deals&activities #4

Merged
k1nq merged 6 commits from deals&activities into dev 2025-11-27 11:18:51 +00:00
16 changed files with 1009 additions and 42 deletions

View File

@ -2,7 +2,7 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
import jwt import jwt
from fastapi import Depends, HTTPException, status from fastapi import Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -10,9 +10,17 @@ from app.core.config import settings
from app.core.database import get_session from app.core.database import get_session
from app.core.security import jwt_service, password_hasher from app.core.security import jwt_service, password_hasher
from app.models.user import User from app.models.user import User
from app.repositories.deal_repo import DealRepository
from app.repositories.org_repo import OrganizationRepository from app.repositories.org_repo import OrganizationRepository
from app.repositories.user_repo import UserRepository from app.repositories.user_repo import UserRepository
from app.services.auth_service import AuthService from app.services.auth_service import AuthService
from app.services.deal_service import DealService
from app.services.organization_service import (
OrganizationAccessDeniedError,
OrganizationContext,
OrganizationContextMissingError,
OrganizationService,
)
from app.services.user_service import UserService from app.services.user_service import UserService
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token")
@ -32,6 +40,14 @@ def get_organization_repository(session: AsyncSession = Depends(get_db_session))
return OrganizationRepository(session=session) return OrganizationRepository(session=session)
def get_deal_repository(session: AsyncSession = Depends(get_db_session)) -> DealRepository:
return DealRepository(session=session)
def get_deal_service(repo: DealRepository = Depends(get_deal_repository)) -> DealService:
return DealService(repository=repo)
def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService: def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService:
return UserService(user_repository=repo, password_hasher=password_hasher) return UserService(user_repository=repo, password_hasher=password_hasher)
@ -46,6 +62,12 @@ def get_auth_service(
) )
def get_organization_service(
repo: OrganizationRepository = Depends(get_organization_repository),
) -> OrganizationService:
return OrganizationService(repository=repo)
async def get_current_user( async def get_current_user(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
repo: UserRepository = Depends(get_user_repository), repo: UserRepository = Depends(get_user_repository),
@ -67,3 +89,16 @@ async def get_current_user(
if user is None: if user is None:
raise credentials_exception raise credentials_exception
return user return user
async def get_organization_context(
x_organization_id: int | None = Header(default=None, alias="X-Organization-Id"),
current_user: User = Depends(get_current_user),
service: OrganizationService = Depends(get_organization_service),
) -> OrganizationContext:
try:
return await service.get_context(user_id=current_user.id, organization_id=x_organization_id)
except OrganizationContextMissingError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
except OrganizationAccessDeniedError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc

View File

@ -1,7 +1,10 @@
"""Activity timeline API stubs.""" """Activity timeline API stubs."""
from __future__ import annotations from __future__ import annotations
from fastapi import APIRouter, status from fastapi import APIRouter, Depends, status
from app.api.deps import get_organization_context
from app.services.organization_service import OrganizationContext
from .models import ActivityCommentPayload from .models import ActivityCommentPayload
@ -13,14 +16,21 @@ def _stub(endpoint: str) -> dict[str, str]:
@router.get("/", status_code=status.HTTP_501_NOT_IMPLEMENTED) @router.get("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
async def list_activities(deal_id: int) -> dict[str, str]: async def list_activities(
deal_id: int,
context: OrganizationContext = Depends(get_organization_context),
) -> dict[str, str]:
"""Placeholder for listing deal activities.""" """Placeholder for listing deal activities."""
_ = deal_id _ = (deal_id, context)
return _stub("GET /deals/{deal_id}/activities") return _stub("GET /deals/{deal_id}/activities")
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED) @router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
async def create_activity_comment(deal_id: int, payload: ActivityCommentPayload) -> dict[str, str]: async def create_activity_comment(
deal_id: int,
payload: ActivityCommentPayload,
context: OrganizationContext = Depends(get_organization_context),
) -> dict[str, str]:
"""Placeholder for adding a comment activity to a deal.""" """Placeholder for adding a comment activity to a deal."""
_ = (deal_id, payload) _ = (deal_id, payload, context)
return _stub("POST /deals/{deal_id}/activities") return _stub("POST /deals/{deal_id}/activities")

View File

@ -1,7 +1,10 @@
"""Analytics API stubs (deal summary and funnel).""" """Analytics API stubs (deal summary and funnel)."""
from __future__ import annotations from __future__ import annotations
from fastapi import APIRouter, Query, status from fastapi import APIRouter, Depends, Query, status
from app.api.deps import get_organization_context
from app.services.organization_service import OrganizationContext
router = APIRouter(prefix="/analytics", tags=["analytics"]) router = APIRouter(prefix="/analytics", tags=["analytics"])
@ -11,13 +14,19 @@ def _stub(endpoint: str) -> dict[str, str]:
@router.get("/deals/summary", status_code=status.HTTP_501_NOT_IMPLEMENTED) @router.get("/deals/summary", status_code=status.HTTP_501_NOT_IMPLEMENTED)
async def deals_summary(days: int = Query(30, ge=1, le=180)) -> dict[str, str]: async def deals_summary(
days: int = Query(30, ge=1, le=180),
context: OrganizationContext = Depends(get_organization_context),
) -> dict[str, str]:
"""Placeholder for aggregated deal statistics.""" """Placeholder for aggregated deal statistics."""
_ = days _ = (days, context)
return _stub("GET /analytics/deals/summary") return _stub("GET /analytics/deals/summary")
@router.get("/deals/funnel", status_code=status.HTTP_501_NOT_IMPLEMENTED) @router.get("/deals/funnel", status_code=status.HTTP_501_NOT_IMPLEMENTED)
async def deals_funnel() -> dict[str, str]: async def deals_funnel(
context: OrganizationContext = Depends(get_organization_context),
) -> dict[str, str]:
"""Placeholder for funnel analytics.""" """Placeholder for funnel analytics."""
_ = context
return _stub("GET /analytics/deals/funnel") return _stub("GET /analytics/deals/funnel")

View File

@ -1,7 +1,10 @@
"""Contact API stubs required by the spec.""" """Contact API stubs required by the spec."""
from __future__ import annotations from __future__ import annotations
from fastapi import APIRouter, Query, status from fastapi import APIRouter, Depends, Query, status
from app.api.deps import get_organization_context
from app.services.organization_service import OrganizationContext
from .models import ContactCreatePayload from .models import ContactCreatePayload
@ -18,13 +21,18 @@ async def list_contacts(
page_size: int = Query(20, ge=1, le=100), page_size: int = Query(20, ge=1, le=100),
search: str | None = None, search: str | None = None,
owner_id: int | None = None, owner_id: int | None = None,
context: OrganizationContext = Depends(get_organization_context),
) -> dict[str, str]: ) -> dict[str, str]:
"""Placeholder list endpoint supporting the required filters.""" """Placeholder list endpoint supporting the required filters."""
_ = context
return _stub("GET /contacts") return _stub("GET /contacts")
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED) @router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
async def create_contact(payload: ContactCreatePayload) -> dict[str, str]: async def create_contact(
payload: ContactCreatePayload,
context: OrganizationContext = Depends(get_organization_context),
) -> dict[str, str]:
"""Placeholder for creating a contact within the current organization.""" """Placeholder for creating a contact within the current organization."""
_ = payload _ = (payload, context)
return _stub("POST /contacts") return _stub("POST /contacts")

View File

@ -5,16 +5,29 @@ from decimal import Decimal
from pydantic import BaseModel from pydantic import BaseModel
from app.models.deal import DealCreate, DealStage, DealStatus
class DealCreatePayload(BaseModel): class DealCreatePayload(BaseModel):
contact_id: int contact_id: int
title: str title: str
amount: Decimal | None = None amount: Decimal | None = None
currency: str | None = None currency: str | None = None
owner_id: int | None = None
def to_domain(self, *, organization_id: int, fallback_owner: int) -> DealCreate:
return DealCreate(
organization_id=organization_id,
contact_id=self.contact_id,
owner_id=self.owner_id or fallback_owner,
title=self.title,
amount=self.amount,
currency=self.currency,
)
class DealUpdatePayload(BaseModel): class DealUpdatePayload(BaseModel):
status: str | None = None status: DealStatus | None = None
stage: str | None = None stage: DealStage | None = None
amount: Decimal | None = None amount: Decimal | None = None
currency: str | None = None currency: str | None = None

View File

@ -1,20 +1,27 @@
"""Deal API stubs covering list/create/update operations.""" """Deal API endpoints backed by DealService."""
from __future__ import annotations from __future__ import annotations
from decimal import Decimal from decimal import Decimal
from fastapi import APIRouter, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.api.deps import get_deal_repository, get_deal_service, get_organization_context
from app.models.deal import DealRead, DealStage, DealStatus
from app.repositories.deal_repo import DealRepository, DealAccessError, DealQueryParams
from app.services.deal_service import (
DealService,
DealStageTransitionError,
DealStatusValidationError,
DealUpdateData,
)
from app.services.organization_service import OrganizationContext
from .models import DealCreatePayload, DealUpdatePayload from .models import DealCreatePayload, DealUpdatePayload
router = APIRouter(prefix="/deals", tags=["deals"]) router = APIRouter(prefix="/deals", tags=["deals"])
def _stub(endpoint: str) -> dict[str, str]: @router.get("/", response_model=list[DealRead])
return {"detail": f"{endpoint} is not implemented yet"}
@router.get("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
async def list_deals( async def list_deals(
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100), page_size: int = Query(20, ge=1, le=100),
@ -24,22 +31,90 @@ async def list_deals(
stage: str | None = None, stage: str | None = None,
owner_id: int | None = None, owner_id: int | None = None,
order_by: str | None = None, order_by: str | None = None,
order: str | None = Query(default=None, pattern="^(asc|desc)$"), order: str | None = Query(default="desc", pattern="^(asc|desc)$"),
) -> dict[str, str]: context: OrganizationContext = Depends(get_organization_context),
"""Placeholder for deal filtering endpoint.""" repo: DealRepository = Depends(get_deal_repository),
_ = (status_filter,) ) -> list[DealRead]:
return _stub("GET /deals") """List deals for the current organization with optional filters."""
try:
statuses_value = [DealStatus(value) for value in status_filter] if status_filter else None
stage_value = DealStage(stage) if stage else None
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid deal filter") from exc
params = DealQueryParams(
organization_id=context.organization_id,
page=page,
page_size=page_size,
statuses=statuses_value,
stage=stage_value,
owner_id=owner_id,
min_amount=min_amount,
max_amount=max_amount,
order_by=order_by,
order_desc=(order != "asc"),
)
try:
deals = await repo.list(params=params, role=context.role, user_id=context.user_id)
except DealAccessError as exc:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(exc)) from exc
return [DealRead.model_validate(deal) for deal in deals]
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED) @router.post("/", response_model=DealRead, status_code=status.HTTP_201_CREATED)
async def create_deal(payload: DealCreatePayload) -> dict[str, str]: async def create_deal(
"""Placeholder for creating a new deal.""" payload: DealCreatePayload,
_ = payload context: OrganizationContext = Depends(get_organization_context),
return _stub("POST /deals") service: DealService = Depends(get_deal_service),
) -> DealRead:
"""Create a new deal within the current organization."""
data = payload.to_domain(organization_id=context.organization_id, fallback_owner=context.user_id)
try:
deal = await service.create_deal(data, context=context)
except DealAccessError as exc:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(exc)) from exc
except DealStatusValidationError as exc: # pragma: no cover - creation shouldn't trigger
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
return DealRead.model_validate(deal)
@router.patch("/{deal_id}", status_code=status.HTTP_501_NOT_IMPLEMENTED) @router.patch("/{deal_id}", response_model=DealRead)
async def update_deal(deal_id: int, payload: DealUpdatePayload) -> dict[str, str]: async def update_deal(
"""Placeholder for modifying deal status or stage.""" deal_id: int,
_ = (deal_id, payload) payload: DealUpdatePayload,
return _stub("PATCH /deals/{deal_id}") context: OrganizationContext = Depends(get_organization_context),
repo: DealRepository = Depends(get_deal_repository),
service: DealService = Depends(get_deal_service),
) -> DealRead:
"""Update deal status, stage, or financial data."""
existing = await repo.get(
deal_id,
organization_id=context.organization_id,
role=context.role,
user_id=context.user_id,
)
if existing is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Deal not found")
updates = DealUpdateData(
status=payload.status,
stage=payload.stage,
amount=payload.amount,
currency=payload.currency,
)
try:
deal = await service.update_deal(existing, updates, context=context)
except DealAccessError as exc:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(exc)) from exc
except DealStageTransitionError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
except DealStatusValidationError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
return DealRead.model_validate(deal)

