Compare commits
5 Commits
c972d79ba9
...
a4c3864ef6
| Author | SHA1 | Date |
|---|---|---|
|
|
a4c3864ef6 | |
|
|
8492a0aed1 | |
|
|
969a1b5905 | |
|
|
8c326501bf | |
|
|
4b45073bd3 |
|
|
@ -2,7 +2,7 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -10,9 +10,17 @@ from app.core.config import settings
|
|||
from app.core.database import get_session
|
||||
from app.core.security import jwt_service, password_hasher
|
||||
from app.models.user import User
|
||||
from app.repositories.deal_repo import DealRepository
|
||||
from app.repositories.org_repo import OrganizationRepository
|
||||
from app.repositories.user_repo import UserRepository
|
||||
from app.services.auth_service import AuthService
|
||||
from app.services.deal_service import DealService
|
||||
from app.services.organization_service import (
|
||||
OrganizationAccessDeniedError,
|
||||
OrganizationContext,
|
||||
OrganizationContextMissingError,
|
||||
OrganizationService,
|
||||
)
|
||||
from app.services.user_service import UserService
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token")
|
||||
|
|
@ -32,6 +40,14 @@ def get_organization_repository(session: AsyncSession = Depends(get_db_session))
|
|||
return OrganizationRepository(session=session)
|
||||
|
||||
|
||||
def get_deal_repository(session: AsyncSession = Depends(get_db_session)) -> DealRepository:
|
||||
return DealRepository(session=session)
|
||||
|
||||
|
||||
def get_deal_service(repo: DealRepository = Depends(get_deal_repository)) -> DealService:
|
||||
return DealService(repository=repo)
|
||||
|
||||
|
||||
def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService:
|
||||
return UserService(user_repository=repo, password_hasher=password_hasher)
|
||||
|
||||
|
|
@ -46,6 +62,12 @@ def get_auth_service(
|
|||
)
|
||||
|
||||
|
||||
def get_organization_service(
|
||||
repo: OrganizationRepository = Depends(get_organization_repository),
|
||||
) -> OrganizationService:
|
||||
return OrganizationService(repository=repo)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
repo: UserRepository = Depends(get_user_repository),
|
||||
|
|
@ -67,3 +89,16 @@ async def get_current_user(
|
|||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
async def get_organization_context(
|
||||
x_organization_id: int | None = Header(default=None, alias="X-Organization-Id"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: OrganizationService = Depends(get_organization_service),
|
||||
) -> OrganizationContext:
|
||||
try:
|
||||
return await service.get_context(user_id=current_user.id, organization_id=x_organization_id)
|
||||
except OrganizationContextMissingError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
except OrganizationAccessDeniedError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
"""Activity timeline API stubs."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, status
|
||||
from fastapi import APIRouter, Depends, status
|
||||
|
||||
from app.api.deps import get_organization_context
|
||||
from app.services.organization_service import OrganizationContext
|
||||
|
||||
from .models import ActivityCommentPayload
|
||||
|
||||
|
|
@ -13,14 +16,21 @@ def _stub(endpoint: str) -> dict[str, str]:
|
|||
|
||||
|
||||
@router.get("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||
async def list_activities(deal_id: int) -> dict[str, str]:
|
||||
async def list_activities(
|
||||
deal_id: int,
|
||||
context: OrganizationContext = Depends(get_organization_context),
|
||||
) -> dict[str, str]:
|
||||
"""Placeholder for listing deal activities."""
|
||||
_ = deal_id
|
||||
_ = (deal_id, context)
|
||||
return _stub("GET /deals/{deal_id}/activities")
|
||||
|
||||
|
||||
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||
async def create_activity_comment(deal_id: int, payload: ActivityCommentPayload) -> dict[str, str]:
|
||||
async def create_activity_comment(
|
||||
deal_id: int,
|
||||
payload: ActivityCommentPayload,
|
||||
context: OrganizationContext = Depends(get_organization_context),
|
||||
) -> dict[str, str]:
|
||||
"""Placeholder for adding a comment activity to a deal."""
|
||||
_ = (deal_id, payload)
|
||||
_ = (deal_id, payload, context)
|
||||
return _stub("POST /deals/{deal_id}/activities")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
"""Analytics API stubs (deal summary and funnel)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Query, status
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
|
||||
from app.api.deps import get_organization_context
|
||||
from app.services.organization_service import OrganizationContext
|
||||
|
||||
router = APIRouter(prefix="/analytics", tags=["analytics"])
|
||||
|
||||
|
|
@ -11,13 +14,19 @@ def _stub(endpoint: str) -> dict[str, str]:
|
|||
|
||||
|
||||
@router.get("/deals/summary", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||
async def deals_summary(days: int = Query(30, ge=1, le=180)) -> dict[str, str]:
|
||||
async def deals_summary(
|
||||
days: int = Query(30, ge=1, le=180),
|
||||
context: OrganizationContext = Depends(get_organization_context),
|
||||
) -> dict[str, str]:
|
||||
"""Placeholder for aggregated deal statistics."""
|
||||
_ = days
|
||||
_ = (days, context)
|
||||
return _stub("GET /analytics/deals/summary")
|
||||
|
||||
|
||||
@router.get("/deals/funnel", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||
async def deals_funnel() -> dict[str, str]:
|
||||
async def deals_funnel(
|
||||
context: OrganizationContext = Depends(get_organization_context),
|
||||
) -> dict[str, str]:
|
||||
"""Placeholder for funnel analytics."""
|
||||
_ = context
|
||||
return _stub("GET /analytics/deals/funnel")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
"""Contact API stubs required by the spec."""
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Query, status
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
|
||||
from app.api.deps import get_organization_context
|
||||
from app.services.organization_service import OrganizationContext
|
||||
|
||||
from .models import ContactCreatePayload
|
||||
|
||||
|
|
@ -18,13 +21,18 @@ async def list_contacts(
|
|||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: str | None = None,
|
||||
owner_id: int | None = None,
|
||||
context: OrganizationContext = Depends(get_organization_context),
|
||||
) -> dict[str, str]:
|
||||
"""Placeholder list endpoint supporting the required filters."""
|
||||
_ = context
|
||||
return _stub("GET /contacts")
|
||||
|
||||
|
||||
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||
async def create_contact(payload: ContactCreatePayload) -> dict[str, str]:
|
||||
async def create_contact(
|
||||
payload: ContactCreatePayload,
|
||||
context: OrganizationContext = Depends(get_organization_context),
|
||||
) -> dict[str, str]:
|
||||
"""Placeholder for creating a contact within the current organization."""
|
||||
_ = payload
|
||||
_ = (payload, context)
|
||||
return _stub("POST /contacts")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,10 @@ from __future__ import annotations
|
|||
|
||||
from decimal import Decimal
|
||||
|
||||
from fastapi import APIRouter, Query, status
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
|
||||
from app.api.deps import get_organization_context
|
||||
from app.services.organization_service import OrganizationContext
|
||||
|
||||
from .models import DealCreatePayload, DealUpdatePayload
|
||||
|
||||
|
|
@ -25,21 +28,29 @@ async def list_deals(
|
|||
owner_id: int | None = None,
|
||||
order_by: str | None = None,
|
||||
order: str | None = Query(default=None, pattern="^(asc|desc)$"),
|
||||
context: OrganizationContext = Depends(get_organization_context),
|
||||
) -> dict[str, str]:
|
||||
"""Placeholder for deal filtering endpoint."""
|
||||
_ = (status_filter,)
|
||||
_ = (status_filter, context)
|
||||
return _stub("GET /deals")
|
||||
|
||||
|
||||
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||
async def create_deal(payload: DealCreatePayload) -> dict[str, str]:
|
||||
async def create_deal(
|
||||
payload: DealCreatePayload,
|
||||
context: OrganizationContext = Depends(get_organization_context),
|
||||
) -> dict[str, str]:
|
||||
"""Placeholder for creating a new deal."""
|
||||
_ = payload
|
||||
_ = (payload, context)
|
||||
return _stub("POST /deals")
|
||||
|
||||
|
||||
@router.patch("/{deal_id}", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||
async def update_deal(deal_id: int, payload: DealUpdatePayload) -> dict[str, str]:
|
||||
async def update_deal(
|
||||
deal_id: int,
|
||||
payload: DealUpdatePayload,
|
||||
context: OrganizationContext = Depends(get_organization_context),
|
||||
) -> dict[str, str]:
|
||||
"""Placeholder for modifying deal status or stage."""
|
||||
_ = (deal_id, payload)
|
||||
_ = (deal_id, payload, context)
|
||||
return _stub("PATCH /deals/{deal_id}")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,10 @@ from __future__ import annotations
|
|||
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Query, status
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
|
||||
from app.api.deps import get_organization_context
|
||||
from app.services.organization_service import OrganizationContext
|
||||
|
||||
from .models import TaskCreatePayload
|
||||
|
||||
|
|
@ -20,13 +23,18 @@ async def list_tasks(
|
|||
only_open: bool = False,
|
||||
due_before: date | None = Query(default=None),
|
||||
due_after: date | None = Query(default=None),
|
||||
context: OrganizationContext = Depends(get_organization_context),
|
||||
) -> dict[str, str]:
|
||||
"""Placeholder for task filtering endpoint."""
|
||||
_ = context
|
||||
return _stub("GET /tasks")
|
||||
|
||||
|
||||
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||
async def create_task(payload: TaskCreatePayload) -> dict[str, str]:
|
||||
async def create_task(
|
||||
payload: TaskCreatePayload,
|
||||
context: OrganizationContext = Depends(get_organization_context),
|
||||
) -> dict[str, str]:
|
||||
"""Placeholder for creating a task linked to a deal."""
|
||||
_ = payload
|
||||
_ = (payload, context)
|
||||
return _stub("POST /tasks")
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, func, text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.types import JSON as GenericJSON, TypeDecorator
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import Base
|
||||
|
|
@ -16,10 +17,25 @@ from app.models.base import Base
|
|||
class ActivityType(StrEnum):
|
||||
COMMENT = "comment"
|
||||
STATUS_CHANGED = "status_changed"
|
||||
STAGE_CHANGED = "stage_changed"
|
||||
TASK_CREATED = "task_created"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class JSONBCompat(TypeDecorator):
|
||||
"""Uses JSONB on Postgres and plain JSON elsewhere for testability."""
|
||||
|
||||
impl = JSONB
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect): # type: ignore[override]
|
||||
if dialect.name == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import JSON as SQLiteJSON # local import
|
||||
|
||||
return dialect.type_descriptor(SQLiteJSON())
|
||||
return dialect.type_descriptor(JSONB())
|
||||
|
||||
|
||||
class Activity(Base):
|
||||
"""Represents a timeline event for a deal."""
|
||||
|
||||
|
|
@ -32,9 +48,9 @@ class Activity(Base):
|
|||
)
|
||||
type: Mapped[ActivityType] = mapped_column(SqlEnum(ActivityType, name="activity_type"), nullable=False)
|
||||
payload: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSONB,
|
||||
JSONBCompat().with_variant(GenericJSON(), "sqlite"),
|
||||
nullable=False,
|
||||
server_default=text("'{}'::jsonb"),
|
||||
server_default=text("'{}'"),
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,152 @@
|
|||
"""Deal repository with access-aware CRUD helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Select, asc, desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.deal import Deal, DealCreate, DealStage, DealStatus
|
||||
from app.models.organization_member import OrganizationRole
|
||||
|
||||
|
||||
ORDERABLE_COLUMNS: dict[str, Any] = {
|
||||
"created_at": Deal.created_at,
|
||||
"amount": Deal.amount,
|
||||
"title": Deal.title,
|
||||
}
|
||||
|
||||
|
||||
class DealAccessError(Exception):
|
||||
"""Raised when a user attempts an operation without sufficient permissions."""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DealQueryParams:
|
||||
"""Filters supported by list queries."""
|
||||
|
||||
organization_id: int
|
||||
page: int = 1
|
||||
page_size: int = 20
|
||||
statuses: Sequence[DealStatus] | None = None
|
||||
stage: DealStage | None = None
|
||||
owner_id: int | None = None
|
||||
min_amount: Decimal | None = None
|
||||
max_amount: Decimal | None = None
|
||||
order_by: str | None = None
|
||||
order_desc: bool = True
|
||||
|
||||
|
||||
class DealRepository:
|
||||
"""Provides CRUD helpers for deals with role-aware filtering."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
@property
|
||||
def session(self) -> AsyncSession:
|
||||
return self._session
|
||||
|
||||
async def list(
|
||||
self,
|
||||
*,
|
||||
params: DealQueryParams,
|
||||
role: OrganizationRole,
|
||||
user_id: int,
|
||||
) -> Sequence[Deal]:
|
||||
stmt = select(Deal).where(Deal.organization_id == params.organization_id)
|
||||
stmt = self._apply_filters(stmt, params, role, user_id)
|
||||
stmt = self._apply_ordering(stmt, params)
|
||||
|
||||
offset = (max(params.page, 1) - 1) * params.page_size
|
||||
stmt = stmt.offset(offset).limit(params.page_size)
|
||||
result = await self._session.scalars(stmt)
|
||||
return result.all()
|
||||
|
||||
async def get(
|
||||
self,
|
||||
deal_id: int,
|
||||
*,
|
||||
organization_id: int,
|
||||
role: OrganizationRole,
|
||||
user_id: int,
|
||||
require_owner: bool = False,
|
||||
) -> Deal | None:
|
||||
stmt = select(Deal).where(Deal.id == deal_id, Deal.organization_id == organization_id)
|
||||
stmt = self._apply_role_clause(stmt, role, user_id, require_owner=require_owner)
|
||||
result = await self._session.scalars(stmt)
|
||||
return result.first()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
data: DealCreate,
|
||||
*,
|
||||
role: OrganizationRole,
|
||||
user_id: int,
|
||||
) -> Deal:
|
||||
if role == OrganizationRole.MEMBER and data.owner_id != user_id:
|
||||
raise DealAccessError("Members can only create deals they own")
|
||||
deal = Deal(**data.model_dump())
|
||||
self._session.add(deal)
|
||||
await self._session.flush()
|
||||
return deal
|
||||
|
||||
async def update(
|
||||
self,
|
||||
deal: Deal,
|
||||
updates: Mapping[str, Any],
|
||||
*,
|
||||
role: OrganizationRole,
|
||||
user_id: int,
|
||||
) -> Deal:
|
||||
if role == OrganizationRole.MEMBER and deal.owner_id != user_id:
|
||||
raise DealAccessError("Members can only modify their own deals")
|
||||
for field, value in updates.items():
|
||||
if hasattr(deal, field):
|
||||
setattr(deal, field, value)
|
||||
await self._session.flush()
|
||||
return deal
|
||||
|
||||
def _apply_filters(
|
||||
self,
|
||||
stmt: Select[tuple[Deal]],
|
||||
params: DealQueryParams,
|
||||
role: OrganizationRole,
|
||||
user_id: int,
|
||||
) -> Select[tuple[Deal]]:
|
||||
if params.statuses:
|
||||
stmt = stmt.where(Deal.status.in_(params.statuses))
|
||||
if params.stage:
|
||||
stmt = stmt.where(Deal.stage == params.stage)
|
||||
if params.owner_id is not None:
|
||||
if role == OrganizationRole.MEMBER and params.owner_id != user_id:
|
||||
raise DealAccessError("Members cannot filter by other owners")
|
||||
stmt = stmt.where(Deal.owner_id == params.owner_id)
|
||||
if params.min_amount is not None:
|
||||
stmt = stmt.where(Deal.amount >= params.min_amount)
|
||||
if params.max_amount is not None:
|
||||
stmt = stmt.where(Deal.amount <= params.max_amount)
|
||||
|
||||
return self._apply_role_clause(stmt, role, user_id)
|
||||
|
||||
def _apply_role_clause(
|
||||
self,
|
||||
stmt: Select[tuple[Deal]],
|
||||
role: OrganizationRole,
|
||||
user_id: int,
|
||||
*,
|
||||
require_owner: bool = False,
|
||||
) -> Select[tuple[Deal]]:
|
||||
if role in {OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MANAGER}:
|
||||
return stmt
|
||||
if require_owner:
|
||||
return stmt.where(Deal.owner_id == user_id)
|
||||
return stmt
|
||||
|
||||
def _apply_ordering(self, stmt: Select[tuple[Deal]], params: DealQueryParams) -> Select[tuple[Deal]]:
|
||||
column = ORDERABLE_COLUMNS.get(params.order_by or "created_at", Deal.created_at)
|
||||
order_func = desc if params.order_desc else asc
|
||||
return stmt.order_by(order_func(column))
|
||||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.organization import Organization, OrganizationCreate
|
||||
|
|
@ -42,6 +43,18 @@ class OrganizationRepository:
|
|||
result = await self._session.scalars(stmt)
|
||||
return result.unique().all()
|
||||
|
||||
async def get_membership(self, organization_id: int, user_id: int) -> OrganizationMember | None:
|
||||
stmt = (
|
||||
select(OrganizationMember)
|
||||
.where(
|
||||
OrganizationMember.organization_id == organization_id,
|
||||
OrganizationMember.user_id == user_id,
|
||||
)
|
||||
.options(selectinload(OrganizationMember.organization))
|
||||
)
|
||||
result = await self._session.scalars(stmt)
|
||||
return result.first()
|
||||
|
||||
async def create(self, data: OrganizationCreate) -> Organization:
|
||||
organization = Organization(name=data.name)
|
||||
self._session.add(organization)
|
||||
|
|
|
|||
|
|
@ -1 +1,11 @@
|
|||
"""Business logic services."""
|
||||
|
||||
from .deal_service import DealService # noqa: F401
|
||||
from .organization_service import ( # noqa: F401
|
||||
OrganizationAccessDeniedError,
|
||||
OrganizationContext,
|
||||
OrganizationContextMissingError,
|
||||
OrganizationService,
|
||||
)
|
||||
from .user_service import UserService # noqa: F401
|
||||
from .auth_service import AuthService # noqa: F401
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
"""Business logic for deals."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.models.activity import Activity, ActivityType
|
||||
from app.models.contact import Contact
|
||||
from app.models.deal import Deal, DealCreate, DealStage, DealStatus
|
||||
from app.models.organization_member import OrganizationRole
|
||||
from app.repositories.deal_repo import DealRepository
|
||||
from app.services.organization_service import OrganizationContext
|
||||
|
||||
|
||||
STAGE_ORDER = {
|
||||
stage: index
|
||||
for index, stage in enumerate(
|
||||
[
|
||||
DealStage.QUALIFICATION,
|
||||
DealStage.PROPOSAL,
|
||||
DealStage.NEGOTIATION,
|
||||
DealStage.CLOSED,
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class DealServiceError(Exception):
|
||||
"""Base class for deal service errors."""
|
||||
|
||||
|
||||
class DealOrganizationMismatchError(DealServiceError):
|
||||
"""Raised when attempting to use resources from another organization."""
|
||||
|
||||
|
||||
class DealStageTransitionError(DealServiceError):
|
||||
"""Raised when stage transition violates business rules."""
|
||||
|
||||
|
||||
class DealStatusValidationError(DealServiceError):
|
||||
"""Raised when invalid status transitions are requested."""
|
||||
|
||||
|
||||
class ContactHasDealsError(DealServiceError):
|
||||
"""Raised when attempting to delete a contact with active deals."""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DealUpdateData:
|
||||
"""Structured container for deal update operations."""
|
||||
|
||||
status: DealStatus | None = None
|
||||
stage: DealStage | None = None
|
||||
amount: Decimal | None = None
|
||||
currency: str | None = None
|
||||
|
||||
|
||||
class DealService:
|
||||
"""Encapsulates deal workflows and validations."""
|
||||
|
||||
def __init__(self, repository: DealRepository) -> None:
|
||||
self._repository = repository
|
||||
|
||||
async def create_deal(self, data: DealCreate, *, context: OrganizationContext) -> Deal:
|
||||
self._ensure_same_organization(data.organization_id, context)
|
||||
await self._ensure_contact_in_organization(data.contact_id, context.organization_id)
|
||||
return await self._repository.create(data=data, role=context.role, user_id=context.user_id)
|
||||
|
||||
async def update_deal(
|
||||
self,
|
||||
deal: Deal,
|
||||
updates: DealUpdateData,
|
||||
*,
|
||||
context: OrganizationContext,
|
||||
) -> Deal:
|
||||
self._ensure_same_organization(deal.organization_id, context)
|
||||
changes: dict[str, object] = {}
|
||||
stage_activity: tuple[ActivityType, dict[str, str]] | None = None
|
||||
status_activity: tuple[ActivityType, dict[str, str]] | None = None
|
||||
|
||||
if updates.amount is not None:
|
||||
changes["amount"] = updates.amount
|
||||
if updates.currency is not None:
|
||||
changes["currency"] = updates.currency
|
||||
|
||||
if updates.stage is not None and updates.stage != deal.stage:
|
||||
self._validate_stage_transition(deal.stage, updates.stage, context.role)
|
||||
changes["stage"] = updates.stage
|
||||
stage_activity = (
|
||||
ActivityType.STAGE_CHANGED,
|
||||
{"old_stage": deal.stage, "new_stage": updates.stage},
|
||||
)
|
||||
|
||||
if updates.status is not None and updates.status != deal.status:
|
||||
self._validate_status_transition(deal, updates)
|
||||
changes["status"] = updates.status
|
||||
status_activity = (
|
||||
ActivityType.STATUS_CHANGED,
|
||||
{"old_status": deal.status, "new_status": updates.status},
|
||||
)
|
||||
|
||||
if not changes:
|
||||
return deal
|
||||
|
||||
updated = await self._repository.update(deal, changes, role=context.role, user_id=context.user_id)
|
||||
await self._log_activities(
|
||||
deal_id=deal.id,
|
||||
author_id=context.user_id,
|
||||
activities=[activity for activity in [stage_activity, status_activity] if activity],
|
||||
)
|
||||
return updated
|
||||
|
||||
async def ensure_contact_can_be_deleted(self, contact_id: int) -> None:
|
||||
stmt = select(func.count()).select_from(Deal).where(Deal.contact_id == contact_id)
|
||||
count = await self._repository.session.scalar(stmt)
|
||||
if count and count > 0:
|
||||
raise ContactHasDealsError("Contact has related deals and cannot be deleted")
|
||||
|
||||
async def _log_activities(
|
||||
self,
|
||||
*,
|
||||
deal_id: int,
|
||||
author_id: int,
|
||||
activities: Iterable[tuple[ActivityType, dict[str, str]]],
|
||||
) -> None:
|
||||
entries = list(activities)
|
||||
if not entries:
|
||||
return
|
||||
for activity_type, payload in entries:
|
||||
activity = Activity(deal_id=deal_id, author_id=author_id, type=activity_type, payload=payload)
|
||||
self._repository.session.add(activity)
|
||||
await self._repository.session.flush()
|
||||
|
||||
def _ensure_same_organization(self, organization_id: int, context: OrganizationContext) -> None:
|
||||
if organization_id != context.organization_id:
|
||||
raise DealOrganizationMismatchError("Operation targets a different organization")
|
||||
|
||||
async def _ensure_contact_in_organization(self, contact_id: int, organization_id: int) -> Contact:
|
||||
contact = await self._repository.session.get(Contact, contact_id)
|
||||
if contact is None or contact.organization_id != organization_id:
|
||||
raise DealOrganizationMismatchError("Contact belongs to another organization")
|
||||
return contact
|
||||
|
||||
def _validate_stage_transition(
|
||||
self,
|
||||
current_stage: DealStage,
|
||||
new_stage: DealStage,
|
||||
role: OrganizationRole,
|
||||
) -> None:
|
||||
if STAGE_ORDER[new_stage] < STAGE_ORDER[current_stage] and role not in {
|
||||
OrganizationRole.OWNER,
|
||||
OrganizationRole.ADMIN,
|
||||
}:
|
||||
raise DealStageTransitionError("Stage rollback requires owner or admin role")
|
||||
|
||||
def _validate_status_transition(self, deal: Deal, updates: DealUpdateData) -> None:
|
||||
if updates.status != DealStatus.WON:
|
||||
return
|
||||
effective_amount = updates.amount if updates.amount is not None else deal.amount
|
||||
if effective_amount is None or Decimal(effective_amount) <= Decimal("0"):
|
||||
raise DealStatusValidationError("Amount must be greater than zero to mark a deal as won")
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
"""Organization-related business rules."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.models.organization import Organization
|
||||
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||
from app.repositories.org_repo import OrganizationRepository
|
||||
|
||||
|
||||
class OrganizationServiceError(Exception):
|
||||
"""Base class for organization service errors."""
|
||||
|
||||
|
||||
class OrganizationContextMissingError(OrganizationServiceError):
|
||||
"""Raised when the request lacks organization context."""
|
||||
|
||||
|
||||
class OrganizationAccessDeniedError(OrganizationServiceError):
|
||||
"""Raised when a user tries to work with a foreign organization."""
|
||||
|
||||
|
||||
class OrganizationForbiddenError(OrganizationServiceError):
|
||||
"""Raised when a user does not have enough privileges."""
|
||||
|
||||
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class OrganizationContext:
|
||||
"""Resolved organization and membership information for a request."""
|
||||
|
||||
organization: Organization
|
||||
membership: OrganizationMember
|
||||
|
||||
@property
|
||||
def organization_id(self) -> int:
|
||||
return self.organization.id
|
||||
|
||||
@property
|
||||
def role(self) -> OrganizationRole:
|
||||
return self.membership.role
|
||||
|
||||
@property
|
||||
def user_id(self) -> int:
|
||||
return self.membership.user_id
|
||||
|
||||
|
||||
class OrganizationService:
|
||||
"""Encapsulates organization-specific policies."""
|
||||
|
||||
def __init__(self, repository: OrganizationRepository) -> None:
|
||||
self._repository = repository
|
||||
|
||||
async def get_context(self, *, user_id: int, organization_id: int | None) -> OrganizationContext:
|
||||
"""Resolve request context ensuring the user belongs to the given organization."""
|
||||
|
||||
if organization_id is None:
|
||||
raise OrganizationContextMissingError("X-Organization-Id header is required")
|
||||
|
||||
membership = await self._repository.get_membership(organization_id, user_id)
|
||||
if membership is None or membership.organization is None:
|
||||
raise OrganizationAccessDeniedError("Organization not found")
|
||||
|
||||
return OrganizationContext(organization=membership.organization, membership=membership)
|
||||
|
||||
def ensure_entity_in_context(self, *, entity_organization_id: int, context: OrganizationContext) -> None:
|
||||
"""Make sure a resource belongs to the current organization."""
|
||||
|
||||
if entity_organization_id != context.organization_id:
|
||||
raise OrganizationAccessDeniedError("Resource belongs to another organization")
|
||||
|
||||
def ensure_can_manage_settings(self, context: OrganizationContext) -> None:
|
||||
"""Allow only owner/admin to change organization-level settings."""
|
||||
|
||||
if context.role not in {OrganizationRole.OWNER, OrganizationRole.ADMIN}:
|
||||
raise OrganizationForbiddenError("Only owner/admin can modify organization settings")
|
||||
|
||||
def ensure_can_manage_entity(self, context: OrganizationContext) -> None:
|
||||
"""Managers/admins/owners may manage entities; members are restricted."""
|
||||
|
||||
if context.role == OrganizationRole.MEMBER:
|
||||
raise OrganizationForbiddenError("Members cannot manage shared entities")
|
||||
|
||||
def ensure_member_owns_entity(self, *, context: OrganizationContext, owner_id: int) -> None:
|
||||
"""Members can only mutate entities they own (contacts/deals/tasks)."""
|
||||
|
||||
if context.role == OrganizationRole.MEMBER and owner_id != context.user_id:
|
||||
raise OrganizationForbiddenError("Members can only modify their own records")
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
"""Add stage_changed activity type."""
|
||||
from __future__ import annotations
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "20251127_0002_stage_changed"
|
||||
down_revision: str | None = "20251122_0001"
|
||||
branch_labels: tuple[str, ...] | None = None
|
||||
depends_on: tuple[str, ...] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute("ALTER TYPE activity_type ADD VALUE IF NOT EXISTS 'stage_changed';")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("UPDATE activities SET type = 'status_changed' WHERE type = 'stage_changed';")
|
||||
op.execute("ALTER TYPE activity_type RENAME TO activity_type_old;")
|
||||
op.execute(
|
||||
"CREATE TYPE activity_type AS ENUM ('comment','status_changed','task_created','system');"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE activities ALTER COLUMN type TYPE activity_type USING type::text::activity_type;"
|
||||
)
|
||||
op.execute("DROP TYPE activity_type_old;")
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
"""Unit tests for DealService."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from decimal import Decimal
|
||||
import uuid
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
import pytest_asyncio # type: ignore[import-not-found]
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.models.activity import Activity, ActivityType
|
||||
from app.models.base import Base
|
||||
from app.models.contact import Contact
|
||||
from app.models.deal import DealCreate, DealStage, DealStatus
|
||||
from app.models.organization import Organization
|
||||
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||
from app.models.user import User
|
||||
from app.repositories.deal_repo import DealRepository
|
||||
from app.services.deal_service import (
|
||||
ContactHasDealsError,
|
||||
DealOrganizationMismatchError,
|
||||
DealService,
|
||||
DealStageTransitionError,
|
||||
DealStatusValidationError,
|
||||
DealUpdateData,
|
||||
)
|
||||
from app.services.organization_service import OrganizationContext
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def session() -> AsyncGenerator[AsyncSession, None]:
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
future=True,
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def _make_organization(name: str) -> Organization:
|
||||
org = Organization(name=name)
|
||||
return org
|
||||
|
||||
|
||||
def _make_user(email_suffix: str) -> User:
|
||||
return User(
|
||||
email=f"user-{email_suffix}@example.com",
|
||||
hashed_password="hashed",
|
||||
name="Test User",
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
|
||||
def _make_context(org: Organization, user: User, role: OrganizationRole) -> OrganizationContext:
|
||||
membership = OrganizationMember(organization_id=org.id, user_id=user.id, role=role)
|
||||
return OrganizationContext(organization=org, membership=membership)
|
||||
|
||||
|
||||
async def _persist_base(session: AsyncSession, *, role: OrganizationRole = OrganizationRole.MANAGER) -> tuple[
|
||||
OrganizationContext,
|
||||
Contact,
|
||||
DealRepository,
|
||||
]:
|
||||
org = _make_organization(name=f"Org-{uuid.uuid4()}"[:8])
|
||||
user = _make_user(email_suffix=str(uuid.uuid4())[:8])
|
||||
session.add_all([org, user])
|
||||
await session.flush()
|
||||
|
||||
contact = Contact(
|
||||
organization_id=org.id,
|
||||
owner_id=user.id,
|
||||
name="John Doe",
|
||||
email="john@example.com",
|
||||
)
|
||||
session.add(contact)
|
||||
await session.flush()
|
||||
|
||||
context = _make_context(org, user, role)
|
||||
repo = DealRepository(session=session)
|
||||
return context, contact, repo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_deal_rejects_foreign_contact(session: AsyncSession) -> None:
|
||||
context, contact, repo = await _persist_base(session)
|
||||
|
||||
other_org = _make_organization(name="Other")
|
||||
other_user = _make_user(email_suffix="other")
|
||||
session.add_all([other_org, other_user])
|
||||
await session.flush()
|
||||
|
||||
service = DealService(repository=repo)
|
||||
payload = DealCreate(
|
||||
organization_id=other_org.id,
|
||||
contact_id=contact.id,
|
||||
owner_id=context.user_id,
|
||||
title="Website Redesign",
|
||||
amount=None,
|
||||
)
|
||||
|
||||
other_context = _make_context(other_org, other_user, OrganizationRole.MANAGER)
|
||||
|
||||
with pytest.raises(DealOrganizationMismatchError):
|
||||
await service.create_deal(payload, context=other_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_rollback_requires_admin(session: AsyncSession) -> None:
|
||||
context, contact, repo = await _persist_base(session, role=OrganizationRole.MANAGER)
|
||||
service = DealService(repository=repo)
|
||||
|
||||
deal = await service.create_deal(
|
||||
DealCreate(
|
||||
organization_id=context.organization_id,
|
||||
contact_id=contact.id,
|
||||
owner_id=context.user_id,
|
||||
title="Migration",
|
||||
amount=Decimal("5000"),
|
||||
),
|
||||
context=context,
|
||||
)
|
||||
deal.stage = DealStage.PROPOSAL
|
||||
|
||||
with pytest.raises(DealStageTransitionError):
|
||||
await service.update_deal(
|
||||
deal,
|
||||
DealUpdateData(stage=DealStage.QUALIFICATION),
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_rollback_allowed_for_admin(session: AsyncSession) -> None:
|
||||
context, contact, repo = await _persist_base(session, role=OrganizationRole.ADMIN)
|
||||
service = DealService(repository=repo)
|
||||
|
||||
deal = await service.create_deal(
|
||||
DealCreate(
|
||||
organization_id=context.organization_id,
|
||||
contact_id=contact.id,
|
||||
owner_id=context.user_id,
|
||||
title="Rollout",
|
||||
amount=Decimal("1000"),
|
||||
),
|
||||
context=context,
|
||||
)
|
||||
deal.stage = DealStage.NEGOTIATION
|
||||
|
||||
updated = await service.update_deal(
|
||||
deal,
|
||||
DealUpdateData(stage=DealStage.PROPOSAL),
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert updated.stage == DealStage.PROPOSAL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_won_requires_positive_amount(session: AsyncSession) -> None:
|
||||
context, contact, repo = await _persist_base(session)
|
||||
service = DealService(repository=repo)
|
||||
|
||||
deal = await service.create_deal(
|
||||
DealCreate(
|
||||
organization_id=context.organization_id,
|
||||
contact_id=contact.id,
|
||||
owner_id=context.user_id,
|
||||
title="Zero",
|
||||
amount=None,
|
||||
),
|
||||
context=context,
|
||||
)
|
||||
|
||||
with pytest.raises(DealStatusValidationError):
|
||||
await service.update_deal(
|
||||
deal,
|
||||
DealUpdateData(status=DealStatus.WON),
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_create_activity_records(session: AsyncSession) -> None:
|
||||
context, contact, repo = await _persist_base(session)
|
||||
service = DealService(repository=repo)
|
||||
|
||||
deal = await service.create_deal(
|
||||
DealCreate(
|
||||
organization_id=context.organization_id,
|
||||
contact_id=contact.id,
|
||||
owner_id=context.user_id,
|
||||
title="Activity",
|
||||
amount=Decimal("100"),
|
||||
),
|
||||
context=context,
|
||||
)
|
||||
|
||||
await service.update_deal(
|
||||
deal,
|
||||
DealUpdateData(
|
||||
stage=DealStage.PROPOSAL,
|
||||
status=DealStatus.WON,
|
||||
amount=Decimal("5000"),
|
||||
),
|
||||
context=context,
|
||||
)
|
||||
|
||||
result = await session.scalars(select(Activity).where(Activity.deal_id == deal.id))
|
||||
activity_types = {activity.type for activity in result.all()}
|
||||
assert ActivityType.STAGE_CHANGED in activity_types
|
||||
assert ActivityType.STATUS_CHANGED in activity_types
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_contact_delete_guard(session: AsyncSession) -> None:
|
||||
context, contact, repo = await _persist_base(session)
|
||||
service = DealService(repository=repo)
|
||||
|
||||
deal = await service.create_deal(
|
||||
DealCreate(
|
||||
organization_id=context.organization_id,
|
||||
contact_id=contact.id,
|
||||
owner_id=context.user_id,
|
||||
title="To Delete",
|
||||
amount=Decimal("100"),
|
||||
),
|
||||
context=context,
|
||||
)
|
||||
|
||||
with pytest.raises(ContactHasDealsError):
|
||||
await service.ensure_contact_can_be_deleted(contact.id)
|
||||
|
||||
await session.delete(deal)
|
||||
await session.flush()
|
||||
|
||||
await service.ensure_contact_can_be_deleted(contact.id)
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
"""Unit tests for OrganizationService."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.organization import Organization
|
||||
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||
from app.repositories.org_repo import OrganizationRepository
|
||||
from app.services.organization_service import (
|
||||
OrganizationAccessDeniedError,
|
||||
OrganizationContext,
|
||||
OrganizationContextMissingError,
|
||||
OrganizationForbiddenError,
|
||||
OrganizationService,
|
||||
)
|
||||
|
||||
|
||||
class StubOrganizationRepository(OrganizationRepository):
|
||||
"""Simple in-memory stand-in for OrganizationRepository."""
|
||||
|
||||
def __init__(self, membership: OrganizationMember | None) -> None:
|
||||
super().__init__(session=MagicMock(spec=AsyncSession))
|
||||
self._membership = membership
|
||||
|
||||
async def get_membership(self, organization_id: int, user_id: int) -> OrganizationMember | None: # pragma: no cover - helper
|
||||
if (
|
||||
self._membership
|
||||
and self._membership.organization_id == organization_id
|
||||
and self._membership.user_id == user_id
|
||||
):
|
||||
return self._membership
|
||||
return None
|
||||
|
||||
|
||||
def make_membership(role: OrganizationRole, *, organization_id: int = 1, user_id: int = 10) -> OrganizationMember:
|
||||
organization = Organization(name="Acme Inc")
|
||||
organization.id = organization_id
|
||||
membership = OrganizationMember(
|
||||
organization_id=organization_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
)
|
||||
membership.organization = organization
|
||||
return membership
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_success() -> None:
|
||||
membership = make_membership(OrganizationRole.MANAGER)
|
||||
service = OrganizationService(StubOrganizationRepository(membership))
|
||||
|
||||
context = await service.get_context(user_id=membership.user_id, organization_id=membership.organization_id)
|
||||
|
||||
assert context.organization_id == membership.organization_id
|
||||
assert context.role == OrganizationRole.MANAGER
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_missing_header() -> None:
|
||||
service = OrganizationService(StubOrganizationRepository(None))
|
||||
|
||||
with pytest.raises(OrganizationContextMissingError):
|
||||
await service.get_context(user_id=1, organization_id=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_access_denied() -> None:
|
||||
service = OrganizationService(StubOrganizationRepository(None))
|
||||
|
||||
with pytest.raises(OrganizationAccessDeniedError):
|
||||
await service.get_context(user_id=1, organization_id=99)
|
||||
|
||||
|
||||
def test_ensure_can_manage_settings_blocks_manager() -> None:
|
||||
membership = make_membership(OrganizationRole.MANAGER)
|
||||
organization = membership.organization
|
||||
assert organization is not None
|
||||
context = OrganizationContext(organization=organization, membership=membership)
|
||||
service = OrganizationService(StubOrganizationRepository(membership))
|
||||
|
||||
with pytest.raises(OrganizationForbiddenError):
|
||||
service.ensure_can_manage_settings(context)
|
||||
|
||||
|
||||
def test_member_must_own_entity() -> None:
|
||||
membership = make_membership(OrganizationRole.MEMBER)
|
||||
organization = membership.organization
|
||||
assert organization is not None
|
||||
context = OrganizationContext(organization=organization, membership=membership)
|
||||
service = OrganizationService(StubOrganizationRepository(membership))
|
||||
|
||||
with pytest.raises(OrganizationForbiddenError):
|
||||
service.ensure_member_owns_entity(context=context, owner_id=999)
|
||||
|
||||
# Same owner should pass silently.
|
||||
service.ensure_member_owns_entity(context=context, owner_id=membership.user_id)
|
||||
Loading…
Reference in New Issue