73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
"""Repository helpers for deal activities."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass
|
|
|
|
from sqlalchemy import Select, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.activity import Activity, ActivityCreate
|
|
from app.models.deal import Deal
|
|
|
|
|
|
class ActivityOrganizationMismatchError(Exception):
|
|
"""Raised when a deal/activity pair targets another organization."""
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ActivityQueryParams:
|
|
"""Filtering options for fetching activities."""
|
|
|
|
organization_id: int
|
|
deal_id: int
|
|
limit: int | None = None
|
|
offset: int = 0
|
|
|
|
|
|
class ActivityRepository:
|
|
"""Provides CRUD helpers for Activity model."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self._session = session
|
|
|
|
@property
|
|
def session(self) -> AsyncSession:
|
|
return self._session
|
|
|
|
async def list(self, *, params: ActivityQueryParams) -> Sequence[Activity]:
|
|
stmt = (
|
|
select(Activity)
|
|
.join(Deal, Deal.id == Activity.deal_id)
|
|
.where(
|
|
Activity.deal_id == params.deal_id,
|
|
Deal.organization_id == params.organization_id,
|
|
)
|
|
.order_by(Activity.created_at)
|
|
)
|
|
stmt = self._apply_window(stmt, params)
|
|
result = await self._session.scalars(stmt)
|
|
return result.all()
|
|
|
|
async def create(self, data: ActivityCreate, *, organization_id: int) -> Activity:
|
|
deal = await self._session.get(Deal, data.deal_id)
|
|
if deal is None or deal.organization_id != organization_id:
|
|
raise ActivityOrganizationMismatchError("Deal belongs to another organization")
|
|
|
|
activity = Activity(**data.model_dump())
|
|
self._session.add(activity)
|
|
await self._session.flush()
|
|
return activity
|
|
|
|
def _apply_window(
|
|
self,
|
|
stmt: Select[tuple[Activity]],
|
|
params: ActivityQueryParams,
|
|
) -> Select[tuple[Activity]]:
|
|
if params.offset:
|
|
stmt = stmt.offset(params.offset)
|
|
if params.limit is not None:
|
|
stmt = stmt.limit(params.limit)
|
|
return stmt
|