View File

@ -3,7 +3,10 @@ from __future__ import annotations
from datetime import date from datetime import date
from fastapi import APIRouter, Query, status from fastapi import APIRouter, Depends, Query, status
from app.api.deps import get_organization_context
from app.services.organization_service import OrganizationContext
from .models import TaskCreatePayload from .models import TaskCreatePayload
@ -20,13 +23,18 @@ async def list_tasks(
only_open: bool = False, only_open: bool = False,
due_before: date | None = Query(default=None), due_before: date | None = Query(default=None),
due_after: date | None = Query(default=None), due_after: date | None = Query(default=None),
context: OrganizationContext = Depends(get_organization_context),
) -> dict[str, str]: ) -> dict[str, str]:
"""Placeholder for task filtering endpoint.""" """Placeholder for task filtering endpoint."""
_ = context
return _stub("GET /tasks") return _stub("GET /tasks")
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED) @router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
async def create_task(payload: TaskCreatePayload) -> dict[str, str]: async def create_task(
payload: TaskCreatePayload,
context: OrganizationContext = Depends(get_organization_context),
) -> dict[str, str]:
"""Placeholder for creating a task linked to a deal.""" """Placeholder for creating a task linked to a deal."""
_ = payload _ = (payload, context)
return _stub("POST /tasks") return _stub("POST /tasks")

