Compare commits
5 Commits
c972d79ba9
...
a4c3864ef6
| Author | SHA1 | Date |
|---|---|---|
|
|
a4c3864ef6 | |
|
|
8492a0aed1 | |
|
|
969a1b5905 | |
|
|
8c326501bf | |
|
|
4b45073bd3 |
|
|
@ -2,7 +2,7 @@
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, Header, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
@ -10,9 +10,17 @@ from app.core.config import settings
|
||||||
from app.core.database import get_session
|
from app.core.database import get_session
|
||||||
from app.core.security import jwt_service, password_hasher
|
from app.core.security import jwt_service, password_hasher
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.deal_repo import DealRepository
|
||||||
from app.repositories.org_repo import OrganizationRepository
|
from app.repositories.org_repo import OrganizationRepository
|
||||||
from app.repositories.user_repo import UserRepository
|
from app.repositories.user_repo import UserRepository
|
||||||
from app.services.auth_service import AuthService
|
from app.services.auth_service import AuthService
|
||||||
|
from app.services.deal_service import DealService
|
||||||
|
from app.services.organization_service import (
|
||||||
|
OrganizationAccessDeniedError,
|
||||||
|
OrganizationContext,
|
||||||
|
OrganizationContextMissingError,
|
||||||
|
OrganizationService,
|
||||||
|
)
|
||||||
from app.services.user_service import UserService
|
from app.services.user_service import UserService
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token")
|
||||||
|
|
@ -32,6 +40,14 @@ def get_organization_repository(session: AsyncSession = Depends(get_db_session))
|
||||||
return OrganizationRepository(session=session)
|
return OrganizationRepository(session=session)
|
||||||
|
|
||||||
|
|
||||||
|
def get_deal_repository(session: AsyncSession = Depends(get_db_session)) -> DealRepository:
|
||||||
|
return DealRepository(session=session)
|
||||||
|
|
||||||
|
|
||||||
|
def get_deal_service(repo: DealRepository = Depends(get_deal_repository)) -> DealService:
|
||||||
|
return DealService(repository=repo)
|
||||||
|
|
||||||
|
|
||||||
def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService:
|
def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService:
|
||||||
return UserService(user_repository=repo, password_hasher=password_hasher)
|
return UserService(user_repository=repo, password_hasher=password_hasher)
|
||||||
|
|
||||||
|
|
@ -46,6 +62,12 @@ def get_auth_service(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_organization_service(
|
||||||
|
repo: OrganizationRepository = Depends(get_organization_repository),
|
||||||
|
) -> OrganizationService:
|
||||||
|
return OrganizationService(repository=repo)
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
token: str = Depends(oauth2_scheme),
|
token: str = Depends(oauth2_scheme),
|
||||||
repo: UserRepository = Depends(get_user_repository),
|
repo: UserRepository = Depends(get_user_repository),
|
||||||
|
|
@ -67,3 +89,16 @@ async def get_current_user(
|
||||||
if user is None:
|
if user is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_organization_context(
|
||||||
|
x_organization_id: int | None = Header(default=None, alias="X-Organization-Id"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
service: OrganizationService = Depends(get_organization_service),
|
||||||
|
) -> OrganizationContext:
|
||||||
|
try:
|
||||||
|
return await service.get_context(user_id=current_user.id, organization_id=x_organization_id)
|
||||||
|
except OrganizationContextMissingError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||||
|
except OrganizationAccessDeniedError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
"""Activity timeline API stubs."""
|
"""Activity timeline API stubs."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, Depends, status
|
||||||
|
|
||||||
|
from app.api.deps import get_organization_context
|
||||||
|
from app.services.organization_service import OrganizationContext
|
||||||
|
|
||||||
from .models import ActivityCommentPayload
|
from .models import ActivityCommentPayload
|
||||||
|
|
||||||
|
|
@ -13,14 +16,21 @@ def _stub(endpoint: str) -> dict[str, str]:
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
@router.get("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||||
async def list_activities(deal_id: int) -> dict[str, str]:
|
async def list_activities(
|
||||||
|
deal_id: int,
|
||||||
|
context: OrganizationContext = Depends(get_organization_context),
|
||||||
|
) -> dict[str, str]:
|
||||||
"""Placeholder for listing deal activities."""
|
"""Placeholder for listing deal activities."""
|
||||||
_ = deal_id
|
_ = (deal_id, context)
|
||||||
return _stub("GET /deals/{deal_id}/activities")
|
return _stub("GET /deals/{deal_id}/activities")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||||
async def create_activity_comment(deal_id: int, payload: ActivityCommentPayload) -> dict[str, str]:
|
async def create_activity_comment(
|
||||||
|
deal_id: int,
|
||||||
|
payload: ActivityCommentPayload,
|
||||||
|
context: OrganizationContext = Depends(get_organization_context),
|
||||||
|
) -> dict[str, str]:
|
||||||
"""Placeholder for adding a comment activity to a deal."""
|
"""Placeholder for adding a comment activity to a deal."""
|
||||||
_ = (deal_id, payload)
|
_ = (deal_id, payload, context)
|
||||||
return _stub("POST /deals/{deal_id}/activities")
|
return _stub("POST /deals/{deal_id}/activities")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
"""Analytics API stubs (deal summary and funnel)."""
|
"""Analytics API stubs (deal summary and funnel)."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from fastapi import APIRouter, Query, status
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
|
|
||||||
|
from app.api.deps import get_organization_context
|
||||||
|
from app.services.organization_service import OrganizationContext
|
||||||
|
|
||||||
router = APIRouter(prefix="/analytics", tags=["analytics"])
|
router = APIRouter(prefix="/analytics", tags=["analytics"])
|
||||||
|
|
||||||
|
|
@ -11,13 +14,19 @@ def _stub(endpoint: str) -> dict[str, str]:
|
||||||
|
|
||||||
|
|
||||||
@router.get("/deals/summary", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
@router.get("/deals/summary", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||||
async def deals_summary(days: int = Query(30, ge=1, le=180)) -> dict[str, str]:
|
async def deals_summary(
|
||||||
|
days: int = Query(30, ge=1, le=180),
|
||||||
|
context: OrganizationContext = Depends(get_organization_context),
|
||||||
|
) -> dict[str, str]:
|
||||||
"""Placeholder for aggregated deal statistics."""
|
"""Placeholder for aggregated deal statistics."""
|
||||||
_ = days
|
_ = (days, context)
|
||||||
return _stub("GET /analytics/deals/summary")
|
return _stub("GET /analytics/deals/summary")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/deals/funnel", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
@router.get("/deals/funnel", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||||
async def deals_funnel() -> dict[str, str]:
|
async def deals_funnel(
|
||||||
|
context: OrganizationContext = Depends(get_organization_context),
|
||||||
|
) -> dict[str, str]:
|
||||||
"""Placeholder for funnel analytics."""
|
"""Placeholder for funnel analytics."""
|
||||||
|
_ = context
|
||||||
return _stub("GET /analytics/deals/funnel")
|
return _stub("GET /analytics/deals/funnel")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
"""Contact API stubs required by the spec."""
|
"""Contact API stubs required by the spec."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from fastapi import APIRouter, Query, status
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
|
|
||||||
|
from app.api.deps import get_organization_context
|
||||||
|
from app.services.organization_service import OrganizationContext
|
||||||
|
|
||||||
from .models import ContactCreatePayload
|
from .models import ContactCreatePayload
|
||||||
|
|
||||||
|
|
@ -18,13 +21,18 @@ async def list_contacts(
|
||||||
page_size: int = Query(20, ge=1, le=100),
|
page_size: int = Query(20, ge=1, le=100),
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
owner_id: int | None = None,
|
owner_id: int | None = None,
|
||||||
|
context: OrganizationContext = Depends(get_organization_context),
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Placeholder list endpoint supporting the required filters."""
|
"""Placeholder list endpoint supporting the required filters."""
|
||||||
|
_ = context
|
||||||
return _stub("GET /contacts")
|
return _stub("GET /contacts")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||||
async def create_contact(payload: ContactCreatePayload) -> dict[str, str]:
|
async def create_contact(
|
||||||
|
payload: ContactCreatePayload,
|
||||||
|
context: OrganizationContext = Depends(get_organization_context),
|
||||||
|
) -> dict[str, str]:
|
||||||
"""Placeholder for creating a contact within the current organization."""
|
"""Placeholder for creating a contact within the current organization."""
|
||||||
_ = payload
|
_ = (payload, context)
|
||||||
return _stub("POST /contacts")
|
return _stub("POST /contacts")
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,10 @@ from __future__ import annotations
|
||||||
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
from fastapi import APIRouter, Query, status
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
|
|
||||||
|
from app.api.deps import get_organization_context
|
||||||
|
from app.services.organization_service import OrganizationContext
|
||||||
|
|
||||||
from .models import DealCreatePayload, DealUpdatePayload
|
from .models import DealCreatePayload, DealUpdatePayload
|
||||||
|
|
||||||
|
|
@ -25,21 +28,29 @@ async def list_deals(
|
||||||
owner_id: int | None = None,
|
owner_id: int | None = None,
|
||||||
order_by: str | None = None,
|
order_by: str | None = None,
|
||||||
order: str | None = Query(default=None, pattern="^(asc|desc)$"),
|
order: str | None = Query(default=None, pattern="^(asc|desc)$"),
|
||||||
|
context: OrganizationContext = Depends(get_organization_context),
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Placeholder for deal filtering endpoint."""
|
"""Placeholder for deal filtering endpoint."""
|
||||||
_ = (status_filter,)
|
_ = (status_filter, context)
|
||||||
return _stub("GET /deals")
|
return _stub("GET /deals")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||||
async def create_deal(payload: DealCreatePayload) -> dict[str, str]:
|
async def create_deal(
|
||||||
|
payload: DealCreatePayload,
|
||||||
|
context: OrganizationContext = Depends(get_organization_context),
|
||||||
|
) -> dict[str, str]:
|
||||||
"""Placeholder for creating a new deal."""
|
"""Placeholder for creating a new deal."""
|
||||||
_ = payload
|
_ = (payload, context)
|
||||||
return _stub("POST /deals")
|
return _stub("POST /deals")
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{deal_id}", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
@router.patch("/{deal_id}", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||||
async def update_deal(deal_id: int, payload: DealUpdatePayload) -> dict[str, str]:
|
async def update_deal(
|
||||||
|
deal_id: int,
|
||||||
|
payload: DealUpdatePayload,
|
||||||
|
context: OrganizationContext = Depends(get_organization_context),
|
||||||
|
) -> dict[str, str]:
|
||||||
"""Placeholder for modifying deal status or stage."""
|
"""Placeholder for modifying deal status or stage."""
|
||||||
_ = (deal_id, payload)
|
_ = (deal_id, payload, context)
|
||||||
return _stub("PATCH /deals/{deal_id}")
|
return _stub("PATCH /deals/{deal_id}")
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,10 @@ from __future__ import annotations
|
||||||
|
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
|
||||||
from fastapi import APIRouter, Query, status
|
from fastapi import APIRouter, Depends, Query, status
|
||||||
|
|
||||||
|
from app.api.deps import get_organization_context
|
||||||
|
from app.services.organization_service import OrganizationContext
|
||||||
|
|
||||||
from .models import TaskCreatePayload
|
from .models import TaskCreatePayload
|
||||||
|
|
||||||
|
|
@ -20,13 +23,18 @@ async def list_tasks(
|
||||||
only_open: bool = False,
|
only_open: bool = False,
|
||||||
due_before: date | None = Query(default=None),
|
due_before: date | None = Query(default=None),
|
||||||
due_after: date | None = Query(default=None),
|
due_after: date | None = Query(default=None),
|
||||||
|
context: OrganizationContext = Depends(get_organization_context),
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Placeholder for task filtering endpoint."""
|
"""Placeholder for task filtering endpoint."""
|
||||||
|
_ = context
|
||||||
return _stub("GET /tasks")
|
return _stub("GET /tasks")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
@router.post("/", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
||||||
async def create_task(payload: TaskCreatePayload) -> dict[str, str]:
|
async def create_task(
|
||||||
|
payload: TaskCreatePayload,
|
||||||
|
context: OrganizationContext = Depends(get_organization_context),
|
||||||
|
) -> dict[str, str]:
|
||||||
"""Placeholder for creating a task linked to a deal."""
|
"""Placeholder for creating a task linked to a deal."""
|
||||||
_ = payload
|
_ = (payload, context)
|
||||||
return _stub("POST /tasks")
|
return _stub("POST /tasks")
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from typing import Any
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, func, text
|
from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, func, text
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.types import JSON as GenericJSON, TypeDecorator
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from app.models.base import Base
|
from app.models.base import Base
|
||||||
|
|
@ -16,10 +17,25 @@ from app.models.base import Base
|
||||||
class ActivityType(StrEnum):
|
class ActivityType(StrEnum):
|
||||||
COMMENT = "comment"
|
COMMENT = "comment"
|
||||||
STATUS_CHANGED = "status_changed"
|
STATUS_CHANGED = "status_changed"
|
||||||
|
STAGE_CHANGED = "stage_changed"
|
||||||
TASK_CREATED = "task_created"
|
TASK_CREATED = "task_created"
|
||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
|
|
||||||
|
|
||||||
|
class JSONBCompat(TypeDecorator):
|
||||||
|
"""Uses JSONB on Postgres and plain JSON elsewhere for testability."""
|
||||||
|
|
||||||
|
impl = JSONB
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def load_dialect_impl(self, dialect): # type: ignore[override]
|
||||||
|
if dialect.name == "sqlite":
|
||||||
|
from sqlalchemy.dialects.sqlite import JSON as SQLiteJSON # local import
|
||||||
|
|
||||||
|
return dialect.type_descriptor(SQLiteJSON())
|
||||||
|
return dialect.type_descriptor(JSONB())
|
||||||
|
|
||||||
|
|
||||||
class Activity(Base):
|
class Activity(Base):
|
||||||
"""Represents a timeline event for a deal."""
|
"""Represents a timeline event for a deal."""
|
||||||
|
|
||||||
|
|
@ -32,9 +48,9 @@ class Activity(Base):
|
||||||
)
|
)
|
||||||
type: Mapped[ActivityType] = mapped_column(SqlEnum(ActivityType, name="activity_type"), nullable=False)
|
type: Mapped[ActivityType] = mapped_column(SqlEnum(ActivityType, name="activity_type"), nullable=False)
|
||||||
payload: Mapped[dict[str, Any]] = mapped_column(
|
payload: Mapped[dict[str, Any]] = mapped_column(
|
||||||
JSONB,
|
JSONBCompat().with_variant(GenericJSON(), "sqlite"),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default=text("'{}'::jsonb"),
|
server_default=text("'{}'"),
|
||||||
)
|
)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,152 @@
|
||||||
|
"""Deal repository with access-aware CRUD helpers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import Select, asc, desc, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.deal import Deal, DealCreate, DealStage, DealStatus
|
||||||
|
from app.models.organization_member import OrganizationRole
|
||||||
|
|
||||||
|
|
||||||
|
ORDERABLE_COLUMNS: dict[str, Any] = {
|
||||||
|
"created_at": Deal.created_at,
|
||||||
|
"amount": Deal.amount,
|
||||||
|
"title": Deal.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DealAccessError(Exception):
|
||||||
|
"""Raised when a user attempts an operation without sufficient permissions."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class DealQueryParams:
|
||||||
|
"""Filters supported by list queries."""
|
||||||
|
|
||||||
|
organization_id: int
|
||||||
|
page: int = 1
|
||||||
|
page_size: int = 20
|
||||||
|
statuses: Sequence[DealStatus] | None = None
|
||||||
|
stage: DealStage | None = None
|
||||||
|
owner_id: int | None = None
|
||||||
|
min_amount: Decimal | None = None
|
||||||
|
max_amount: Decimal | None = None
|
||||||
|
order_by: str | None = None
|
||||||
|
order_desc: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class DealRepository:
|
||||||
|
"""Provides CRUD helpers for deals with role-aware filtering."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self) -> AsyncSession:
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def list(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
params: DealQueryParams,
|
||||||
|
role: OrganizationRole,
|
||||||
|
user_id: int,
|
||||||
|
) -> Sequence[Deal]:
|
||||||
|
stmt = select(Deal).where(Deal.organization_id == params.organization_id)
|
||||||
|
stmt = self._apply_filters(stmt, params, role, user_id)
|
||||||
|
stmt = self._apply_ordering(stmt, params)
|
||||||
|
|
||||||
|
offset = (max(params.page, 1) - 1) * params.page_size
|
||||||
|
stmt = stmt.offset(offset).limit(params.page_size)
|
||||||
|
result = await self._session.scalars(stmt)
|
||||||
|
return result.all()
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self,
|
||||||
|
deal_id: int,
|
||||||
|
*,
|
||||||
|
organization_id: int,
|
||||||
|
role: OrganizationRole,
|
||||||
|
user_id: int,
|
||||||
|
require_owner: bool = False,
|
||||||
|
) -> Deal | None:
|
||||||
|
stmt = select(Deal).where(Deal.id == deal_id, Deal.organization_id == organization_id)
|
||||||
|
stmt = self._apply_role_clause(stmt, role, user_id, require_owner=require_owner)
|
||||||
|
result = await self._session.scalars(stmt)
|
||||||
|
return result.first()
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
data: DealCreate,
|
||||||
|
*,
|
||||||
|
role: OrganizationRole,
|
||||||
|
user_id: int,
|
||||||
|
) -> Deal:
|
||||||
|
if role == OrganizationRole.MEMBER and data.owner_id != user_id:
|
||||||
|
raise DealAccessError("Members can only create deals they own")
|
||||||
|
deal = Deal(**data.model_dump())
|
||||||
|
self._session.add(deal)
|
||||||
|
await self._session.flush()
|
||||||
|
return deal
|
||||||
|
|
||||||
|
async def update(
|
||||||
|
self,
|
||||||
|
deal: Deal,
|
||||||
|
updates: Mapping[str, Any],
|
||||||
|
*,
|
||||||
|
role: OrganizationRole,
|
||||||
|
user_id: int,
|
||||||
|
) -> Deal:
|
||||||
|
if role == OrganizationRole.MEMBER and deal.owner_id != user_id:
|
||||||
|
raise DealAccessError("Members can only modify their own deals")
|
||||||
|
for field, value in updates.items():
|
||||||
|
if hasattr(deal, field):
|
||||||
|
setattr(deal, field, value)
|
||||||
|
await self._session.flush()
|
||||||
|
return deal
|
||||||
|
|
||||||
|
def _apply_filters(
|
||||||
|
self,
|
||||||
|
stmt: Select[tuple[Deal]],
|
||||||
|
params: DealQueryParams,
|
||||||
|
role: OrganizationRole,
|
||||||
|
user_id: int,
|
||||||
|
) -> Select[tuple[Deal]]:
|
||||||
|
if params.statuses:
|
||||||
|
stmt = stmt.where(Deal.status.in_(params.statuses))
|
||||||
|
if params.stage:
|
||||||
|
stmt = stmt.where(Deal.stage == params.stage)
|
||||||
|
if params.owner_id is not None:
|
||||||
|
if role == OrganizationRole.MEMBER and params.owner_id != user_id:
|
||||||
|
raise DealAccessError("Members cannot filter by other owners")
|
||||||
|
stmt = stmt.where(Deal.owner_id == params.owner_id)
|
||||||
|
if params.min_amount is not None:
|
||||||
|
stmt = stmt.where(Deal.amount >= params.min_amount)
|
||||||
|
if params.max_amount is not None:
|
||||||
|
stmt = stmt.where(Deal.amount <= params.max_amount)
|
||||||
|
|
||||||
|
return self._apply_role_clause(stmt, role, user_id)
|
||||||
|
|
||||||
|
def _apply_role_clause(
|
||||||
|
self,
|
||||||
|
stmt: Select[tuple[Deal]],
|
||||||
|
role: OrganizationRole,
|
||||||
|
user_id: int,
|
||||||
|
*,
|
||||||
|
require_owner: bool = False,
|
||||||
|
) -> Select[tuple[Deal]]:
|
||||||
|
if role in {OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MANAGER}:
|
||||||
|
return stmt
|
||||||
|
if require_owner:
|
||||||
|
return stmt.where(Deal.owner_id == user_id)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
def _apply_ordering(self, stmt: Select[tuple[Deal]], params: DealQueryParams) -> Select[tuple[Deal]]:
|
||||||
|
column = ORDERABLE_COLUMNS.get(params.order_by or "created_at", Deal.created_at)
|
||||||
|
order_func = desc if params.order_desc else asc
|
||||||
|
return stmt.order_by(order_func(column))
|
||||||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.models.organization import Organization, OrganizationCreate
|
from app.models.organization import Organization, OrganizationCreate
|
||||||
|
|
@ -42,6 +43,18 @@ class OrganizationRepository:
|
||||||
result = await self._session.scalars(stmt)
|
result = await self._session.scalars(stmt)
|
||||||
return result.unique().all()
|
return result.unique().all()
|
||||||
|
|
||||||
|
async def get_membership(self, organization_id: int, user_id: int) -> OrganizationMember | None:
|
||||||
|
stmt = (
|
||||||
|
select(OrganizationMember)
|
||||||
|
.where(
|
||||||
|
OrganizationMember.organization_id == organization_id,
|
||||||
|
OrganizationMember.user_id == user_id,
|
||||||
|
)
|
||||||
|
.options(selectinload(OrganizationMember.organization))
|
||||||
|
)
|
||||||
|
result = await self._session.scalars(stmt)
|
||||||
|
return result.first()
|
||||||
|
|
||||||
async def create(self, data: OrganizationCreate) -> Organization:
|
async def create(self, data: OrganizationCreate) -> Organization:
|
||||||
organization = Organization(name=data.name)
|
organization = Organization(name=data.name)
|
||||||
self._session.add(organization)
|
self._session.add(organization)
|
||||||
|
|
|
||||||
|
|
@ -1 +1,11 @@
|
||||||
"""Business logic services."""
|
"""Business logic services."""
|
||||||
|
|
||||||
|
from .deal_service import DealService # noqa: F401
|
||||||
|
from .organization_service import ( # noqa: F401
|
||||||
|
OrganizationAccessDeniedError,
|
||||||
|
OrganizationContext,
|
||||||
|
OrganizationContextMissingError,
|
||||||
|
OrganizationService,
|
||||||
|
)
|
||||||
|
from .user_service import UserService # noqa: F401
|
||||||
|
from .auth_service import AuthService # noqa: F401
|
||||||
|
|
@ -0,0 +1,164 @@
|
||||||
|
"""Business logic for deals."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
|
||||||
|
from app.models.activity import Activity, ActivityType
|
||||||
|
from app.models.contact import Contact
|
||||||
|
from app.models.deal import Deal, DealCreate, DealStage, DealStatus
|
||||||
|
from app.models.organization_member import OrganizationRole
|
||||||
|
from app.repositories.deal_repo import DealRepository
|
||||||
|
from app.services.organization_service import OrganizationContext
|
||||||
|
|
||||||
|
|
||||||
|
STAGE_ORDER = {
|
||||||
|
stage: index
|
||||||
|
for index, stage in enumerate(
|
||||||
|
[
|
||||||
|
DealStage.QUALIFICATION,
|
||||||
|
DealStage.PROPOSAL,
|
||||||
|
DealStage.NEGOTIATION,
|
||||||
|
DealStage.CLOSED,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DealServiceError(Exception):
|
||||||
|
"""Base class for deal service errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class DealOrganizationMismatchError(DealServiceError):
|
||||||
|
"""Raised when attempting to use resources from another organization."""
|
||||||
|
|
||||||
|
|
||||||
|
class DealStageTransitionError(DealServiceError):
|
||||||
|
"""Raised when stage transition violates business rules."""
|
||||||
|
|
||||||
|
|
||||||
|
class DealStatusValidationError(DealServiceError):
|
||||||
|
"""Raised when invalid status transitions are requested."""
|
||||||
|
|
||||||
|
|
||||||
|
class ContactHasDealsError(DealServiceError):
|
||||||
|
"""Raised when attempting to delete a contact with active deals."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class DealUpdateData:
|
||||||
|
"""Structured container for deal update operations."""
|
||||||
|
|
||||||
|
status: DealStatus | None = None
|
||||||
|
stage: DealStage | None = None
|
||||||
|
amount: Decimal | None = None
|
||||||
|
currency: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DealService:
|
||||||
|
"""Encapsulates deal workflows and validations."""
|
||||||
|
|
||||||
|
def __init__(self, repository: DealRepository) -> None:
|
||||||
|
self._repository = repository
|
||||||
|
|
||||||
|
async def create_deal(self, data: DealCreate, *, context: OrganizationContext) -> Deal:
|
||||||
|
self._ensure_same_organization(data.organization_id, context)
|
||||||
|
await self._ensure_contact_in_organization(data.contact_id, context.organization_id)
|
||||||
|
return await self._repository.create(data=data, role=context.role, user_id=context.user_id)
|
||||||
|
|
||||||
|
async def update_deal(
|
||||||
|
self,
|
||||||
|
deal: Deal,
|
||||||
|
updates: DealUpdateData,
|
||||||
|
*,
|
||||||
|
context: OrganizationContext,
|
||||||
|
) -> Deal:
|
||||||
|
self._ensure_same_organization(deal.organization_id, context)
|
||||||
|
changes: dict[str, object] = {}
|
||||||
|
stage_activity: tuple[ActivityType, dict[str, str]] | None = None
|
||||||
|
status_activity: tuple[ActivityType, dict[str, str]] | None = None
|
||||||
|
|
||||||
|
if updates.amount is not None:
|
||||||
|
changes["amount"] = updates.amount
|
||||||
|
if updates.currency is not None:
|
||||||
|
changes["currency"] = updates.currency
|
||||||
|
|
||||||
|
if updates.stage is not None and updates.stage != deal.stage:
|
||||||
|
self._validate_stage_transition(deal.stage, updates.stage, context.role)
|
||||||
|
changes["stage"] = updates.stage
|
||||||
|
stage_activity = (
|
||||||
|
ActivityType.STAGE_CHANGED,
|
||||||
|
{"old_stage": deal.stage, "new_stage": updates.stage},
|
||||||
|
)
|
||||||
|
|
||||||
|
if updates.status is not None and updates.status != deal.status:
|
||||||
|
self._validate_status_transition(deal, updates)
|
||||||
|
changes["status"] = updates.status
|
||||||
|
status_activity = (
|
||||||
|
ActivityType.STATUS_CHANGED,
|
||||||
|
{"old_status": deal.status, "new_status": updates.status},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not changes:
|
||||||
|
return deal
|
||||||
|
|
||||||
|
updated = await self._repository.update(deal, changes, role=context.role, user_id=context.user_id)
|
||||||
|
await self._log_activities(
|
||||||
|
deal_id=deal.id,
|
||||||
|
author_id=context.user_id,
|
||||||
|
activities=[activity for activity in [stage_activity, status_activity] if activity],
|
||||||
|
)
|
||||||
|
return updated
|
||||||
|
|
||||||
|
async def ensure_contact_can_be_deleted(self, contact_id: int) -> None:
|
||||||
|
stmt = select(func.count()).select_from(Deal).where(Deal.contact_id == contact_id)
|
||||||
|
count = await self._repository.session.scalar(stmt)
|
||||||
|
if count and count > 0:
|
||||||
|
raise ContactHasDealsError("Contact has related deals and cannot be deleted")
|
||||||
|
|
||||||
|
async def _log_activities(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
deal_id: int,
|
||||||
|
author_id: int,
|
||||||
|
activities: Iterable[tuple[ActivityType, dict[str, str]]],
|
||||||
|
) -> None:
|
||||||
|
entries = list(activities)
|
||||||
|
if not entries:
|
||||||
|
return
|
||||||
|
for activity_type, payload in entries:
|
||||||
|
activity = Activity(deal_id=deal_id, author_id=author_id, type=activity_type, payload=payload)
|
||||||
|
self._repository.session.add(activity)
|
||||||
|
await self._repository.session.flush()
|
||||||
|
|
||||||
|
def _ensure_same_organization(self, organization_id: int, context: OrganizationContext) -> None:
|
||||||
|
if organization_id != context.organization_id:
|
||||||
|
raise DealOrganizationMismatchError("Operation targets a different organization")
|
||||||
|
|
||||||
|
async def _ensure_contact_in_organization(self, contact_id: int, organization_id: int) -> Contact:
|
||||||
|
contact = await self._repository.session.get(Contact, contact_id)
|
||||||
|
if contact is None or contact.organization_id != organization_id:
|
||||||
|
raise DealOrganizationMismatchError("Contact belongs to another organization")
|
||||||
|
return contact
|
||||||
|
|
||||||
|
def _validate_stage_transition(
|
||||||
|
self,
|
||||||
|
current_stage: DealStage,
|
||||||
|
new_stage: DealStage,
|
||||||
|
role: OrganizationRole,
|
||||||
|
) -> None:
|
||||||
|
if STAGE_ORDER[new_stage] < STAGE_ORDER[current_stage] and role not in {
|
||||||
|
OrganizationRole.OWNER,
|
||||||
|
OrganizationRole.ADMIN,
|
||||||
|
}:
|
||||||
|
raise DealStageTransitionError("Stage rollback requires owner or admin role")
|
||||||
|
|
||||||
|
def _validate_status_transition(self, deal: Deal, updates: DealUpdateData) -> None:
|
||||||
|
if updates.status != DealStatus.WON:
|
||||||
|
return
|
||||||
|
effective_amount = updates.amount if updates.amount is not None else deal.amount
|
||||||
|
if effective_amount is None or Decimal(effective_amount) <= Decimal("0"):
|
||||||
|
raise DealStatusValidationError("Amount must be greater than zero to mark a deal as won")
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
"""Organization-related business rules."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from app.models.organization import Organization
|
||||||
|
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||||
|
from app.repositories.org_repo import OrganizationRepository
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationServiceError(Exception):
|
||||||
|
"""Base class for organization service errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationContextMissingError(OrganizationServiceError):
|
||||||
|
"""Raised when the request lacks organization context."""
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationAccessDeniedError(OrganizationServiceError):
|
||||||
|
"""Raised when a user tries to work with a foreign organization."""
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationForbiddenError(OrganizationServiceError):
|
||||||
|
"""Raised when a user does not have enough privileges."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True, frozen=True)
|
||||||
|
class OrganizationContext:
|
||||||
|
"""Resolved organization and membership information for a request."""
|
||||||
|
|
||||||
|
organization: Organization
|
||||||
|
membership: OrganizationMember
|
||||||
|
|
||||||
|
@property
|
||||||
|
def organization_id(self) -> int:
|
||||||
|
return self.organization.id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def role(self) -> OrganizationRole:
|
||||||
|
return self.membership.role
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_id(self) -> int:
|
||||||
|
return self.membership.user_id
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationService:
|
||||||
|
"""Encapsulates organization-specific policies."""
|
||||||
|
|
||||||
|
def __init__(self, repository: OrganizationRepository) -> None:
|
||||||
|
self._repository = repository
|
||||||
|
|
||||||
|
async def get_context(self, *, user_id: int, organization_id: int | None) -> OrganizationContext:
|
||||||
|
"""Resolve request context ensuring the user belongs to the given organization."""
|
||||||
|
|
||||||
|
if organization_id is None:
|
||||||
|
raise OrganizationContextMissingError("X-Organization-Id header is required")
|
||||||
|
|
||||||
|
membership = await self._repository.get_membership(organization_id, user_id)
|
||||||
|
if membership is None or membership.organization is None:
|
||||||
|
raise OrganizationAccessDeniedError("Organization not found")
|
||||||
|
|
||||||
|
return OrganizationContext(organization=membership.organization, membership=membership)
|
||||||
|
|
||||||
|
def ensure_entity_in_context(self, *, entity_organization_id: int, context: OrganizationContext) -> None:
|
||||||
|
"""Make sure a resource belongs to the current organization."""
|
||||||
|
|
||||||
|
if entity_organization_id != context.organization_id:
|
||||||
|
raise OrganizationAccessDeniedError("Resource belongs to another organization")
|
||||||
|
|
||||||
|
def ensure_can_manage_settings(self, context: OrganizationContext) -> None:
|
||||||
|
"""Allow only owner/admin to change organization-level settings."""
|
||||||
|
|
||||||
|
if context.role not in {OrganizationRole.OWNER, OrganizationRole.ADMIN}:
|
||||||
|
raise OrganizationForbiddenError("Only owner/admin can modify organization settings")
|
||||||
|
|
||||||
|
def ensure_can_manage_entity(self, context: OrganizationContext) -> None:
|
||||||
|
"""Managers/admins/owners may manage entities; members are restricted."""
|
||||||
|
|
||||||
|
if context.role == OrganizationRole.MEMBER:
|
||||||
|
raise OrganizationForbiddenError("Members cannot manage shared entities")
|
||||||
|
|
||||||
|
def ensure_member_owns_entity(self, *, context: OrganizationContext, owner_id: int) -> None:
|
||||||
|
"""Members can only mutate entities they own (contacts/deals/tasks)."""
|
||||||
|
|
||||||
|
if context.role == OrganizationRole.MEMBER and owner_id != context.user_id:
|
||||||
|
raise OrganizationForbiddenError("Members can only modify their own records")
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Add stage_changed activity type."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "20251127_0002_stage_changed"
|
||||||
|
down_revision: str | None = "20251122_0001"
|
||||||
|
branch_labels: tuple[str, ...] | None = None
|
||||||
|
depends_on: tuple[str, ...] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute("ALTER TYPE activity_type ADD VALUE IF NOT EXISTS 'stage_changed';")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("UPDATE activities SET type = 'status_changed' WHERE type = 'stage_changed';")
|
||||||
|
op.execute("ALTER TYPE activity_type RENAME TO activity_type_old;")
|
||||||
|
op.execute(
|
||||||
|
"CREATE TYPE activity_type AS ENUM ('comment','status_changed','task_created','system');"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE activities ALTER COLUMN type TYPE activity_type USING type::text::activity_type;"
|
||||||
|
)
|
||||||
|
op.execute("DROP TYPE activity_type_old;")
|
||||||
|
|
@ -0,0 +1,244 @@
|
||||||
|
"""Unit tests for DealService."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from decimal import Decimal
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest # type: ignore[import-not-found]
|
||||||
|
import pytest_asyncio # type: ignore[import-not-found]
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.models.activity import Activity, ActivityType
|
||||||
|
from app.models.base import Base
|
||||||
|
from app.models.contact import Contact
|
||||||
|
from app.models.deal import DealCreate, DealStage, DealStatus
|
||||||
|
from app.models.organization import Organization
|
||||||
|
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||||
|
from app.models.user import User
|
||||||
|
from app.repositories.deal_repo import DealRepository
|
||||||
|
from app.services.deal_service import (
|
||||||
|
ContactHasDealsError,
|
||||||
|
DealOrganizationMismatchError,
|
||||||
|
DealService,
|
||||||
|
DealStageTransitionError,
|
||||||
|
DealStatusValidationError,
|
||||||
|
DealUpdateData,
|
||||||
|
)
|
||||||
|
from app.services.organization_service import OrganizationContext
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture()
|
||||||
|
async def session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
engine = create_async_engine(
|
||||||
|
"sqlite+aiosqlite:///:memory:",
|
||||||
|
future=True,
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_organization(name: str) -> Organization:
|
||||||
|
org = Organization(name=name)
|
||||||
|
return org
|
||||||
|
|
||||||
|
|
||||||
|
def _make_user(email_suffix: str) -> User:
|
||||||
|
return User(
|
||||||
|
email=f"user-{email_suffix}@example.com",
|
||||||
|
hashed_password="hashed",
|
||||||
|
name="Test User",
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_context(org: Organization, user: User, role: OrganizationRole) -> OrganizationContext:
|
||||||
|
membership = OrganizationMember(organization_id=org.id, user_id=user.id, role=role)
|
||||||
|
return OrganizationContext(organization=org, membership=membership)
|
||||||
|
|
||||||
|
|
||||||
|
async def _persist_base(session: AsyncSession, *, role: OrganizationRole = OrganizationRole.MANAGER) -> tuple[
|
||||||
|
OrganizationContext,
|
||||||
|
Contact,
|
||||||
|
DealRepository,
|
||||||
|
]:
|
||||||
|
org = _make_organization(name=f"Org-{uuid.uuid4()}"[:8])
|
||||||
|
user = _make_user(email_suffix=str(uuid.uuid4())[:8])
|
||||||
|
session.add_all([org, user])
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
contact = Contact(
|
||||||
|
organization_id=org.id,
|
||||||
|
owner_id=user.id,
|
||||||
|
name="John Doe",
|
||||||
|
email="john@example.com",
|
||||||
|
)
|
||||||
|
session.add(contact)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
context = _make_context(org, user, role)
|
||||||
|
repo = DealRepository(session=session)
|
||||||
|
return context, contact, repo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_deal_rejects_foreign_contact(session: AsyncSession) -> None:
|
||||||
|
context, contact, repo = await _persist_base(session)
|
||||||
|
|
||||||
|
other_org = _make_organization(name="Other")
|
||||||
|
other_user = _make_user(email_suffix="other")
|
||||||
|
session.add_all([other_org, other_user])
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
service = DealService(repository=repo)
|
||||||
|
payload = DealCreate(
|
||||||
|
organization_id=other_org.id,
|
||||||
|
contact_id=contact.id,
|
||||||
|
owner_id=context.user_id,
|
||||||
|
title="Website Redesign",
|
||||||
|
amount=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
other_context = _make_context(other_org, other_user, OrganizationRole.MANAGER)
|
||||||
|
|
||||||
|
with pytest.raises(DealOrganizationMismatchError):
|
||||||
|
await service.create_deal(payload, context=other_context)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stage_rollback_requires_admin(session: AsyncSession) -> None:
|
||||||
|
context, contact, repo = await _persist_base(session, role=OrganizationRole.MANAGER)
|
||||||
|
service = DealService(repository=repo)
|
||||||
|
|
||||||
|
deal = await service.create_deal(
|
||||||
|
DealCreate(
|
||||||
|
organization_id=context.organization_id,
|
||||||
|
contact_id=contact.id,
|
||||||
|
owner_id=context.user_id,
|
||||||
|
title="Migration",
|
||||||
|
amount=Decimal("5000"),
|
||||||
|
),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
deal.stage = DealStage.PROPOSAL
|
||||||
|
|
||||||
|
with pytest.raises(DealStageTransitionError):
|
||||||
|
await service.update_deal(
|
||||||
|
deal,
|
||||||
|
DealUpdateData(stage=DealStage.QUALIFICATION),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stage_rollback_allowed_for_admin(session: AsyncSession) -> None:
|
||||||
|
context, contact, repo = await _persist_base(session, role=OrganizationRole.ADMIN)
|
||||||
|
service = DealService(repository=repo)
|
||||||
|
|
||||||
|
deal = await service.create_deal(
|
||||||
|
DealCreate(
|
||||||
|
organization_id=context.organization_id,
|
||||||
|
contact_id=contact.id,
|
||||||
|
owner_id=context.user_id,
|
||||||
|
title="Rollout",
|
||||||
|
amount=Decimal("1000"),
|
||||||
|
),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
deal.stage = DealStage.NEGOTIATION
|
||||||
|
|
||||||
|
updated = await service.update_deal(
|
||||||
|
deal,
|
||||||
|
DealUpdateData(stage=DealStage.PROPOSAL),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated.stage == DealStage.PROPOSAL
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_status_won_requires_positive_amount(session: AsyncSession) -> None:
|
||||||
|
context, contact, repo = await _persist_base(session)
|
||||||
|
service = DealService(repository=repo)
|
||||||
|
|
||||||
|
deal = await service.create_deal(
|
||||||
|
DealCreate(
|
||||||
|
organization_id=context.organization_id,
|
||||||
|
contact_id=contact.id,
|
||||||
|
owner_id=context.user_id,
|
||||||
|
title="Zero",
|
||||||
|
amount=None,
|
||||||
|
),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(DealStatusValidationError):
|
||||||
|
await service.update_deal(
|
||||||
|
deal,
|
||||||
|
DealUpdateData(status=DealStatus.WON),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_updates_create_activity_records(session: AsyncSession) -> None:
|
||||||
|
context, contact, repo = await _persist_base(session)
|
||||||
|
service = DealService(repository=repo)
|
||||||
|
|
||||||
|
deal = await service.create_deal(
|
||||||
|
DealCreate(
|
||||||
|
organization_id=context.organization_id,
|
||||||
|
contact_id=contact.id,
|
||||||
|
owner_id=context.user_id,
|
||||||
|
title="Activity",
|
||||||
|
amount=Decimal("100"),
|
||||||
|
),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
await service.update_deal(
|
||||||
|
deal,
|
||||||
|
DealUpdateData(
|
||||||
|
stage=DealStage.PROPOSAL,
|
||||||
|
status=DealStatus.WON,
|
||||||
|
amount=Decimal("5000"),
|
||||||
|
),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await session.scalars(select(Activity).where(Activity.deal_id == deal.id))
|
||||||
|
activity_types = {activity.type for activity in result.all()}
|
||||||
|
assert ActivityType.STAGE_CHANGED in activity_types
|
||||||
|
assert ActivityType.STATUS_CHANGED in activity_types
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_contact_delete_guard(session: AsyncSession) -> None:
|
||||||
|
context, contact, repo = await _persist_base(session)
|
||||||
|
service = DealService(repository=repo)
|
||||||
|
|
||||||
|
deal = await service.create_deal(
|
||||||
|
DealCreate(
|
||||||
|
organization_id=context.organization_id,
|
||||||
|
contact_id=contact.id,
|
||||||
|
owner_id=context.user_id,
|
||||||
|
title="To Delete",
|
||||||
|
amount=Decimal("100"),
|
||||||
|
),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ContactHasDealsError):
|
||||||
|
await service.ensure_contact_can_be_deleted(contact.id)
|
||||||
|
|
||||||
|
await session.delete(deal)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
await service.ensure_contact_can_be_deleted(contact.id)
|
||||||
|
|
@ -0,0 +1,99 @@
|
||||||
|
"""Unit tests for OrganizationService."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest # type: ignore[import-not-found]
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.organization import Organization
|
||||||
|
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||||
|
from app.repositories.org_repo import OrganizationRepository
|
||||||
|
from app.services.organization_service import (
|
||||||
|
OrganizationAccessDeniedError,
|
||||||
|
OrganizationContext,
|
||||||
|
OrganizationContextMissingError,
|
||||||
|
OrganizationForbiddenError,
|
||||||
|
OrganizationService,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StubOrganizationRepository(OrganizationRepository):
|
||||||
|
"""Simple in-memory stand-in for OrganizationRepository."""
|
||||||
|
|
||||||
|
def __init__(self, membership: OrganizationMember | None) -> None:
|
||||||
|
super().__init__(session=MagicMock(spec=AsyncSession))
|
||||||
|
self._membership = membership
|
||||||
|
|
||||||
|
async def get_membership(self, organization_id: int, user_id: int) -> OrganizationMember | None: # pragma: no cover - helper
|
||||||
|
if (
|
||||||
|
self._membership
|
||||||
|
and self._membership.organization_id == organization_id
|
||||||
|
and self._membership.user_id == user_id
|
||||||
|
):
|
||||||
|
return self._membership
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def make_membership(role: OrganizationRole, *, organization_id: int = 1, user_id: int = 10) -> OrganizationMember:
|
||||||
|
organization = Organization(name="Acme Inc")
|
||||||
|
organization.id = organization_id
|
||||||
|
membership = OrganizationMember(
|
||||||
|
organization_id=organization_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role=role,
|
||||||
|
)
|
||||||
|
membership.organization = organization
|
||||||
|
return membership
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_context_success() -> None:
|
||||||
|
membership = make_membership(OrganizationRole.MANAGER)
|
||||||
|
service = OrganizationService(StubOrganizationRepository(membership))
|
||||||
|
|
||||||
|
context = await service.get_context(user_id=membership.user_id, organization_id=membership.organization_id)
|
||||||
|
|
||||||
|
assert context.organization_id == membership.organization_id
|
||||||
|
assert context.role == OrganizationRole.MANAGER
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_context_missing_header() -> None:
|
||||||
|
service = OrganizationService(StubOrganizationRepository(None))
|
||||||
|
|
||||||
|
with pytest.raises(OrganizationContextMissingError):
|
||||||
|
await service.get_context(user_id=1, organization_id=None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_context_access_denied() -> None:
|
||||||
|
service = OrganizationService(StubOrganizationRepository(None))
|
||||||
|
|
||||||
|
with pytest.raises(OrganizationAccessDeniedError):
|
||||||
|
await service.get_context(user_id=1, organization_id=99)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_can_manage_settings_blocks_manager() -> None:
|
||||||
|
membership = make_membership(OrganizationRole.MANAGER)
|
||||||
|
organization = membership.organization
|
||||||
|
assert organization is not None
|
||||||
|
context = OrganizationContext(organization=organization, membership=membership)
|
||||||
|
service = OrganizationService(StubOrganizationRepository(membership))
|
||||||
|
|
||||||
|
with pytest.raises(OrganizationForbiddenError):
|
||||||
|
service.ensure_can_manage_settings(context)
|
||||||
|
|
||||||
|
|
||||||
|
def test_member_must_own_entity() -> None:
|
||||||
|
membership = make_membership(OrganizationRole.MEMBER)
|
||||||
|
organization = membership.organization
|
||||||
|
assert organization is not None
|
||||||
|
context = OrganizationContext(organization=organization, membership=membership)
|
||||||
|
service = OrganizationService(StubOrganizationRepository(membership))
|
||||||
|
|
||||||
|
with pytest.raises(OrganizationForbiddenError):
|
||||||
|
service.ensure_member_owns_entity(context=context, owner_id=999)
|
||||||
|
|
||||||
|
# Same owner should pass silently.
|
||||||
|
service.ensure_member_owns_entity(context=context, owner_id=membership.user_id)
|
||||||
Loading…
Reference in New Issue