test_task_crm/app/repositories/deal_repo.py

158 lines
4.9 KiB
Python

"""Deal repository with access-aware CRUD helpers."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from decimal import Decimal
from typing import Any
from sqlalchemy import Select, asc, desc, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.deal import Deal, DealCreate, DealStage, DealStatus
from app.models.organization_member import OrganizationRole
ORDERABLE_COLUMNS: dict[str, Any] = {
"created_at": Deal.created_at,
"amount": Deal.amount,
"title": Deal.title,
}
class DealAccessError(Exception):
"""Raised when a user attempts an operation without sufficient permissions."""
@dataclass(slots=True)
class DealQueryParams:
"""Filters supported by list queries."""
organization_id: int
page: int = 1
page_size: int = 20
statuses: Sequence[DealStatus] | None = None
stage: DealStage | None = None
owner_id: int | None = None
min_amount: Decimal | None = None
max_amount: Decimal | None = None
order_by: str | None = None
order_desc: bool = True
class DealRepository:
"""Provides CRUD helpers for deals with role-aware filtering."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
@property
def session(self) -> AsyncSession:
return self._session
async def list(
self,
*,
params: DealQueryParams,
role: OrganizationRole,
user_id: int,
) -> Sequence[Deal]:
stmt = select(Deal).where(Deal.organization_id == params.organization_id)
stmt = self._apply_filters(stmt, params, role, user_id)
stmt = self._apply_ordering(stmt, params)
offset = (max(params.page, 1) - 1) * params.page_size
stmt = stmt.offset(offset).limit(params.page_size)
result = await self._session.scalars(stmt)
return result.all()
async def get(
self,
deal_id: int,
*,
organization_id: int,
role: OrganizationRole,
user_id: int,
require_owner: bool = False,
) -> Deal | None:
stmt = select(Deal).where(Deal.id == deal_id, Deal.organization_id == organization_id)
stmt = self._apply_role_clause(stmt, role, user_id, require_owner=require_owner)
result = await self._session.scalars(stmt)
return result.first()
async def create(
self,
data: DealCreate,
*,
role: OrganizationRole,
user_id: int,
) -> Deal:
if role == OrganizationRole.MEMBER and data.owner_id != user_id:
raise DealAccessError("Members can only create deals they own")
deal = Deal(**data.model_dump())
self._session.add(deal)
await self._session.flush()
return deal
async def update(
self,
deal: Deal,
updates: Mapping[str, Any],
*,
role: OrganizationRole,
user_id: int,
) -> Deal:
if role == OrganizationRole.MEMBER and deal.owner_id != user_id:
raise DealAccessError("Members can only modify their own deals")
for field, value in updates.items():
if hasattr(deal, field):
setattr(deal, field, value)
await self._session.flush()
await self._session.refresh(deal)
return deal
def _apply_filters(
self,
stmt: Select[tuple[Deal]],
params: DealQueryParams,
role: OrganizationRole,
user_id: int,
) -> Select[tuple[Deal]]:
if params.statuses:
stmt = stmt.where(Deal.status.in_(params.statuses))
if params.stage:
stmt = stmt.where(Deal.stage == params.stage)
if params.owner_id is not None:
if role == OrganizationRole.MEMBER and params.owner_id != user_id:
raise DealAccessError("Members cannot filter by other owners")
stmt = stmt.where(Deal.owner_id == params.owner_id)
if params.min_amount is not None:
stmt = stmt.where(Deal.amount >= params.min_amount)
if params.max_amount is not None:
stmt = stmt.where(Deal.amount <= params.max_amount)
return self._apply_role_clause(stmt, role, user_id)
def _apply_role_clause(
self,
stmt: Select[tuple[Deal]],
role: OrganizationRole,
user_id: int,
*,
require_owner: bool = False,
) -> Select[tuple[Deal]]:
if role in {OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MANAGER}:
return stmt
if require_owner:
return stmt.where(Deal.owner_id == user_id)
return stmt
def _apply_ordering(
self,
stmt: Select[tuple[Deal]],
params: DealQueryParams,
) -> Select[tuple[Deal]]:
column = ORDERABLE_COLUMNS.get(params.order_by or "created_at", Deal.created_at)
order_func = desc if params.order_desc else asc
return stmt.order_by(order_func(column))