View File

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, func, text from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, func, text
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.types import JSON as GenericJSON, TypeDecorator
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base from app.models.base import Base
@ -16,10 +17,25 @@ from app.models.base import Base
class ActivityType(StrEnum): class ActivityType(StrEnum):
COMMENT = "comment" COMMENT = "comment"
STATUS_CHANGED = "status_changed" STATUS_CHANGED = "status_changed"
STAGE_CHANGED = "stage_changed"
TASK_CREATED = "task_created" TASK_CREATED = "task_created"
SYSTEM = "system" SYSTEM = "system"
class JSONBCompat(TypeDecorator):
"""Uses JSONB on Postgres and plain JSON elsewhere for testability."""
impl = JSONB
cache_ok = True
def load_dialect_impl(self, dialect): # type: ignore[override]
if dialect.name == "sqlite":
from sqlalchemy.dialects.sqlite import JSON as SQLiteJSON # local import
return dialect.type_descriptor(SQLiteJSON())
return dialect.type_descriptor(JSONB())
class Activity(Base): class Activity(Base):
"""Represents a timeline event for a deal.""" """Represents a timeline event for a deal."""
@ -32,9 +48,9 @@ class Activity(Base):
) )
type: Mapped[ActivityType] = mapped_column(SqlEnum(ActivityType, name="activity_type"), nullable=False) type: Mapped[ActivityType] = mapped_column(SqlEnum(ActivityType, name="activity_type"), nullable=False)
payload: Mapped[dict[str, Any]] = mapped_column( payload: Mapped[dict[str, Any]] = mapped_column(
JSONB, JSONBCompat().with_variant(GenericJSON(), "sqlite"),
nullable=False, nullable=False,
server_default=text("'{}'::jsonb"), server_default=text("'{}'"),
) )
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False DateTime(timezone=True), server_default=func.now(), nullable=False

