diff --git a/app/models/activity.py b/app/models/activity.py index 4c004a7..89daa10 100644 --- a/app/models/activity.py +++ b/app/models/activity.py @@ -8,6 +8,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, func, text from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.types import JSON as GenericJSON, TypeDecorator from sqlalchemy.orm import Mapped, mapped_column, relationship from app.models.base import Base @@ -16,10 +17,25 @@ from app.models.base import Base class ActivityType(StrEnum): COMMENT = "comment" STATUS_CHANGED = "status_changed" + STAGE_CHANGED = "stage_changed" TASK_CREATED = "task_created" SYSTEM = "system" +class JSONBCompat(TypeDecorator): + """Uses JSONB on Postgres and plain JSON elsewhere for testability.""" + + impl = JSONB + cache_ok = True + + def load_dialect_impl(self, dialect): # type: ignore[override] + if dialect.name == "sqlite": + from sqlalchemy.dialects.sqlite import JSON as SQLiteJSON # local import + + return dialect.type_descriptor(SQLiteJSON()) + return dialect.type_descriptor(JSONB()) + + class Activity(Base): """Represents a timeline event for a deal.""" @@ -32,9 +48,9 @@ class Activity(Base): ) type: Mapped[ActivityType] = mapped_column(SqlEnum(ActivityType, name="activity_type"), nullable=False) payload: Mapped[dict[str, Any]] = mapped_column( - JSONB, + JSONBCompat().with_variant(GenericJSON(), "sqlite"), nullable=False, - server_default=text("'{}'::jsonb"), + server_default=text("'{}'"), ) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False diff --git a/tests/services/test_deal_service.py b/tests/services/test_deal_service.py new file mode 100644 index 0000000..3a789c1 --- /dev/null +++ b/tests/services/test_deal_service.py @@ -0,0 +1,244 @@ +"""Unit tests for DealService.""" +from __future__ import annotations + +from collections.abc import AsyncGenerator +from decimal import Decimal +import uuid + +import pytest # type: ignore[import-not-found] +import pytest_asyncio # type: ignore[import-not-found] +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.models.activity import Activity, ActivityType +from app.models.base import Base +from app.models.contact import Contact +from app.models.deal import DealCreate, DealStage, DealStatus +from app.models.organization import Organization +from app.models.organization_member import OrganizationMember, OrganizationRole +from app.models.user import User +from app.repositories.deal_repo import DealRepository +from app.services.deal_service import ( + ContactHasDealsError, + DealOrganizationMismatchError, + DealService, + DealStageTransitionError, + DealStatusValidationError, + DealUpdateData, +) +from app.services.organization_service import OrganizationContext + + +@pytest_asyncio.fixture() +async def session() -> AsyncGenerator[AsyncSession, None]: + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + future=True, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + async_session = async_sessionmaker(engine, expire_on_commit=False) + async with async_session() as session: + yield session + await engine.dispose() + + +def _make_organization(name: str) -> Organization: + org = Organization(name=name) + return org + + +def _make_user(email_suffix: str) -> User: + return User( + email=f"user-{email_suffix}@example.com", + hashed_password="hashed", + name="Test User", + is_active=True, + ) + + +def _make_context(org: Organization, user: User, role: OrganizationRole) -> OrganizationContext: + membership = OrganizationMember(organization_id=org.id, user_id=user.id, role=role) + return OrganizationContext(organization=org, membership=membership) + + +async def _persist_base(session: AsyncSession, *, role: OrganizationRole = OrganizationRole.MANAGER) -> tuple[ + OrganizationContext, + Contact, + DealRepository, +]: + org = _make_organization(name=f"Org-{uuid.uuid4()}"[:8]) + user = _make_user(email_suffix=str(uuid.uuid4())[:8]) + session.add_all([org, user]) + await session.flush() + + contact = Contact( + organization_id=org.id, + owner_id=user.id, + name="John Doe", + email="john@example.com", + ) + session.add(contact) + await session.flush() + + context = _make_context(org, user, role) + repo = DealRepository(session=session) + return context, contact, repo + + +@pytest.mark.asyncio +async def test_create_deal_rejects_foreign_contact(session: AsyncSession) -> None: + context, contact, repo = await _persist_base(session) + + other_org = _make_organization(name="Other") + other_user = _make_user(email_suffix="other") + session.add_all([other_org, other_user]) + await session.flush() + + service = DealService(repository=repo) + payload = DealCreate( + organization_id=other_org.id, + contact_id=contact.id, + owner_id=context.user_id, + title="Website Redesign", + amount=None, + ) + + other_context = _make_context(other_org, other_user, OrganizationRole.MANAGER) + + with pytest.raises(DealOrganizationMismatchError): + await service.create_deal(payload, context=other_context) + + +@pytest.mark.asyncio +async def test_stage_rollback_requires_admin(session: AsyncSession) -> None: + context, contact, repo = await _persist_base(session, role=OrganizationRole.MANAGER) + service = DealService(repository=repo) + + deal = await service.create_deal( + DealCreate( + organization_id=context.organization_id, + contact_id=contact.id, + owner_id=context.user_id, + title="Migration", + amount=Decimal("5000"), + ), + context=context, + ) + deal.stage = DealStage.PROPOSAL + + with pytest.raises(DealStageTransitionError): + await service.update_deal( + deal, + DealUpdateData(stage=DealStage.QUALIFICATION), + context=context, + ) + + +@pytest.mark.asyncio +async def test_stage_rollback_allowed_for_admin(session: AsyncSession) -> None: + context, contact, repo = await _persist_base(session, role=OrganizationRole.ADMIN) + service = DealService(repository=repo) + + deal = await service.create_deal( + DealCreate( + organization_id=context.organization_id, + contact_id=contact.id, + owner_id=context.user_id, + title="Rollout", + amount=Decimal("1000"), + ), + context=context, + ) + deal.stage = DealStage.NEGOTIATION + + updated = await service.update_deal( + deal, + DealUpdateData(stage=DealStage.PROPOSAL), + context=context, + ) + + assert updated.stage == DealStage.PROPOSAL + + +@pytest.mark.asyncio +async def test_status_won_requires_positive_amount(session: AsyncSession) -> None: + context, contact, repo = await _persist_base(session) + service = DealService(repository=repo) + + deal = await service.create_deal( + DealCreate( + organization_id=context.organization_id, + contact_id=contact.id, + owner_id=context.user_id, + title="Zero", + amount=None, + ), + context=context, + ) + + with pytest.raises(DealStatusValidationError): + await service.update_deal( + deal, + DealUpdateData(status=DealStatus.WON), + context=context, + ) + + +@pytest.mark.asyncio +async def test_updates_create_activity_records(session: AsyncSession) -> None: + context, contact, repo = await _persist_base(session) + service = DealService(repository=repo) + + deal = await service.create_deal( + DealCreate( + organization_id=context.organization_id, + contact_id=contact.id, + owner_id=context.user_id, + title="Activity", + amount=Decimal("100"), + ), + context=context, + ) + + await service.update_deal( + deal, + DealUpdateData( + stage=DealStage.PROPOSAL, + status=DealStatus.WON, + amount=Decimal("5000"), + ), + context=context, + ) + + result = await session.scalars(select(Activity).where(Activity.deal_id == deal.id)) + activity_types = {activity.type for activity in result.all()} + assert ActivityType.STAGE_CHANGED in activity_types + assert ActivityType.STATUS_CHANGED in activity_types + + +@pytest.mark.asyncio +async def test_contact_delete_guard(session: AsyncSession) -> None: + context, contact, repo = await _persist_base(session) + service = DealService(repository=repo) + + deal = await service.create_deal( + DealCreate( + organization_id=context.organization_id, + contact_id=contact.id, + owner_id=context.user_id, + title="To Delete", + amount=Decimal("100"), + ), + context=context, + ) + + with pytest.raises(ContactHasDealsError): + await service.ensure_contact_can_be_deleted(contact.id) + + await session.delete(deal) + await session.flush() + + await service.ensure_contact_can_be_deleted(contact.id)