"""Deal repository with access-aware CRUD helpers.""" from __future__ import annotations from collections.abc import Mapping, Sequence from dataclasses import dataclass from decimal import Decimal from typing import Any from sqlalchemy import Select, asc, desc, select from sqlalchemy.ext.asyncio import AsyncSession from app.models.deal import Deal, DealCreate, DealStage, DealStatus from app.models.organization_member import OrganizationRole ORDERABLE_COLUMNS: dict[str, Any] = { "created_at": Deal.created_at, "amount": Deal.amount, "title": Deal.title, } class DealAccessError(Exception): """Raised when a user attempts an operation without sufficient permissions.""" @dataclass(slots=True) class DealQueryParams: """Filters supported by list queries.""" organization_id: int page: int = 1 page_size: int = 20 statuses: Sequence[DealStatus] | None = None stage: DealStage | None = None owner_id: int | None = None min_amount: Decimal | None = None max_amount: Decimal | None = None order_by: str | None = None order_desc: bool = True class DealRepository: """Provides CRUD helpers for deals with role-aware filtering.""" def __init__(self, session: AsyncSession) -> None: self._session = session @property def session(self) -> AsyncSession: return self._session async def list( self, *, params: DealQueryParams, role: OrganizationRole, user_id: int, ) -> Sequence[Deal]: stmt = select(Deal).where(Deal.organization_id == params.organization_id) stmt = self._apply_filters(stmt, params, role, user_id) stmt = self._apply_ordering(stmt, params) offset = (max(params.page, 1) - 1) * params.page_size stmt = stmt.offset(offset).limit(params.page_size) result = await self._session.scalars(stmt) return result.all() async def get( self, deal_id: int, *, organization_id: int, role: OrganizationRole, user_id: int, require_owner: bool = False, ) -> Deal | None: stmt = select(Deal).where(Deal.id == deal_id, Deal.organization_id == organization_id) stmt = self._apply_role_clause(stmt, role, user_id, require_owner=require_owner) result = await self._session.scalars(stmt) return result.first() async def create( self, data: DealCreate, *, role: OrganizationRole, user_id: int, ) -> Deal: if role == OrganizationRole.MEMBER and data.owner_id != user_id: raise DealAccessError("Members can only create deals they own") deal = Deal(**data.model_dump()) self._session.add(deal) await self._session.flush() return deal async def update( self, deal: Deal, updates: Mapping[str, Any], *, role: OrganizationRole, user_id: int, ) -> Deal: if role == OrganizationRole.MEMBER and deal.owner_id != user_id: raise DealAccessError("Members can only modify their own deals") for field, value in updates.items(): if hasattr(deal, field): setattr(deal, field, value) await self._session.flush() await self._session.refresh(deal) return deal def _apply_filters( self, stmt: Select[tuple[Deal]], params: DealQueryParams, role: OrganizationRole, user_id: int, ) -> Select[tuple[Deal]]: if params.statuses: stmt = stmt.where(Deal.status.in_(params.statuses)) if params.stage: stmt = stmt.where(Deal.stage == params.stage) if params.owner_id is not None: if role == OrganizationRole.MEMBER and params.owner_id != user_id: raise DealAccessError("Members cannot filter by other owners") stmt = stmt.where(Deal.owner_id == params.owner_id) if params.min_amount is not None: stmt = stmt.where(Deal.amount >= params.min_amount) if params.max_amount is not None: stmt = stmt.where(Deal.amount <= params.max_amount) return self._apply_role_clause(stmt, role, user_id) def _apply_role_clause( self, stmt: Select[tuple[Deal]], role: OrganizationRole, user_id: int, *, require_owner: bool = False, ) -> Select[tuple[Deal]]: if role in {OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MANAGER}: return stmt if require_owner: return stmt.where(Deal.owner_id == user_id) return stmt def _apply_ordering(self, stmt: Select[tuple[Deal]], params: DealQueryParams) -> Select[tuple[Deal]]: column = ORDERABLE_COLUMNS.get(params.order_by or "created_at", Deal.created_at) order_func = desc if params.order_desc else asc return stmt.order_by(order_func(column))