View File

@ -0,0 +1,152 @@
"""Deal repository with access-aware CRUD helpers."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from decimal import Decimal
from typing import Any
from sqlalchemy import Select, asc, desc, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.deal import Deal, DealCreate, DealStage, DealStatus
from app.models.organization_member import OrganizationRole
ORDERABLE_COLUMNS: dict[str, Any] = {
"created_at": Deal.created_at,
"amount": Deal.amount,
"title": Deal.title,
}
class DealAccessError(Exception):
"""Raised when a user attempts an operation without sufficient permissions."""
@dataclass(slots=True)
class DealQueryParams:
"""Filters supported by list queries."""
organization_id: int
page: int = 1
page_size: int = 20
statuses: Sequence[DealStatus] | None = None
stage: DealStage | None = None
owner_id: int | None = None
min_amount: Decimal | None = None
max_amount: Decimal | None = None
order_by: str | None = None
order_desc: bool = True
class DealRepository:
"""Provides CRUD helpers for deals with role-aware filtering."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
@property
def session(self) -> AsyncSession:
return self._session
async def list(
self,
*,
params: DealQueryParams,
role: OrganizationRole,
user_id: int,
) -> Sequence[Deal]:
stmt = select(Deal).where(Deal.organization_id == params.organization_id)
stmt = self._apply_filters(stmt, params, role, user_id)
stmt = self._apply_ordering(stmt, params)
offset = (max(params.page, 1) - 1) * params.page_size
stmt = stmt.offset(offset).limit(params.page_size)
result = await self._session.scalars(stmt)
return result.all()
async def get(
self,
deal_id: int,
*,
organization_id: int,
role: OrganizationRole,
user_id: int,
require_owner: bool = False,
) -> Deal | None:
stmt = select(Deal).where(Deal.id == deal_id, Deal.organization_id == organization_id)
stmt = self._apply_role_clause(stmt, role, user_id, require_owner=require_owner)
result = await self._session.scalars(stmt)
return result.first()
async def create(
self,
data: DealCreate,
*,
role: OrganizationRole,
user_id: int,
) -> Deal:
if role == OrganizationRole.MEMBER and data.owner_id != user_id:
raise DealAccessError("Members can only create deals they own")
deal = Deal(**data.model_dump())
self._session.add(deal)
await self._session.flush()
return deal
async def update(
self,
deal: Deal,
updates: Mapping[str, Any],
*,
role: OrganizationRole,
user_id: int,
) -> Deal:
if role == OrganizationRole.MEMBER and deal.owner_id != user_id:
raise DealAccessError("Members can only modify their own deals")
for field, value in updates.items():
if hasattr(deal, field):
setattr(deal, field, value)
await self._session.flush()
return deal
def _apply_filters(
self,
stmt: Select[tuple[Deal]],
params: DealQueryParams,
role: OrganizationRole,
user_id: int,
) -> Select[tuple[Deal]]:
if params.statuses:
stmt = stmt.where(Deal.status.in_(params.statuses))
if params.stage:
stmt = stmt.where(Deal.stage == params.stage)
if params.owner_id is not None:
if role == OrganizationRole.MEMBER and params.owner_id != user_id:
raise DealAccessError("Members cannot filter by other owners")
stmt = stmt.where(Deal.owner_id == params.owner_id)
if params.min_amount is not None:
stmt = stmt.where(Deal.amount >= params.min_amount)
if params.max_amount is not None:
stmt = stmt.where(Deal.amount <= params.max_amount)
return self._apply_role_clause(stmt, role, user_id)
def _apply_role_clause(
self,
stmt: Select[tuple[Deal]],
role: OrganizationRole,
user_id: int,
*,
require_owner: bool = False,
) -> Select[tuple[Deal]]:
if role in {OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MANAGER}:
return stmt
if require_owner:
return stmt.where(Deal.owner_id == user_id)
return stmt
def _apply_ordering(self, stmt: Select[tuple[Deal]], params: DealQueryParams) -> Select[tuple[Deal]]:
column = ORDERABLE_COLUMNS.get(params.order_by or "created_at", Deal.created_at)
order_func = desc if params.order_desc else asc
return stmt.order_by(order_func(column))

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.models.organization import Organization, OrganizationCreate from app.models.organization import Organization, OrganizationCreate
@ -42,6 +43,18 @@ class OrganizationRepository:
result = await self._session.scalars(stmt) result = await self._session.scalars(stmt)
return result.unique().all() return result.unique().all()
async def get_membership(self, organization_id: int, user_id: int) -> OrganizationMember | None:
stmt = (
select(OrganizationMember)
.where(
OrganizationMember.organization_id == organization_id,
OrganizationMember.user_id == user_id,
)
.options(selectinload(OrganizationMember.organization))
)
result = await self._session.scalars(stmt)
return result.first()
async def create(self, data: OrganizationCreate) -> Organization: async def create(self, data: OrganizationCreate) -> Organization:
organization = Organization(name=data.name) organization = Organization(name=data.name)
self._session.add(organization) self._session.add(organization)

