"""Task repository providing role-aware CRUD helpers.""" from __future__ import annotations from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime from typing import Any from sqlalchemy import Select, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models.deal import Deal from app.models.organization_member import OrganizationRole from app.models.task import Task, TaskCreate class TaskAccessError(Exception): """Raised when a user attempts to modify a forbidden task.""" class TaskOrganizationMismatchError(Exception): """Raised when a task or deal belongs to another organization.""" @dataclass(slots=True) class TaskQueryParams: """Filtering options supported by list queries.""" organization_id: int deal_id: int | None = None only_open: bool = False due_before: datetime | None = None due_after: datetime | None = None class TaskRepository: """Encapsulates database access for Task entities.""" def __init__(self, session: AsyncSession) -> None: self._session = session @property def session(self) -> AsyncSession: return self._session async def list(self, *, params: TaskQueryParams) -> Sequence[Task]: stmt = ( select(Task) .join(Deal, Deal.id == Task.deal_id) .where(Deal.organization_id == params.organization_id) .options(selectinload(Task.deal)) .order_by(Task.due_date.is_(None), Task.due_date, Task.id) ) stmt = self._apply_filters(stmt, params) result = await self._session.scalars(stmt) return result.all() async def get(self, task_id: int, *, organization_id: int) -> Task | None: stmt = ( select(Task) .join(Deal, Deal.id == Task.deal_id) .where(Task.id == task_id, Deal.organization_id == organization_id) .options(selectinload(Task.deal)) ) result = await self._session.scalars(stmt) return result.first() async def create( self, data: TaskCreate, *, organization_id: int, role: OrganizationRole, user_id: int, ) -> Task: deal = await self._session.get(Deal, data.deal_id) if deal is None or deal.organization_id != organization_id: raise TaskOrganizationMismatchError("Deal belongs to another organization") if role == OrganizationRole.MEMBER and deal.owner_id != user_id: raise TaskAccessError("Members can only create tasks for their own deals") task = Task(**data.model_dump()) self._session.add(task) await self._session.flush() return task async def update( self, task: Task, updates: Mapping[str, Any], *, role: OrganizationRole, user_id: int, ) -> Task: owner_id = await self._resolve_task_owner(task) if owner_id is None: raise TaskOrganizationMismatchError("Task is missing an owner context") if role == OrganizationRole.MEMBER and owner_id != user_id: raise TaskAccessError("Members can only modify their own tasks") for field, value in updates.items(): if hasattr(task, field): setattr(task, field, value) await self._session.flush() return task def _apply_filters( self, stmt: Select[tuple[Task]], params: TaskQueryParams ) -> Select[tuple[Task]]: if params.deal_id is not None: stmt = stmt.where(Task.deal_id == params.deal_id) if params.only_open: stmt = stmt.where(Task.is_done.is_(False)) if params.due_before is not None: stmt = stmt.where(Task.due_date <= params.due_before) if params.due_after is not None: stmt = stmt.where(Task.due_date >= params.due_after) return stmt async def _resolve_task_owner(self, task: Task) -> int | None: if task.deal is not None: return task.deal.owner_id stmt = select(Deal.owner_id).where(Deal.id == task.deal_id) return await self._session.scalar(stmt)