test_task_crm/app/repositories/task_repo.py

127 lines
4.2 KiB
Python

"""Task repository providing role-aware CRUD helpers."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.deal import Deal
from app.models.organization_member import OrganizationRole
from app.models.task import Task, TaskCreate
class TaskAccessError(Exception):
"""Raised when a user attempts to modify a forbidden task."""
class TaskOrganizationMismatchError(Exception):
"""Raised when a task or deal belongs to another organization."""
@dataclass(slots=True)
class TaskQueryParams:
"""Filtering options supported by list queries."""
organization_id: int
deal_id: int | None = None
only_open: bool = False
due_before: datetime | None = None
due_after: datetime | None = None
class TaskRepository:
"""Encapsulates database access for Task entities."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
@property
def session(self) -> AsyncSession:
return self._session
async def list(self, *, params: TaskQueryParams) -> Sequence[Task]:
stmt = (
select(Task)
.join(Deal, Deal.id == Task.deal_id)
.where(Deal.organization_id == params.organization_id)
.options(selectinload(Task.deal))
.order_by(Task.due_date.is_(None), Task.due_date, Task.id)
)
stmt = self._apply_filters(stmt, params)
result = await self._session.scalars(stmt)
return result.all()
async def get(self, task_id: int, *, organization_id: int) -> Task | None:
stmt = (
select(Task)
.join(Deal, Deal.id == Task.deal_id)
.where(Task.id == task_id, Deal.organization_id == organization_id)
.options(selectinload(Task.deal))
)
result = await self._session.scalars(stmt)
return result.first()
async def create(
self,
data: TaskCreate,
*,
organization_id: int,
role: OrganizationRole,
user_id: int,
) -> Task:
deal = await self._session.get(Deal, data.deal_id)
if deal is None or deal.organization_id != organization_id:
raise TaskOrganizationMismatchError("Deal belongs to another organization")
if role == OrganizationRole.MEMBER and deal.owner_id != user_id:
raise TaskAccessError("Members can only create tasks for their own deals")
task = Task(**data.model_dump())
self._session.add(task)
await self._session.flush()
return task
async def update(
self,
task: Task,
updates: Mapping[str, Any],
*,
role: OrganizationRole,
user_id: int,
) -> Task:
owner_id = await self._resolve_task_owner(task)
if owner_id is None:
raise TaskOrganizationMismatchError("Task is missing an owner context")
if role == OrganizationRole.MEMBER and owner_id != user_id:
raise TaskAccessError("Members can only modify their own tasks")
for field, value in updates.items():
if hasattr(task, field):
setattr(task, field, value)
await self._session.flush()
return task
def _apply_filters(
self, stmt: Select[tuple[Task]], params: TaskQueryParams
) -> Select[tuple[Task]]:
if params.deal_id is not None:
stmt = stmt.where(Task.deal_id == params.deal_id)
if params.only_open:
stmt = stmt.where(Task.is_done.is_(False))
if params.due_before is not None:
stmt = stmt.where(Task.due_date <= params.due_before)
if params.due_after is not None:
stmt = stmt.where(Task.due_date >= params.due_after)
return stmt
async def _resolve_task_owner(self, task: Task) -> int | None:
if task.deal is not None:
return task.deal.owner_id
stmt = select(Deal.owner_id).where(Deal.id == task.deal_id)
return await self._session.scalar(stmt)