View File

@ -1 +1,9 @@
"""Business logic services.""" """Business logic services."""
from .organization_service import ( # noqa: F401
OrganizationAccessDeniedError,
OrganizationContext,
OrganizationContextMissingError,
OrganizationService,
)
from .user_service import UserService # noqa: F401
from .auth_service import AuthService # noqa: F401

View File

@ -0,0 +1,164 @@
"""Business logic for deals."""
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
from decimal import Decimal
from sqlalchemy import func, select
from app.models.activity import Activity, ActivityType
from app.models.contact import Contact
from app.models.deal import Deal, DealCreate, DealStage, DealStatus
from app.models.organization_member import OrganizationRole
from app.repositories.deal_repo import DealRepository
from app.services.organization_service import OrganizationContext
STAGE_ORDER = {
stage: index
for index, stage in enumerate(
[
DealStage.QUALIFICATION,
DealStage.PROPOSAL,
DealStage.NEGOTIATION,
DealStage.CLOSED,
]
)
}
class DealServiceError(Exception):
"""Base class for deal service errors."""
class DealOrganizationMismatchError(DealServiceError):
"""Raised when attempting to use resources from another organization."""
class DealStageTransitionError(DealServiceError):
"""Raised when stage transition violates business rules."""
class DealStatusValidationError(DealServiceError):
"""Raised when invalid status transitions are requested."""
class ContactHasDealsError(DealServiceError):
"""Raised when attempting to delete a contact with active deals."""
@dataclass(slots=True)
class DealUpdateData:
"""Structured container for deal update operations."""
status: DealStatus | None = None
stage: DealStage | None = None
amount: Decimal | None = None
currency: str | None = None
class DealService:
"""Encapsulates deal workflows and validations."""
def __init__(self, repository: DealRepository) -> None:
self._repository = repository
async def create_deal(self, data: DealCreate, *, context: OrganizationContext) -> Deal:
self._ensure_same_organization(data.organization_id, context)
await self._ensure_contact_in_organization(data.contact_id, context.organization_id)
return await self._repository.create(data=data, role=context.role, user_id=context.user_id)
async def update_deal(
self,
deal: Deal,
updates: DealUpdateData,
*,
context: OrganizationContext,
) -> Deal:
self._ensure_same_organization(deal.organization_id, context)
changes: dict[str, object] = {}
stage_activity: tuple[ActivityType, dict[str, str]] | None = None
status_activity: tuple[ActivityType, dict[str, str]] | None = None
if updates.amount is not None:
changes["amount"] = updates.amount
if updates.currency is not None:
changes["currency"] = updates.currency
if updates.stage is not None and updates.stage != deal.stage:
self._validate_stage_transition(deal.stage, updates.stage, context.role)
changes["stage"] = updates.stage
stage_activity = (
ActivityType.STAGE_CHANGED,
{"old_stage": deal.stage, "new_stage": updates.stage},
)
if updates.status is not None and updates.status != deal.status:
self._validate_status_transition(deal, updates)
changes["status"] = updates.status
status_activity = (
ActivityType.STATUS_CHANGED,
{"old_status": deal.status, "new_status": updates.status},
)
if not changes:
return deal
updated = await self._repository.update(deal, changes, role=context.role, user_id=context.user_id)
await self._log_activities(
deal_id=deal.id,
author_id=context.user_id,
activities=[activity for activity in [stage_activity, status_activity] if activity],
)
return updated
async def ensure_contact_can_be_deleted(self, contact_id: int) -> None:
stmt = select(func.count()).select_from(Deal).where(Deal.contact_id == contact_id)
count = await self._repository.session.scalar(stmt)
if count and count > 0:
raise ContactHasDealsError("Contact has related deals and cannot be deleted")
async def _log_activities(
self,
*,
deal_id: int,
author_id: int,
activities: Iterable[tuple[ActivityType, dict[str, str]]],
) -> None:
entries = list(activities)
if not entries:
return
for activity_type, payload in entries:
activity = Activity(deal_id=deal_id, author_id=author_id, type=activity_type, payload=payload)
self._repository.session.add(activity)
await self._repository.session.flush()
def _ensure_same_organization(self, organization_id: int, context: OrganizationContext) -> None:
if organization_id != context.organization_id:
raise DealOrganizationMismatchError("Operation targets a different organization")
async def _ensure_contact_in_organization(self, contact_id: int, organization_id: int) -> Contact:
contact = await self._repository.session.get(Contact, contact_id)
if contact is None or contact.organization_id != organization_id:
raise DealOrganizationMismatchError("Contact belongs to another organization")
return contact
def _validate_stage_transition(
self,
current_stage: DealStage,
new_stage: DealStage,
role: OrganizationRole,
) -> None:
if STAGE_ORDER[new_stage] < STAGE_ORDER[current_stage] and role not in {
OrganizationRole.OWNER,
OrganizationRole.ADMIN,
}:
raise DealStageTransitionError("Stage rollback requires owner or admin role")
def _validate_status_transition(self, deal: Deal, updates: DealUpdateData) -> None:
if updates.status != DealStatus.WON:
return
effective_amount = updates.amount if updates.amount is not None else deal.amount
if effective_amount is None or Decimal(effective_amount) <= Decimal("0"):
raise DealStatusValidationError("Amount must be greater than zero to mark a deal as won")

View File

@ -0,0 +1,87 @@
"""Organization-related business rules."""
from __future__ import annotations
from dataclasses import dataclass
from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole
from app.repositories.org_repo import OrganizationRepository
class OrganizationServiceError(Exception):
"""Base class for organization service errors."""
class OrganizationContextMissingError(OrganizationServiceError):
"""Raised when the request lacks organization context."""
class OrganizationAccessDeniedError(OrganizationServiceError):
"""Raised when a user tries to work with a foreign organization."""
class OrganizationForbiddenError(OrganizationServiceError):
"""Raised when a user does not have enough privileges."""
@dataclass(slots=True, frozen=True)
class OrganizationContext:
"""Resolved organization and membership information for a request."""
organization: Organization
membership: OrganizationMember
@property
def organization_id(self) -> int:
return self.organization.id
@property
def role(self) -> OrganizationRole:
return self.membership.role
@property
def user_id(self) -> int:
return self.membership.user_id
class OrganizationService:
"""Encapsulates organization-specific policies."""
def __init__(self, repository: OrganizationRepository) -> None:
self._repository = repository
async def get_context(self, *, user_id: int, organization_id: int | None) -> OrganizationContext:
"""Resolve request context ensuring the user belongs to the given organization."""
if organization_id is None:
raise OrganizationContextMissingError("X-Organization-Id header is required")
membership = await self._repository.get_membership(organization_id, user_id)
if membership is None or membership.organization is None:
raise OrganizationAccessDeniedError("Organization not found")
return OrganizationContext(organization=membership.organization, membership=membership)
def ensure_entity_in_context(self, *, entity_organization_id: int, context: OrganizationContext) -> None:
"""Make sure a resource belongs to the current organization."""
if entity_organization_id != context.organization_id:
raise OrganizationAccessDeniedError("Resource belongs to another organization")
def ensure_can_manage_settings(self, context: OrganizationContext) -> None:
"""Allow only owner/admin to change organization-level settings."""
if context.role not in {OrganizationRole.OWNER, OrganizationRole.ADMIN}:
raise OrganizationForbiddenError("Only owner/admin can modify organization settings")
def ensure_can_manage_entity(self, context: OrganizationContext) -> None:
"""Managers/admins/owners may manage entities; members are restricted."""
if context.role == OrganizationRole.MEMBER:
raise OrganizationForbiddenError("Members cannot manage shared entities")
def ensure_member_owns_entity(self, *, context: OrganizationContext, owner_id: int) -> None:
"""Members can only mutate entities they own (contacts/deals/tasks)."""
if context.role == OrganizationRole.MEMBER and owner_id != context.user_id:
raise OrganizationForbiddenError("Members can only modify their own records")

View File

@ -0,0 +1,26 @@
"""Add stage_changed activity type."""
from __future__ import annotations
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "20251127_0002_stage_changed"
down_revision: str | None = "20251122_0001"
branch_labels: tuple[str, ...] | None = None
depends_on: tuple[str, ...] | None = None
def upgrade() -> None:
op.execute("ALTER TYPE activity_type ADD VALUE IF NOT EXISTS 'stage_changed';")
def downgrade() -> None:
op.execute("UPDATE activities SET type = 'status_changed' WHERE type = 'stage_changed';")
op.execute("ALTER TYPE activity_type RENAME TO activity_type_old;")
op.execute(
"CREATE TYPE activity_type AS ENUM ('comment','status_changed','task_created','system');"
)
op.execute(
"ALTER TABLE activities ALTER COLUMN type TYPE activity_type USING type::text::activity_type;"
)
op.execute("DROP TYPE activity_type_old;")

View File

@ -0,0 +1,244 @@
"""Unit tests for DealService."""
from __future__ import annotations
from collections.abc import AsyncGenerator
from decimal import Decimal
import uuid
import pytest # type: ignore[import-not-found]
import pytest_asyncio # type: ignore[import-not-found]
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.models.activity import Activity, ActivityType
from app.models.base import Base
from app.models.contact import Contact
from app.models.deal import DealCreate, DealStage, DealStatus
from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.user import User
from app.repositories.deal_repo import DealRepository
from app.services.deal_service import (
ContactHasDealsError,
DealOrganizationMismatchError,
DealService,
DealStageTransitionError,
DealStatusValidationError,
DealUpdateData,
)
from app.services.organization_service import OrganizationContext
@pytest_asyncio.fixture()
async def session() -> AsyncGenerator[AsyncSession, None]:
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
future=True,
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async_session = async_sessionmaker(engine, expire_on_commit=False)
async with async_session() as session:
yield session
await engine.dispose()
def _make_organization(name: str) -> Organization:
org = Organization(name=name)
return org
def _make_user(email_suffix: str) -> User:
return User(
email=f"user-{email_suffix}@example.com",
hashed_password="hashed",
name="Test User",
is_active=True,
)
def _make_context(org: Organization, user: User, role: OrganizationRole) -> OrganizationContext:
membership = OrganizationMember(organization_id=org.id, user_id=user.id, role=role)
return OrganizationContext(organization=org, membership=membership)
async def _persist_base(session: AsyncSession, *, role: OrganizationRole = OrganizationRole.MANAGER) -> tuple[
OrganizationContext,
Contact,
DealRepository,
]:
org = _make_organization(name=f"Org-{uuid.uuid4()}"[:8])
user = _make_user(email_suffix=str(uuid.uuid4())[:8])
session.add_all([org, user])
await session.flush()
contact = Contact(
organization_id=org.id,
owner_id=user.id,
name="John Doe",
email="john@example.com",
)
session.add(contact)
await session.flush()
context = _make_context(org, user, role)
repo = DealRepository(session=session)
return context, contact, repo
@pytest.mark.asyncio
async def test_create_deal_rejects_foreign_contact(session: AsyncSession) -> None:
context, contact, repo = await _persist_base(session)
other_org = _make_organization(name="Other")
other_user = _make_user(email_suffix="other")
session.add_all([other_org, other_user])
await session.flush()
service = DealService(repository=repo)
payload = DealCreate(
organization_id=other_org.id,
contact_id=contact.id,
owner_id=context.user_id,
title="Website Redesign",
amount=None,
)
other_context = _make_context(other_org, other_user, OrganizationRole.MANAGER)
with pytest.raises(DealOrganizationMismatchError):
await service.create_deal(payload, context=other_context)
@pytest.mark.asyncio
async def test_stage_rollback_requires_admin(session: AsyncSession) -> None:
context, contact, repo = await _persist_base(session, role=OrganizationRole.MANAGER)
service = DealService(repository=repo)
deal = await service.create_deal(
DealCreate(
organization_id=context.organization_id,
contact_id=contact.id,
owner_id=context.user_id,
title="Migration",
amount=Decimal("5000"),
),
context=context,
)
deal.stage = DealStage.PROPOSAL
with pytest.raises(DealStageTransitionError):
await service.update_deal(
deal,
DealUpdateData(stage=DealStage.QUALIFICATION),
context=context,
)
@pytest.mark.asyncio
async def test_stage_rollback_allowed_for_admin(session: AsyncSession) -> None:
context, contact, repo = await _persist_base(session, role=OrganizationRole.ADMIN)
service = DealService(repository=repo)
deal = await service.create_deal(
DealCreate(
organization_id=context.organization_id,
contact_id=contact.id,
owner_id=context.user_id,
title="Rollout",
amount=Decimal("1000"),
),
context=context,
)
deal.stage = DealStage.NEGOTIATION
updated = await service.update_deal(
deal,
DealUpdateData(stage=DealStage.PROPOSAL),
context=context,
)
assert updated.stage == DealStage.PROPOSAL
@pytest.mark.asyncio
async def test_status_won_requires_positive_amount(session: AsyncSession) -> None:
context, contact, repo = await _persist_base(session)
service = DealService(repository=repo)
deal = await service.create_deal(
DealCreate(
organization_id=context.organization_id,
contact_id=contact.id,
owner_id=context.user_id,
title="Zero",
amount=None,
),
context=context,
)
with pytest.raises(DealStatusValidationError):
await service.update_deal(
deal,
DealUpdateData(status=DealStatus.WON),
context=context,
)
@pytest.mark.asyncio
async def test_updates_create_activity_records(session: AsyncSession) -> None:
context, contact, repo = await _persist_base(session)
service = DealService(repository=repo)
deal = await service.create_deal(
DealCreate(
organization_id=context.organization_id,
contact_id=contact.id,
owner_id=context.user_id,
title="Activity",
amount=Decimal("100"),
),
context=context,
)
await service.update_deal(
deal,
DealUpdateData(
stage=DealStage.PROPOSAL,
status=DealStatus.WON,
amount=Decimal("5000"),
),
context=context,
)
result = await session.scalars(select(Activity).where(Activity.deal_id == deal.id))
activity_types = {activity.type for activity in result.all()}
assert ActivityType.STAGE_CHANGED in activity_types
assert ActivityType.STATUS_CHANGED in activity_types
@pytest.mark.asyncio
async def test_contact_delete_guard(session: AsyncSession) -> None:
context, contact, repo = await _persist_base(session)
service = DealService(repository=repo)
deal = await service.create_deal(
DealCreate(
organization_id=context.organization_id,
contact_id=contact.id,
owner_id=context.user_id,
title="To Delete",
amount=Decimal("100"),
),
context=context,
)
with pytest.raises(ContactHasDealsError):
await service.ensure_contact_can_be_deleted(contact.id)
await session.delete(deal)
await session.flush()
await service.ensure_contact_can_be_deleted(contact.id)

View File

@ -0,0 +1,99 @@
"""Unit tests for OrganizationService."""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest # type: ignore[import-not-found]
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole
from app.repositories.org_repo import OrganizationRepository
from app.services.organization_service import (
OrganizationAccessDeniedError,
OrganizationContext,
OrganizationContextMissingError,
OrganizationForbiddenError,
OrganizationService,
)
class StubOrganizationRepository(OrganizationRepository):
"""Simple in-memory stand-in for OrganizationRepository."""
def __init__(self, membership: OrganizationMember | None) -> None:
super().__init__(session=MagicMock(spec=AsyncSession))
self._membership = membership
async def get_membership(self, organization_id: int, user_id: int) -> OrganizationMember | None: # pragma: no cover - helper
if (
self._membership
and self._membership.organization_id == organization_id
and self._membership.user_id == user_id
):
return self._membership
return None
def make_membership(role: OrganizationRole, *, organization_id: int = 1, user_id: int = 10) -> OrganizationMember:
organization = Organization(name="Acme Inc")
organization.id = organization_id
membership = OrganizationMember(
organization_id=organization_id,
user_id=user_id,
role=role,
)
membership.organization = organization
return membership
@pytest.mark.asyncio
async def test_get_context_success() -> None:
membership = make_membership(OrganizationRole.MANAGER)
service = OrganizationService(StubOrganizationRepository(membership))
context = await service.get_context(user_id=membership.user_id, organization_id=membership.organization_id)
assert context.organization_id == membership.organization_id
assert context.role == OrganizationRole.MANAGER
@pytest.mark.asyncio
async def test_get_context_missing_header() -> None:
service = OrganizationService(StubOrganizationRepository(None))
with pytest.raises(OrganizationContextMissingError):
await service.get_context(user_id=1, organization_id=None)
@pytest.mark.asyncio
async def test_get_context_access_denied() -> None:
service = OrganizationService(StubOrganizationRepository(None))
with pytest.raises(OrganizationAccessDeniedError):
await service.get_context(user_id=1, organization_id=99)
def test_ensure_can_manage_settings_blocks_manager() -> None:
membership = make_membership(OrganizationRole.MANAGER)
organization = membership.organization
assert organization is not None
context = OrganizationContext(organization=organization, membership=membership)
service = OrganizationService(StubOrganizationRepository(membership))
with pytest.raises(OrganizationForbiddenError):
service.ensure_can_manage_settings(context)
def test_member_must_own_entity() -> None:
membership = make_membership(OrganizationRole.MEMBER)
organization = membership.organization
assert organization is not None
context = OrganizationContext(organization=organization, membership=membership)
service = OrganizationService(StubOrganizationRepository(membership))
with pytest.raises(OrganizationForbiddenError):
service.ensure_member_owns_entity(context=context, owner_id=999)
# Same owner should pass silently.
service.ensure_member_owns_entity(context=context, owner_id=membership.user_id)