Refactor code for improved readability and consistency
Test / test (push) Successful in 15s Details

- Reformatted function signatures in `organization_service.py` and `task_service.py` for better alignment.
- Updated import statements across multiple files for consistency and organization.
- Enhanced test files by improving formatting and ensuring consistent use of async session factories.
- Added type hints and improved type safety in various service and test files.
- Adjusted `pyproject.toml` to include configuration for isort, mypy, and ruff for better code quality checks.
- Cleaned up unused imports and organized existing ones in several test files.
This commit is contained in:
Artem Kashaev 2025-12-01 16:18:03 +05:00
parent eecb74c523
commit 5fcb574aca
62 changed files with 765 additions and 476 deletions

View File

@ -4,7 +4,7 @@
## Стек и особенности ## Стек и особенности
- Python 3.10+, FastAPI, SQLAlchemy Async ORM, Alembic. - Python 3.14, FastAPI, SQLAlchemy Async ORM, Alembic.
- Pydantic Settings для конфигурации, JWT access/refresh токены, кеш аналитики в Redis. - Pydantic Settings для конфигурации, JWT access/refresh токены, кеш аналитики в Redis.
- Frontend: Vite + React + TypeScript (см. `frontend/`). - Frontend: Vite + React + TypeScript (см. `frontend/`).
- Докер-окружение для разработки (`docker-compose-dev.yml`) и деплоя (`docker-compose-ci.yml`). - Докер-окружение для разработки (`docker-compose-dev.yml`) и деплоя (`docker-compose-ci.yml`).
@ -80,10 +80,10 @@ cp .env.example .env
uv sync uv sync
# 3. Применяем миграции # 3. Применяем миграции
uv run alembic upgrade head uvx alembic upgrade head
# 4. Запускаем API # 4. Запускаем API
uv run uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 uvx uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
``` ```
PostgreSQL/Redis можно поднять вручную или командой `docker compose -f docker-compose-dev.yml up postgres redis -d`. PostgreSQL/Redis можно поднять вручную или командой `docker compose -f docker-compose-dev.yml up postgres redis -d`.
@ -147,14 +147,23 @@ docker compose -f docker-compose-dev.yml up --build -d
Все тесты находятся в каталоге `tests/` (unit на бизнес-правила и интеграционные сценарии API). Запуск: Все тесты находятся в каталоге `tests/` (unit на бизнес-правила и интеграционные сценарии API). Запуск:
```bash ```bash
uv run pytest uvx pytest
``` ```
Полезные варианты: Полезные варианты:
- Запустить только юнит-тесты сервисов: `uv run pytest tests/services -k service`. - Запустить только юнит-тесты сервисов: `uvx pytest tests/services -k service`.
- Запустить конкретный сценарий API: `uv run pytest tests/api/v1/test_deals.py -k won`. - Запустить конкретный сценарий API: `uvx pytest tests/api/v1/test_deals.py -k won`.
Перед деплоем рекомендуется прогонять миграции на чистой БД и выполнять `uv run pytest` для проверки правил ролей/стадий. Перед деплоем рекомендуется прогонять миграции на чистой БД и выполнять `uvx pytest` для проверки правил ролей/стадий.
## Линтинг и статический анализ
- `uvx ruff check app tests` — основной линтер (PEP8, сортировка импортов, дополнительные правила).
- `uvx ruff format app tests` — автоформатирование (аналог black) для единообразного стиля.
- `uvx isort .` — отдельная сортировка импортов (профиль `black`).
- `uvx mypy app services tests` — статическая проверка типов (строгий режим + плагин pydantic).
В CI/PR рекомендуется запускать команды именно в этом порядке, чтобы быстрее находить проблемы.

View File

@ -1,9 +1,11 @@
"""Reusable FastAPI dependencies.""" """Reusable FastAPI dependencies."""
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
import jwt import jwt
from fastapi import Depends, Header, HTTPException, status from fastapi import Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from redis.asyncio.client import Redis
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.cache import get_cache_client from app.core.cache import get_cache_client
@ -18,9 +20,9 @@ from app.repositories.deal_repo import DealRepository
from app.repositories.org_repo import OrganizationRepository from app.repositories.org_repo import OrganizationRepository
from app.repositories.task_repo import TaskRepository from app.repositories.task_repo import TaskRepository
from app.repositories.user_repo import UserRepository from app.repositories.user_repo import UserRepository
from app.services.activity_service import ActivityService
from app.services.analytics_service import AnalyticsService from app.services.analytics_service import AnalyticsService
from app.services.auth_service import AuthService from app.services.auth_service import AuthService
from app.services.activity_service import ActivityService
from app.services.contact_service import ContactService from app.services.contact_service import ContactService
from app.services.deal_service import DealService from app.services.deal_service import DealService
from app.services.organization_service import ( from app.services.organization_service import (
@ -30,7 +32,6 @@ from app.services.organization_service import (
OrganizationService, OrganizationService,
) )
from app.services.task_service import TaskService from app.services.task_service import TaskService
from redis.asyncio.client import Redis
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token")
@ -45,7 +46,9 @@ def get_user_repository(session: AsyncSession = Depends(get_db_session)) -> User
return UserRepository(session=session) return UserRepository(session=session)
def get_organization_repository(session: AsyncSession = Depends(get_db_session)) -> OrganizationRepository: def get_organization_repository(
session: AsyncSession = Depends(get_db_session),
) -> OrganizationRepository:
return OrganizationRepository(session=session) return OrganizationRepository(session=session)
@ -65,7 +68,9 @@ def get_activity_repository(session: AsyncSession = Depends(get_db_session)) ->
return ActivityRepository(session=session) return ActivityRepository(session=session)
def get_analytics_repository(session: AsyncSession = Depends(get_db_session)) -> AnalyticsRepository: def get_analytics_repository(
session: AsyncSession = Depends(get_db_session),
) -> AnalyticsRepository:
return AnalyticsRepository(session=session) return AnalyticsRepository(session=session)

View File

@ -1,14 +1,15 @@
"""Root API router that aggregates versioned routers.""" """Root API router that aggregates versioned routers."""
from fastapi import APIRouter from fastapi import APIRouter
from app.api.v1 import ( from app.api.v1 import (
activities, activities,
analytics, analytics,
auth, auth,
contacts, contacts,
deals, deals,
organizations, organizations,
tasks, tasks,
) )
from app.core.config import settings from app.core.config import settings

View File

@ -1,20 +1,21 @@
"""Version 1 API routers.""" """Version 1 API routers."""
from . import ( from . import (
activities, activities,
analytics, analytics,
auth, auth,
contacts, contacts,
deals, deals,
organizations, organizations,
tasks, tasks,
) )
__all__ = [ __all__ = [
"activities", "activities",
"analytics", "analytics",
"auth", "auth",
"contacts", "contacts",
"deals", "deals",
"organizations", "organizations",
"tasks", "tasks",
] ]

View File

@ -1,4 +1,5 @@
"""Activity timeline endpoints and payload schemas.""" """Activity timeline endpoints and payload schemas."""
from __future__ import annotations from __future__ import annotations
from typing import Literal from typing import Literal

View File

@ -1,4 +1,5 @@
"""Analytics API endpoints for summaries and funnels.""" """Analytics API endpoints for summaries and funnels."""
from __future__ import annotations from __future__ import annotations
from decimal import Decimal from decimal import Decimal
@ -16,6 +17,7 @@ def _decimal_to_str(value: Decimal) -> str:
normalized = value.normalize() normalized = value.normalize()
return format(normalized, "f") return format(normalized, "f")
router = APIRouter(prefix="/analytics", tags=["analytics"]) router = APIRouter(prefix="/analytics", tags=["analytics"])
@ -92,4 +94,6 @@ async def deals_funnel(
"""Return funnel breakdown by stages and statuses.""" """Return funnel breakdown by stages and statuses."""
breakdowns: list[StageBreakdown] = await service.get_deal_funnel(context.organization_id) breakdowns: list[StageBreakdown] = await service.get_deal_funnel(context.organization_id)
return DealFunnelResponse(stages=[StageBreakdownModel.model_validate(item) for item in breakdowns]) return DealFunnelResponse(
stages=[StageBreakdownModel.model_validate(item) for item in breakdowns]
)

View File

@ -1,8 +1,9 @@
"""Authentication API endpoints and payloads.""" """Authentication API endpoints and payloads."""
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, EmailStr
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, EmailStr
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -41,7 +42,7 @@ async def register_user(
organization: Organization | None = None organization: Organization | None = None
if payload.organization_name: if payload.organization_name:
existing_org = await repo.session.scalar( existing_org = await repo.session.scalar(
select(Organization).where(Organization.name == payload.organization_name) select(Organization).where(Organization.name == payload.organization_name),
) )
if existing_org is not None: if existing_org is not None:
raise HTTPException( raise HTTPException(

View File

@ -1,4 +1,5 @@
"""Contact API endpoints.""" """Contact API endpoints."""
from __future__ import annotations from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
@ -81,7 +82,9 @@ async def create_contact(
context: OrganizationContext = Depends(get_organization_context), context: OrganizationContext = Depends(get_organization_context),
service: ContactService = Depends(get_contact_service), service: ContactService = Depends(get_contact_service),
) -> ContactRead: ) -> ContactRead:
data = payload.to_domain(organization_id=context.organization_id, fallback_owner=context.user_id) data = payload.to_domain(
organization_id=context.organization_id, fallback_owner=context.user_id
)
try: try:
contact = await service.create_contact(data, context=context) contact = await service.create_contact(data, context=context)
except ContactForbiddenError as exc: except ContactForbiddenError as exc:

View File

@ -1,4 +1,5 @@
"""Deal API endpoints backed by DealService with inline payload schemas.""" """Deal API endpoints backed by DealService with inline payload schemas."""
from __future__ import annotations from __future__ import annotations
from decimal import Decimal from decimal import Decimal
@ -8,7 +9,7 @@ from pydantic import BaseModel
from app.api.deps import get_deal_repository, get_deal_service, get_organization_context from app.api.deps import get_deal_repository, get_deal_service, get_organization_context
from app.models.deal import DealCreate, DealRead, DealStage, DealStatus from app.models.deal import DealCreate, DealRead, DealStage, DealStatus
from app.repositories.deal_repo import DealRepository, DealAccessError, DealQueryParams from app.repositories.deal_repo import DealAccessError, DealQueryParams, DealRepository
from app.services.deal_service import ( from app.services.deal_service import (
DealService, DealService,
DealStageTransitionError, DealStageTransitionError,
@ -66,7 +67,9 @@ async def list_deals(
statuses_value = [DealStatus(value) for value in status_filter] if status_filter else None statuses_value = [DealStatus(value) for value in status_filter] if status_filter else None
stage_value = DealStage(stage) if stage else None stage_value = DealStage(stage) if stage else None
except ValueError as exc: except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid deal filter") from exc raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid deal filter"
) from exc
params = DealQueryParams( params = DealQueryParams(
organization_id=context.organization_id, organization_id=context.organization_id,
@ -96,7 +99,9 @@ async def create_deal(
) -> DealRead: ) -> DealRead:
"""Create a new deal within the current organization.""" """Create a new deal within the current organization."""
data = payload.to_domain(organization_id=context.organization_id, fallback_owner=context.user_id) data = payload.to_domain(
organization_id=context.organization_id, fallback_owner=context.user_id
)
try: try:
deal = await service.create_deal(data, context=context) deal = await service.create_deal(data, context=context)
except DealAccessError as exc: except DealAccessError as exc:

View File

@ -1,4 +1,5 @@
"""Organization-related API endpoints.""" """Organization-related API endpoints."""
from __future__ import annotations from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status

View File

@ -1,4 +1,5 @@
"""Task API endpoints with inline schemas.""" """Task API endpoints with inline schemas."""
from __future__ import annotations from __future__ import annotations
from datetime import date, datetime, time, timezone from datetime import date, datetime, time, timezone

View File

@ -1,17 +1,18 @@
"""Redis cache utilities and availability tracking.""" """Redis cache utilities and availability tracking."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
from typing import Any, Awaitable, Callable, Optional from collections.abc import Awaitable, Callable
from typing import Any
import redis.asyncio as redis import redis.asyncio as redis
from app.core.config import settings
from redis.asyncio.client import Redis from redis.asyncio.client import Redis
from redis.exceptions import RedisError from redis.exceptions import RedisError
from app.core.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,7 +45,9 @@ class RedisCacheManager:
async with self._lock: async with self._lock:
if self._client is not None: if self._client is not None:
return return
self._client = redis.from_url(settings.redis_url, encoding="utf-8", decode_responses=False) self._client = redis.from_url(
settings.redis_url, encoding="utf-8", decode_responses=False
)
await self._refresh_availability() await self._refresh_availability()
async def shutdown(self) -> None: async def shutdown(self) -> None:
@ -59,7 +62,9 @@ class RedisCacheManager:
return return
async with self._lock: async with self._lock:
if self._client is None: if self._client is None:
self._client = redis.from_url(settings.redis_url, encoding="utf-8", decode_responses=False) self._client = redis.from_url(
settings.redis_url, encoding="utf-8", decode_responses=False
)
await self._refresh_availability() await self._refresh_availability()
async def _refresh_availability(self) -> None: async def _refresh_availability(self) -> None:
@ -95,7 +100,7 @@ async def shutdown_cache() -> None:
await cache_manager.shutdown() await cache_manager.shutdown()
def get_cache_client() -> Optional[Redis]: def get_cache_client() -> Redis | None:
"""Expose the active Redis client for dependency injection.""" """Expose the active Redis client for dependency injection."""
return cache_manager.get_client() return cache_manager.get_client()
@ -113,12 +118,17 @@ async def read_json(client: Redis, key: str) -> Any | None:
cache_manager.mark_available() cache_manager.mark_available()
try: try:
return json.loads(raw.decode("utf-8")) return json.loads(raw.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as exc: # pragma: no cover - malformed payloads except (
UnicodeDecodeError,
json.JSONDecodeError,
) as exc: # pragma: no cover - malformed payloads
logger.warning("Discarding malformed cache entry %s: %s", key, exc) logger.warning("Discarding malformed cache entry %s: %s", key, exc)
return None return None
async def write_json(client: Redis, key: str, value: Any, ttl_seconds: int, backoff_ms: int) -> None: async def write_json(
client: Redis, key: str, value: Any, ttl_seconds: int, backoff_ms: int
) -> None:
"""Serialize data to JSON and store it with TTL using retry/backoff.""" """Serialize data to JSON and store it with TTL using retry/backoff."""
payload = json.dumps(value, separators=(",", ":"), ensure_ascii=True).encode("utf-8") payload = json.dumps(value, separators=(",", ":"), ensure_ascii=True).encode("utf-8")

View File

@ -1,4 +1,5 @@
"""Application settings using Pydantic Settings.""" """Application settings using Pydantic Settings."""
from pydantic import Field, SecretStr from pydantic import Field, SecretStr
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@ -15,7 +16,9 @@ class Settings(BaseSettings):
db_port: int = Field(default=5432, description="Database port") db_port: int = Field(default=5432, description="Database port")
db_name: str = Field(default="test_task_crm", description="Database name") db_name: str = Field(default="test_task_crm", description="Database name")
db_user: str = Field(default="postgres", description="Database user") db_user: str = Field(default="postgres", description="Database user")
db_password: SecretStr = Field(default=SecretStr("postgres"), description="Database user password") db_password: SecretStr = Field(
default=SecretStr("postgres"), description="Database user password"
)
database_url_override: str | None = Field( database_url_override: str | None = Field(
default=None, default=None,
alias="DATABASE_URL", alias="DATABASE_URL",
@ -28,7 +31,9 @@ class Settings(BaseSettings):
refresh_token_expire_days: int = 7 refresh_token_expire_days: int = 7
redis_enabled: bool = Field(default=False, description="Toggle Redis-backed cache usage") redis_enabled: bool = Field(default=False, description="Toggle Redis-backed cache usage")
redis_url: str = Field(default="redis://localhost:6379/0", description="Redis connection URL") redis_url: str = Field(default="redis://localhost:6379/0", description="Redis connection URL")
analytics_cache_ttl_seconds: int = Field(default=120, ge=1, description="TTL for cached analytics responses") analytics_cache_ttl_seconds: int = Field(
default=120, ge=1, description="TTL for cached analytics responses"
)
analytics_cache_backoff_ms: int = Field( analytics_cache_backoff_ms: int = Field(
default=200, default=200,
ge=0, ge=0,

View File

@ -1,11 +1,11 @@
"""Database utilities for async SQLAlchemy engine and sessions.""" """Database utilities for async SQLAlchemy engine and sessions."""
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.core.config import settings from app.core.config import settings
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
engine = create_async_engine(settings.database_url, echo=settings.sqlalchemy_echo) engine = create_async_engine(settings.database_url, echo=settings.sqlalchemy_echo)
AsyncSessionMaker = async_sessionmaker(bind=engine, expire_on_commit=False) AsyncSessionMaker = async_sessionmaker(bind=engine, expire_on_commit=False)

View File

@ -1,11 +1,12 @@
"""Middleware that logs cache availability transitions.""" """Middleware that logs cache availability transitions."""
from __future__ import annotations from __future__ import annotations
import logging import logging
from starlette.types import ASGIApp, Receive, Scope, Send
from app.core.cache import cache_manager from app.core.cache import cache_manager
from app.core.config import settings from app.core.config import settings
from starlette.types import ASGIApp, Receive, Scope, Send
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,13 +1,14 @@
"""Security helpers for hashing passwords and issuing JWT tokens.""" """Security helpers for hashing passwords and issuing JWT tokens."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Mapping from typing import Any
import jwt import jwt
from passlib.context import CryptContext # type: ignore
from app.core.config import settings from app.core.config import settings
from passlib.context import CryptContext # type: ignore
class PasswordHasher: class PasswordHasher:

View File

@ -7,6 +7,7 @@ from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
@ -14,13 +15,12 @@ from app.api.routes import api_router
from app.core.cache import init_cache, shutdown_cache from app.core.cache import init_cache, shutdown_cache
from app.core.config import settings from app.core.config import settings
from app.core.middleware.cache_monitor import CacheAvailabilityMiddleware from app.core.middleware.cache_monitor import CacheAvailabilityMiddleware
from fastapi.middleware.cors import CORSMiddleware
PROJECT_ROOT = Path(__file__).resolve().parent.parent PROJECT_ROOT = Path(__file__).resolve().parent.parent
FRONTEND_DIST = PROJECT_ROOT / "frontend" / "dist" FRONTEND_DIST = PROJECT_ROOT / "frontend" / "dist"
FRONTEND_INDEX = FRONTEND_DIST / "index.html" FRONTEND_INDEX = FRONTEND_DIST / "index.html"
def create_app() -> FastAPI: def create_app() -> FastAPI:
"""Build FastAPI application instance.""" """Build FastAPI application instance."""
@ -43,7 +43,7 @@ def create_app() -> FastAPI:
# "http://localhost:8000", # "http://localhost:8000",
# "http://0.0.0.0:8000", # "http://0.0.0.0:8000",
# "http://127.0.0.1:8000", # "http://127.0.0.1:8000",
"*" # ! TODO: Убрать "*", # ! TODO: Убрать
], ],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], # Разрешить все HTTP-методы allow_methods=["*"], # Разрешить все HTTP-методы
@ -59,7 +59,9 @@ def create_app() -> FastAPI:
return FileResponse(FRONTEND_INDEX) return FileResponse(FRONTEND_INDEX)
@application.get("/{path:path}", include_in_schema=False) @application.get("/{path:path}", include_in_schema=False)
async def serve_frontend_path(path: str) -> FileResponse: # pragma: no cover - simple file response async def serve_frontend_path(
path: str,
) -> FileResponse: # pragma: no cover - simple file response
if path == "" or path.startswith("api"): if path == "" or path.startswith("api"):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)

View File

@ -1,4 +1,5 @@
"""Model exports for Alembic discovery.""" """Model exports for Alembic discovery."""
from app.models.activity import Activity, ActivityType from app.models.activity import Activity, ActivityType
from app.models.base import Base from app.models.base import Base
from app.models.contact import Contact from app.models.contact import Contact
@ -9,16 +10,16 @@ from app.models.task import Task
from app.models.user import User from app.models.user import User
__all__ = [ __all__ = [
"Activity", "Activity",
"ActivityType", "ActivityType",
"Base", "Base",
"Contact", "Contact",
"Deal", "Deal",
"DealStage", "DealStage",
"DealStatus", "DealStatus",
"Organization", "Organization",
"OrganizationMember", "OrganizationMember",
"OrganizationRole", "OrganizationRole",
"Task", "Task",
"User", "User",
] ]

View File

@ -1,4 +1,5 @@
"""Activity timeline ORM model and schemas.""" """Activity timeline ORM model and schemas."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
@ -6,10 +7,12 @@ from enum import StrEnum
from typing import Any from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, func, text from sqlalchemy import DateTime, ForeignKey, Integer, func, text
from sqlalchemy import Enum as SqlEnum
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.types import JSON as GenericJSON, TypeDecorator
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.types import JSON as GenericJSON
from sqlalchemy.types import TypeDecorator
from app.models.base import Base, enum_values from app.models.base import Base, enum_values
@ -44,10 +47,12 @@ class Activity(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True) id: Mapped[int] = mapped_column(Integer, primary_key=True)
deal_id: Mapped[int] = mapped_column(ForeignKey("deals.id", ondelete="CASCADE")) deal_id: Mapped[int] = mapped_column(ForeignKey("deals.id", ondelete="CASCADE"))
author_id: Mapped[int | None] = mapped_column( author_id: Mapped[int | None] = mapped_column(
ForeignKey("users.id", ondelete="SET NULL"), nullable=True ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
) )
type: Mapped[ActivityType] = mapped_column( type: Mapped[ActivityType] = mapped_column(
SqlEnum(ActivityType, name="activity_type", values_callable=enum_values), nullable=False SqlEnum(ActivityType, name="activity_type", values_callable=enum_values),
nullable=False,
) )
payload: Mapped[dict[str, Any]] = mapped_column( payload: Mapped[dict[str, Any]] = mapped_column(
JSONBCompat().with_variant(GenericJSON(), "sqlite"), JSONBCompat().with_variant(GenericJSON(), "sqlite"),
@ -55,7 +60,9 @@ class Activity(Base):
server_default=text("'{}'"), server_default=text("'{}'"),
) )
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False DateTime(timezone=True),
server_default=func.now(),
nullable=False,
) )
deal = relationship("Deal", back_populates="activities") deal = relationship("Deal", back_populates="activities")

View File

@ -1,4 +1,5 @@
"""Declarative base for SQLAlchemy models.""" """Declarative base for SQLAlchemy models."""
from __future__ import annotations from __future__ import annotations
from enum import StrEnum from enum import StrEnum

View File

@ -1,4 +1,5 @@
"""Contact ORM model and schemas.""" """Contact ORM model and schemas."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
@ -22,7 +23,9 @@ class Contact(Base):
email: Mapped[str | None] = mapped_column(String(320), nullable=True) email: Mapped[str | None] = mapped_column(String(320), nullable=True)
phone: Mapped[str | None] = mapped_column(String(64), nullable=True) phone: Mapped[str | None] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False DateTime(timezone=True),
server_default=func.now(),
nullable=False,
) )
organization = relationship("Organization", back_populates="contacts") organization = relationship("Organization", back_populates="contacts")

View File

@ -1,4 +1,5 @@
"""Deal ORM model and schemas.""" """Deal ORM model and schemas."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
@ -6,7 +7,8 @@ from decimal import Decimal
from enum import StrEnum from enum import StrEnum
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, Numeric, String, func from sqlalchemy import DateTime, ForeignKey, Integer, Numeric, String, func
from sqlalchemy import Enum as SqlEnum
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, enum_values from app.models.base import Base, enum_values
@ -49,10 +51,15 @@ class Deal(Base):
default=DealStage.QUALIFICATION, default=DealStage.QUALIFICATION,
) )
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False DateTime(timezone=True),
server_default=func.now(),
nullable=False,
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
) )
organization = relationship("Organization", back_populates="deals") organization = relationship("Organization", back_populates="deals")

View File

@ -1,4 +1,5 @@
"""Organization ORM model and schemas.""" """Organization ORM model and schemas."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
@ -18,7 +19,9 @@ class Organization(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False DateTime(timezone=True),
server_default=func.now(),
nullable=False,
) )
members = relationship( members = relationship(

View File

@ -1,11 +1,13 @@
"""Organization member ORM model.""" """Organization member ORM model."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, UniqueConstraint, func from sqlalchemy import DateTime, ForeignKey, Integer, UniqueConstraint, func
from sqlalchemy import Enum as SqlEnum
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, enum_values from app.models.base import Base, enum_values
@ -39,7 +41,9 @@ class OrganizationMember(Base):
default=OrganizationRole.MEMBER, default=OrganizationRole.MEMBER,
) )
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False DateTime(timezone=True),
server_default=func.now(),
nullable=False,
) )
organization = relationship("Organization", back_populates="members") organization = relationship("Organization", back_populates="members")

View File

@ -1,4 +1,5 @@
"""Task ORM model and schemas.""" """Task ORM model and schemas."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
@ -22,7 +23,9 @@ class Task(Base):
due_date: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) due_date: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
is_done: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) is_done: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False DateTime(timezone=True),
server_default=func.now(),
nullable=False,
) )
deal = relationship("Deal", back_populates="tasks") deal = relationship("Deal", back_populates="tasks")

View File

@ -1,4 +1,5 @@
"""Token-related Pydantic schemas.""" """Token-related Pydantic schemas."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime

View File

@ -1,4 +1,5 @@
"""User ORM model and Pydantic schemas.""" """User ORM model and Pydantic schemas."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
@ -25,13 +26,20 @@ class User(Base):
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False DateTime(timezone=True),
server_default=func.now(),
nullable=False,
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
) )
memberships = relationship("OrganizationMember", back_populates="user", cascade="all, delete-orphan") memberships = relationship(
"OrganizationMember", back_populates="user", cascade="all, delete-orphan"
)
owned_contacts = relationship("Contact", back_populates="owner") owned_contacts = relationship("Contact", back_populates="owner")
owned_deals = relationship("Deal", back_populates="owner") owned_deals = relationship("Deal", back_populates="owner")
activities = relationship("Activity", back_populates="author") activities = relationship("Activity", back_populates="author")

View File

@ -1,4 +1,5 @@
"""Repository helpers for deal activities.""" """Repository helpers for deal activities."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
@ -39,7 +40,9 @@ class ActivityRepository:
stmt = ( stmt = (
select(Activity) select(Activity)
.join(Deal, Deal.id == Activity.deal_id) .join(Deal, Deal.id == Activity.deal_id)
.where(Activity.deal_id == params.deal_id, Deal.organization_id == params.organization_id) .where(
Activity.deal_id == params.deal_id, Deal.organization_id == params.organization_id
)
.order_by(Activity.created_at) .order_by(Activity.created_at)
) )
stmt = self._apply_window(stmt, params) stmt = self._apply_window(stmt, params)

View File

@ -1,4 +1,5 @@
"""Analytics-specific data access helpers.""" """Analytics-specific data access helpers."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
@ -58,7 +59,7 @@ class AnalyticsRepository:
deal_count=int(count or 0), deal_count=int(count or 0),
amount_sum=_to_decimal(amount_sum), amount_sum=_to_decimal(amount_sum),
amount_count=int(amount_count or 0), amount_count=int(amount_count or 0),
) ),
) )
return rollup return rollup

View File

@ -1,4 +1,5 @@
"""Repository helpers for contacts with role-aware access.""" """Repository helpers for contacts with role-aware access."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
@ -44,7 +45,9 @@ class ContactRepository:
role: OrganizationRole, role: OrganizationRole,
user_id: int, user_id: int,
) -> Sequence[Contact]: ) -> Sequence[Contact]:
stmt: Select[tuple[Contact]] = select(Contact).where(Contact.organization_id == params.organization_id) stmt: Select[tuple[Contact]] = select(Contact).where(
Contact.organization_id == params.organization_id
)
stmt = self._apply_filters(stmt, params, role, user_id) stmt = self._apply_filters(stmt, params, role, user_id)
offset = (max(params.page, 1) - 1) * params.page_size offset = (max(params.page, 1) - 1) * params.page_size
stmt = stmt.order_by(Contact.created_at.desc()).offset(offset).limit(params.page_size) stmt = stmt.order_by(Contact.created_at.desc()).offset(offset).limit(params.page_size)
@ -59,7 +62,9 @@ class ContactRepository:
role: OrganizationRole, role: OrganizationRole,
user_id: int, user_id: int,
) -> Contact | None: ) -> Contact | None:
stmt = select(Contact).where(Contact.id == contact_id, Contact.organization_id == organization_id) stmt = select(Contact).where(
Contact.id == contact_id, Contact.organization_id == organization_id
)
result = await self._session.scalars(stmt) result = await self._session.scalars(stmt)
return result.first() return result.first()
@ -117,7 +122,7 @@ class ContactRepository:
pattern = f"%{params.search.lower()}%" pattern = f"%{params.search.lower()}%"
stmt = stmt.where( stmt = stmt.where(
func.lower(Contact.name).like(pattern) func.lower(Contact.name).like(pattern)
| func.lower(func.coalesce(Contact.email, "")).like(pattern) | func.lower(func.coalesce(Contact.email, "")).like(pattern),
) )
if params.owner_id is not None: if params.owner_id is not None:
if role == OrganizationRole.MEMBER: if role == OrganizationRole.MEMBER:

View File

@ -1,4 +1,5 @@
"""Deal repository with access-aware CRUD helpers.""" """Deal repository with access-aware CRUD helpers."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
@ -12,142 +13,143 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.models.deal import Deal, DealCreate, DealStage, DealStatus from app.models.deal import Deal, DealCreate, DealStage, DealStatus
from app.models.organization_member import OrganizationRole from app.models.organization_member import OrganizationRole
ORDERABLE_COLUMNS: dict[str, Any] = { ORDERABLE_COLUMNS: dict[str, Any] = {
"created_at": Deal.created_at, "created_at": Deal.created_at,
"amount": Deal.amount, "amount": Deal.amount,
"title": Deal.title, "title": Deal.title,
} }
class DealAccessError(Exception): class DealAccessError(Exception):
"""Raised when a user attempts an operation without sufficient permissions.""" """Raised when a user attempts an operation without sufficient permissions."""
@dataclass(slots=True) @dataclass(slots=True)
class DealQueryParams: class DealQueryParams:
"""Filters supported by list queries.""" """Filters supported by list queries."""
organization_id: int organization_id: int
page: int = 1 page: int = 1
page_size: int = 20 page_size: int = 20
statuses: Sequence[DealStatus] | None = None statuses: Sequence[DealStatus] | None = None
stage: DealStage | None = None stage: DealStage | None = None
owner_id: int | None = None owner_id: int | None = None
min_amount: Decimal | None = None min_amount: Decimal | None = None
max_amount: Decimal | None = None max_amount: Decimal | None = None
order_by: str | None = None order_by: str | None = None
order_desc: bool = True order_desc: bool = True
class DealRepository: class DealRepository:
"""Provides CRUD helpers for deals with role-aware filtering.""" """Provides CRUD helpers for deals with role-aware filtering."""
def __init__(self, session: AsyncSession) -> None: def __init__(self, session: AsyncSession) -> None:
self._session = session self._session = session
@property @property
def session(self) -> AsyncSession: def session(self) -> AsyncSession:
return self._session return self._session
async def list( async def list(
self, self,
*, *,
params: DealQueryParams, params: DealQueryParams,
role: OrganizationRole, role: OrganizationRole,
user_id: int, user_id: int,
) -> Sequence[Deal]: ) -> Sequence[Deal]:
stmt = select(Deal).where(Deal.organization_id == params.organization_id) stmt = select(Deal).where(Deal.organization_id == params.organization_id)
stmt = self._apply_filters(stmt, params, role, user_id) stmt = self._apply_filters(stmt, params, role, user_id)
stmt = self._apply_ordering(stmt, params) stmt = self._apply_ordering(stmt, params)
offset = (max(params.page, 1) - 1) * params.page_size offset = (max(params.page, 1) - 1) * params.page_size
stmt = stmt.offset(offset).limit(params.page_size) stmt = stmt.offset(offset).limit(params.page_size)
result = await self._session.scalars(stmt) result = await self._session.scalars(stmt)
return result.all() return result.all()
async def get( async def get(
self, self,
deal_id: int, deal_id: int,
*, *,
organization_id: int, organization_id: int,
role: OrganizationRole, role: OrganizationRole,
user_id: int, user_id: int,
require_owner: bool = False, require_owner: bool = False,
) -> Deal | None: ) -> Deal | None:
stmt = select(Deal).where(Deal.id == deal_id, Deal.organization_id == organization_id) stmt = select(Deal).where(Deal.id == deal_id, Deal.organization_id == organization_id)
stmt = self._apply_role_clause(stmt, role, user_id, require_owner=require_owner) stmt = self._apply_role_clause(stmt, role, user_id, require_owner=require_owner)
result = await self._session.scalars(stmt) result = await self._session.scalars(stmt)
return result.first() return result.first()
async def create( async def create(
self, self,
data: DealCreate, data: DealCreate,
*, *,
role: OrganizationRole, role: OrganizationRole,
user_id: int, user_id: int,
) -> Deal: ) -> Deal:
if role == OrganizationRole.MEMBER and data.owner_id != user_id: if role == OrganizationRole.MEMBER and data.owner_id != user_id:
raise DealAccessError("Members can only create deals they own") raise DealAccessError("Members can only create deals they own")
deal = Deal(**data.model_dump()) deal = Deal(**data.model_dump())
self._session.add(deal) self._session.add(deal)
await self._session.flush() await self._session.flush()
return deal return deal
async def update( async def update(
self, self,
deal: Deal, deal: Deal,
updates: Mapping[str, Any], updates: Mapping[str, Any],
*, *,
role: OrganizationRole, role: OrganizationRole,
user_id: int, user_id: int,
) -> Deal: ) -> Deal:
if role == OrganizationRole.MEMBER and deal.owner_id != user_id: if role == OrganizationRole.MEMBER and deal.owner_id != user_id:
raise DealAccessError("Members can only modify their own deals") raise DealAccessError("Members can only modify their own deals")
for field, value in updates.items(): for field, value in updates.items():
if hasattr(deal, field): if hasattr(deal, field):
setattr(deal, field, value) setattr(deal, field, value)
await self._session.flush() await self._session.flush()
await self._session.refresh(deal) await self._session.refresh(deal)
return deal return deal
def _apply_filters( def _apply_filters(
self, self,
stmt: Select[tuple[Deal]], stmt: Select[tuple[Deal]],
params: DealQueryParams, params: DealQueryParams,
role: OrganizationRole, role: OrganizationRole,
user_id: int, user_id: int,
) -> Select[tuple[Deal]]: ) -> Select[tuple[Deal]]:
if params.statuses: if params.statuses:
stmt = stmt.where(Deal.status.in_(params.statuses)) stmt = stmt.where(Deal.status.in_(params.statuses))
if params.stage: if params.stage:
stmt = stmt.where(Deal.stage == params.stage) stmt = stmt.where(Deal.stage == params.stage)
if params.owner_id is not None: if params.owner_id is not None:
if role == OrganizationRole.MEMBER and params.owner_id != user_id: if role == OrganizationRole.MEMBER and params.owner_id != user_id:
raise DealAccessError("Members cannot filter by other owners") raise DealAccessError("Members cannot filter by other owners")
stmt = stmt.where(Deal.owner_id == params.owner_id) stmt = stmt.where(Deal.owner_id == params.owner_id)
if params.min_amount is not None: if params.min_amount is not None:
stmt = stmt.where(Deal.amount >= params.min_amount) stmt = stmt.where(Deal.amount >= params.min_amount)
if params.max_amount is not None: if params.max_amount is not None:
stmt = stmt.where(Deal.amount <= params.max_amount) stmt = stmt.where(Deal.amount <= params.max_amount)
return self._apply_role_clause(stmt, role, user_id) return self._apply_role_clause(stmt, role, user_id)
def _apply_role_clause( def _apply_role_clause(
self, self,
stmt: Select[tuple[Deal]], stmt: Select[tuple[Deal]],
role: OrganizationRole, role: OrganizationRole,
user_id: int, user_id: int,
*, *,
require_owner: bool = False, require_owner: bool = False,
) -> Select[tuple[Deal]]: ) -> Select[tuple[Deal]]:
if role in {OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MANAGER}: if role in {OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MANAGER}:
return stmt return stmt
if require_owner: if require_owner:
return stmt.where(Deal.owner_id == user_id) return stmt.where(Deal.owner_id == user_id)
return stmt return stmt
def _apply_ordering(self, stmt: Select[tuple[Deal]], params: DealQueryParams) -> Select[tuple[Deal]]: def _apply_ordering(
column = ORDERABLE_COLUMNS.get(params.order_by or "created_at", Deal.created_at) self, stmt: Select[tuple[Deal]], params: DealQueryParams
order_func = desc if params.order_desc else asc ) -> Select[tuple[Deal]]:
return stmt.order_by(order_func(column)) column = ORDERABLE_COLUMNS.get(params.order_by or "created_at", Deal.created_at)
order_func = desc if params.order_desc else asc
return stmt.order_by(order_func(column))

View File

@ -1,11 +1,12 @@
"""Organization repository for database operations.""" """Organization repository for database operations."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.organization import Organization, OrganizationCreate from app.models.organization import Organization, OrganizationCreate
from app.models.organization_member import OrganizationMember from app.models.organization_member import OrganizationMember

View File

@ -1,4 +1,5 @@
"""Task repository providing role-aware CRUD helpers.""" """Task repository providing role-aware CRUD helpers."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
@ -105,7 +106,9 @@ class TaskRepository:
await self._session.flush() await self._session.flush()
return task return task
def _apply_filters(self, stmt: Select[tuple[Task]], params: TaskQueryParams) -> Select[tuple[Task]]: def _apply_filters(
self, stmt: Select[tuple[Task]], params: TaskQueryParams
) -> Select[tuple[Task]]:
if params.deal_id is not None: if params.deal_id is not None:
stmt = stmt.where(Task.deal_id == params.deal_id) stmt = stmt.where(Task.deal_id == params.deal_id)
if params.only_open: if params.only_open:

View File

@ -1,4 +1,5 @@
"""User repository handling database operations.""" """User repository handling database operations."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence

View File

@ -1,4 +1,5 @@
"""Business logic services.""" """Business logic services."""
from .activity_service import ( # noqa: F401 from .activity_service import ( # noqa: F401
ActivityForbiddenError, ActivityForbiddenError,
ActivityListFilters, ActivityListFilters,
@ -22,4 +23,4 @@ from .task_service import ( # noqa: F401
TaskService, TaskService,
TaskServiceError, TaskServiceError,
TaskUpdateData, TaskUpdateData,
) )

View File

@ -1,4 +1,5 @@
"""Business logic for timeline activities.""" """Business logic for timeline activities."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence

View File

@ -1,11 +1,13 @@
"""Analytics-related business logic.""" """Analytics-related business logic."""
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from decimal import Decimal, InvalidOperation from decimal import Decimal, InvalidOperation
from typing import Any, Iterable from typing import Any
from redis.asyncio.client import Redis from redis.asyncio.client import Redis
from redis.exceptions import RedisError from redis.exceptions import RedisError
@ -105,9 +107,7 @@ class AnalyticsService:
won_amount_count = row.amount_count won_amount_count = row.amount_count
won_count = row.deal_count won_count = row.deal_count
won_average = ( won_average = (won_amount_sum / won_amount_count) if won_amount_count > 0 else Decimal("0")
(won_amount_sum / won_amount_count) if won_amount_count > 0 else Decimal("0")
)
window_threshold = _threshold_from_days(days) window_threshold = _threshold_from_days(days)
new_deals = await self._repository.count_new_deals_since(organization_id, window_threshold) new_deals = await self._repository.count_new_deals_since(organization_id, window_threshold)
@ -137,7 +137,7 @@ class AnalyticsService:
breakdowns: list[StageBreakdown] = [] breakdowns: list[StageBreakdown] = []
totals = {stage: sum(by_status.values()) for stage, by_status in stage_map.items()} totals = {stage: sum(by_status.values()) for stage, by_status in stage_map.items()}
for index, stage in enumerate(_STAGE_ORDER): for index, stage in enumerate(_STAGE_ORDER):
by_status = stage_map.get(stage, {status: 0 for status in DealStatus}) by_status = stage_map.get(stage, dict.fromkeys(DealStatus, 0))
total = totals.get(stage, 0) total = totals.get(stage, 0)
conversion = None conversion = None
if index < len(_STAGE_ORDER) - 1: if index < len(_STAGE_ORDER) - 1:
@ -151,7 +151,7 @@ class AnalyticsService:
total=total, total=total,
by_status=by_status, by_status=by_status,
conversion_to_next=conversion, conversion_to_next=conversion,
) ),
) )
await self._store_funnel_cache(organization_id, breakdowns) await self._store_funnel_cache(organization_id, breakdowns)
return breakdowns return breakdowns
@ -168,7 +168,9 @@ class AnalyticsService:
return None return None
return _deserialize_summary(payload) return _deserialize_summary(payload)
async def _store_summary_cache(self, organization_id: int, days: int, summary: DealSummary) -> None: async def _store_summary_cache(
self, organization_id: int, days: int, summary: DealSummary
) -> None:
if not self._is_cache_enabled() or self._cache is None: if not self._is_cache_enabled() or self._cache is None:
return return
key = _summary_cache_key(organization_id, days) key = _summary_cache_key(organization_id, days)
@ -184,7 +186,9 @@ class AnalyticsService:
return None return None
return _deserialize_funnel(payload) return _deserialize_funnel(payload)
async def _store_funnel_cache(self, organization_id: int, breakdowns: list[StageBreakdown]) -> None: async def _store_funnel_cache(
self, organization_id: int, breakdowns: list[StageBreakdown]
) -> None:
if not self._is_cache_enabled() or self._cache is None: if not self._is_cache_enabled() or self._cache is None:
return return
key = _funnel_cache_key(organization_id) key = _funnel_cache_key(organization_id)
@ -198,11 +202,10 @@ def _threshold_from_days(days: int) -> datetime:
def _build_stage_map(rollup: Iterable[StageStatusRollup]) -> dict[DealStage, dict[DealStatus, int]]: def _build_stage_map(rollup: Iterable[StageStatusRollup]) -> dict[DealStage, dict[DealStatus, int]]:
stage_map: dict[DealStage, dict[DealStatus, int]] = { stage_map: dict[DealStage, dict[DealStatus, int]] = {
stage: {status: 0 for status in DealStatus} stage: dict.fromkeys(DealStatus, 0) for stage in _STAGE_ORDER
for stage in _STAGE_ORDER
} }
for item in rollup: for item in rollup:
stage_map.setdefault(item.stage, {status: 0 for status in DealStatus}) stage_map.setdefault(item.stage, dict.fromkeys(DealStatus, 0))
stage_map[item.stage][item.status] = item.deal_count stage_map[item.stage][item.status] = item.deal_count
return stage_map return stage_map
@ -263,7 +266,7 @@ def _deserialize_summary(payload: Any) -> DealSummary | None:
status=DealStatus(item["status"]), status=DealStatus(item["status"]),
count=int(item["count"]), count=int(item["count"]),
amount_sum=Decimal(item["amount_sum"]), amount_sum=Decimal(item["amount_sum"]),
) ),
) )
won = WonStatistics( won = WonStatistics(
count=int(won_payload["count"]), count=int(won_payload["count"]),
@ -289,7 +292,7 @@ def _serialize_funnel(breakdowns: list[StageBreakdown]) -> list[dict[str, Any]]:
"total": item.total, "total": item.total,
"by_status": {status.value: count for status, count in item.by_status.items()}, "by_status": {status.value: count for status, count in item.by_status.items()},
"conversion_to_next": item.conversion_to_next, "conversion_to_next": item.conversion_to_next,
} },
) )
return serialized return serialized
@ -307,15 +310,19 @@ def _deserialize_funnel(payload: Any) -> list[StageBreakdown] | None:
stage=DealStage(item["stage"]), stage=DealStage(item["stage"]),
total=int(item["total"]), total=int(item["total"]),
by_status=by_status, by_status=by_status,
conversion_to_next=float(item["conversion_to_next"]) if item["conversion_to_next"] is not None else None, conversion_to_next=float(item["conversion_to_next"])
) if item["conversion_to_next"] is not None
else None,
),
) )
except (KeyError, TypeError, ValueError): except (KeyError, TypeError, ValueError):
return None return None
return breakdowns return breakdowns
async def invalidate_analytics_cache(cache: Redis | None, organization_id: int, backoff_ms: int) -> None: async def invalidate_analytics_cache(
cache: Redis | None, organization_id: int, backoff_ms: int
) -> None:
"""Remove cached analytics payloads for the organization.""" """Remove cached analytics payloads for the organization."""
if cache is None: if cache is None:

View File

@ -1,4 +1,5 @@
"""Authentication workflows.""" """Authentication workflows."""
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from datetime import timedelta

View File

@ -1,4 +1,5 @@
"""Business logic for contact workflows.""" """Business logic for contact workflows."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence from collections.abc import Sequence
@ -78,7 +79,9 @@ class ContactService:
owner_id=filters.owner_id, owner_id=filters.owner_id,
) )
try: try:
return await self._repository.list(params=params, role=context.role, user_id=context.user_id) return await self._repository.list(
params=params, role=context.role, user_id=context.user_id
)
except ContactAccessError as exc: except ContactAccessError as exc:
raise ContactForbiddenError(str(exc)) from exc raise ContactForbiddenError(str(exc)) from exc
@ -122,7 +125,9 @@ class ContactService:
if not payload: if not payload:
return contact return contact
try: try:
return await self._repository.update(contact, payload, role=context.role, user_id=context.user_id) return await self._repository.update(
contact, payload, role=context.role, user_id=context.user_id
)
except ContactAccessError as exc: except ContactAccessError as exc:
raise ContactForbiddenError(str(exc)) from exc raise ContactForbiddenError(str(exc)) from exc

View File

@ -1,4 +1,5 @@
"""Business logic for deals.""" """Business logic for deals."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
@ -16,162 +17,173 @@ from app.repositories.deal_repo import DealRepository
from app.services.analytics_service import invalidate_analytics_cache from app.services.analytics_service import invalidate_analytics_cache
from app.services.organization_service import OrganizationContext from app.services.organization_service import OrganizationContext
STAGE_ORDER = { STAGE_ORDER = {
stage: index stage: index
for index, stage in enumerate( for index, stage in enumerate(
[ [
DealStage.QUALIFICATION, DealStage.QUALIFICATION,
DealStage.PROPOSAL, DealStage.PROPOSAL,
DealStage.NEGOTIATION, DealStage.NEGOTIATION,
DealStage.CLOSED, DealStage.CLOSED,
] ],
) )
} }
class DealServiceError(Exception): class DealServiceError(Exception):
"""Base class for deal service errors.""" """Base class for deal service errors."""
class DealOrganizationMismatchError(DealServiceError): class DealOrganizationMismatchError(DealServiceError):
"""Raised when attempting to use resources from another organization.""" """Raised when attempting to use resources from another organization."""
class DealStageTransitionError(DealServiceError): class DealStageTransitionError(DealServiceError):
"""Raised when stage transition violates business rules.""" """Raised when stage transition violates business rules."""
class DealStatusValidationError(DealServiceError): class DealStatusValidationError(DealServiceError):
"""Raised when invalid status transitions are requested.""" """Raised when invalid status transitions are requested."""
class ContactHasDealsError(DealServiceError): class ContactHasDealsError(DealServiceError):
"""Raised when attempting to delete a contact with active deals.""" """Raised when attempting to delete a contact with active deals."""
@dataclass(slots=True) @dataclass(slots=True)
class DealUpdateData: class DealUpdateData:
"""Structured container for deal update operations.""" """Structured container for deal update operations."""
status: DealStatus | None = None status: DealStatus | None = None
stage: DealStage | None = None stage: DealStage | None = None
amount: Decimal | None = None amount: Decimal | None = None
currency: str | None = None currency: str | None = None
class DealService: class DealService:
"""Encapsulates deal workflows and validations.""" """Encapsulates deal workflows and validations."""
def __init__( def __init__(
self, self,
repository: DealRepository, repository: DealRepository,
cache: Redis | None = None, cache: Redis | None = None,
*, *,
cache_backoff_ms: int = 0, cache_backoff_ms: int = 0,
) -> None: ) -> None:
self._repository = repository self._repository = repository
self._cache = cache self._cache = cache
self._cache_backoff_ms = cache_backoff_ms self._cache_backoff_ms = cache_backoff_ms
async def create_deal(self, data: DealCreate, *, context: OrganizationContext) -> Deal: async def create_deal(self, data: DealCreate, *, context: OrganizationContext) -> Deal:
self._ensure_same_organization(data.organization_id, context) self._ensure_same_organization(data.organization_id, context)
await self._ensure_contact_in_organization(data.contact_id, context.organization_id) await self._ensure_contact_in_organization(data.contact_id, context.organization_id)
deal = await self._repository.create(data=data, role=context.role, user_id=context.user_id) deal = await self._repository.create(data=data, role=context.role, user_id=context.user_id)
await invalidate_analytics_cache(self._cache, context.organization_id, self._cache_backoff_ms) await invalidate_analytics_cache(
return deal self._cache, context.organization_id, self._cache_backoff_ms
)
return deal
async def update_deal( async def update_deal(
self, self,
deal: Deal, deal: Deal,
updates: DealUpdateData, updates: DealUpdateData,
*, *,
context: OrganizationContext, context: OrganizationContext,
) -> Deal: ) -> Deal:
self._ensure_same_organization(deal.organization_id, context) self._ensure_same_organization(deal.organization_id, context)
changes: dict[str, object] = {} changes: dict[str, object] = {}
stage_activity: tuple[ActivityType, dict[str, str]] | None = None stage_activity: tuple[ActivityType, dict[str, str]] | None = None
status_activity: tuple[ActivityType, dict[str, str]] | None = None status_activity: tuple[ActivityType, dict[str, str]] | None = None
if updates.amount is not None: if updates.amount is not None:
changes["amount"] = updates.amount changes["amount"] = updates.amount
if updates.currency is not None: if updates.currency is not None:
changes["currency"] = updates.currency changes["currency"] = updates.currency
if updates.stage is not None and updates.stage != deal.stage: if updates.stage is not None and updates.stage != deal.stage:
self._validate_stage_transition(deal.stage, updates.stage, context.role) self._validate_stage_transition(deal.stage, updates.stage, context.role)
changes["stage"] = updates.stage changes["stage"] = updates.stage
stage_activity = ( stage_activity = (
ActivityType.STAGE_CHANGED, ActivityType.STAGE_CHANGED,
{"old_stage": deal.stage, "new_stage": updates.stage}, {"old_stage": deal.stage, "new_stage": updates.stage},
) )
if updates.status is not None and updates.status != deal.status: if updates.status is not None and updates.status != deal.status:
self._validate_status_transition(deal, updates) self._validate_status_transition(deal, updates)
changes["status"] = updates.status changes["status"] = updates.status
status_activity = ( status_activity = (
ActivityType.STATUS_CHANGED, ActivityType.STATUS_CHANGED,
{"old_status": deal.status, "new_status": updates.status}, {"old_status": deal.status, "new_status": updates.status},
) )
if not changes: if not changes:
return deal return deal
updated = await self._repository.update(deal, changes, role=context.role, user_id=context.user_id) updated = await self._repository.update(
await self._log_activities( deal, changes, role=context.role, user_id=context.user_id
deal_id=deal.id, )
author_id=context.user_id, await self._log_activities(
activities=[activity for activity in [stage_activity, status_activity] if activity], deal_id=deal.id,
) author_id=context.user_id,
await invalidate_analytics_cache(self._cache, context.organization_id, self._cache_backoff_ms) activities=[activity for activity in [stage_activity, status_activity] if activity],
return updated )
await invalidate_analytics_cache(
self._cache, context.organization_id, self._cache_backoff_ms
)
return updated
async def ensure_contact_can_be_deleted(self, contact_id: int) -> None: async def ensure_contact_can_be_deleted(self, contact_id: int) -> None:
stmt = select(func.count()).select_from(Deal).where(Deal.contact_id == contact_id) stmt = select(func.count()).select_from(Deal).where(Deal.contact_id == contact_id)
count = await self._repository.session.scalar(stmt) count = await self._repository.session.scalar(stmt)
if count and count > 0: if count and count > 0:
raise ContactHasDealsError("Contact has related deals and cannot be deleted") raise ContactHasDealsError("Contact has related deals and cannot be deleted")
async def _log_activities( async def _log_activities(
self, self,
*, *,
deal_id: int, deal_id: int,
author_id: int, author_id: int,
activities: Iterable[tuple[ActivityType, dict[str, str]]], activities: Iterable[tuple[ActivityType, dict[str, str]]],
) -> None: ) -> None:
entries = list(activities) entries = list(activities)
if not entries: if not entries:
return return
for activity_type, payload in entries: for activity_type, payload in entries:
activity = Activity(deal_id=deal_id, author_id=author_id, type=activity_type, payload=payload) activity = Activity(
self._repository.session.add(activity) deal_id=deal_id, author_id=author_id, type=activity_type, payload=payload
await self._repository.session.flush() )
self._repository.session.add(activity)
await self._repository.session.flush()
def _ensure_same_organization(self, organization_id: int, context: OrganizationContext) -> None: def _ensure_same_organization(self, organization_id: int, context: OrganizationContext) -> None:
if organization_id != context.organization_id: if organization_id != context.organization_id:
raise DealOrganizationMismatchError("Operation targets a different organization") raise DealOrganizationMismatchError("Operation targets a different organization")
async def _ensure_contact_in_organization(self, contact_id: int, organization_id: int) -> Contact: async def _ensure_contact_in_organization(
contact = await self._repository.session.get(Contact, contact_id) self, contact_id: int, organization_id: int
if contact is None or contact.organization_id != organization_id: ) -> Contact:
raise DealOrganizationMismatchError("Contact belongs to another organization") contact = await self._repository.session.get(Contact, contact_id)
return contact if contact is None or contact.organization_id != organization_id:
raise DealOrganizationMismatchError("Contact belongs to another organization")
return contact
def _validate_stage_transition( def _validate_stage_transition(
self, self,
current_stage: DealStage, current_stage: DealStage,
new_stage: DealStage, new_stage: DealStage,
role: OrganizationRole, role: OrganizationRole,
) -> None: ) -> None:
if STAGE_ORDER[new_stage] < STAGE_ORDER[current_stage] and role not in { if STAGE_ORDER[new_stage] < STAGE_ORDER[current_stage] and role not in {
OrganizationRole.OWNER, OrganizationRole.OWNER,
OrganizationRole.ADMIN, OrganizationRole.ADMIN,
}: }:
raise DealStageTransitionError("Stage rollback requires owner or admin role") raise DealStageTransitionError("Stage rollback requires owner or admin role")
def _validate_status_transition(self, deal: Deal, updates: DealUpdateData) -> None: def _validate_status_transition(self, deal: Deal, updates: DealUpdateData) -> None:
if updates.status != DealStatus.WON: if updates.status != DealStatus.WON:
return return
effective_amount = updates.amount if updates.amount is not None else deal.amount effective_amount = updates.amount if updates.amount is not None else deal.amount
if effective_amount is None or Decimal(effective_amount) <= Decimal("0"): if effective_amount is None or Decimal(effective_amount) <= Decimal("0"):
raise DealStatusValidationError("Amount must be greater than zero to mark a deal as won") raise DealStatusValidationError(
"Amount must be greater than zero to mark a deal as won"
)

View File

@ -1,4 +1,5 @@
"""Organization-related business rules.""" """Organization-related business rules."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
@ -54,7 +55,9 @@ class OrganizationService:
def __init__(self, repository: OrganizationRepository) -> None: def __init__(self, repository: OrganizationRepository) -> None:
self._repository = repository self._repository = repository
async def get_context(self, *, user_id: int, organization_id: int | None) -> OrganizationContext: async def get_context(
self, *, user_id: int, organization_id: int | None
) -> OrganizationContext:
"""Resolve request context ensuring the user belongs to the given organization.""" """Resolve request context ensuring the user belongs to the given organization."""
if organization_id is None: if organization_id is None:
@ -66,7 +69,9 @@ class OrganizationService:
return OrganizationContext(organization=membership.organization, membership=membership) return OrganizationContext(organization=membership.organization, membership=membership)
def ensure_entity_in_context(self, *, entity_organization_id: int, context: OrganizationContext) -> None: def ensure_entity_in_context(
self, *, entity_organization_id: int, context: OrganizationContext
) -> None:
"""Make sure a resource belongs to the current organization.""" """Make sure a resource belongs to the current organization."""
if entity_organization_id != context.organization_id: if entity_organization_id != context.organization_id:
@ -113,4 +118,4 @@ class OrganizationService:
self._repository.session.add(membership) self._repository.session.add(membership)
await self._repository.session.commit() await self._repository.session.commit()
await self._repository.session.refresh(membership) await self._repository.session.refresh(membership)
return membership return membership

View File

@ -1,4 +1,5 @@
"""Business logic for tasks linked to deals.""" """Business logic for tasks linked to deals."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
@ -9,10 +10,14 @@ from typing import Any
from app.models.activity import ActivityCreate, ActivityType from app.models.activity import ActivityCreate, ActivityType
from app.models.organization_member import OrganizationRole from app.models.organization_member import OrganizationRole
from app.models.task import Task, TaskCreate from app.models.task import Task, TaskCreate
from app.repositories.activity_repo import ActivityRepository, ActivityOrganizationMismatchError from app.repositories.activity_repo import ActivityOrganizationMismatchError, ActivityRepository
from app.repositories.task_repo import ( from app.repositories.task_repo import (
TaskAccessError as RepoTaskAccessError, TaskAccessError as RepoTaskAccessError,
)
from app.repositories.task_repo import (
TaskOrganizationMismatchError as RepoTaskOrganizationMismatchError, TaskOrganizationMismatchError as RepoTaskOrganizationMismatchError,
)
from app.repositories.task_repo import (
TaskQueryParams, TaskQueryParams,
TaskRepository, TaskRepository,
) )

View File

@ -24,3 +24,64 @@ dev = [
"pytest-asyncio>=0.25.0", "pytest-asyncio>=0.25.0",
"aiosqlite>=0.20.0", "aiosqlite>=0.20.0",
] ]
[tool.isort]
profile = "black"
line_length = 100
combine_as_imports = true
default_section = "THIRDPARTY"
known_first_party = ["app", "tests"]
skip_glob = ["migrations/*"]
[tool.mypy]
python_version = "3.14"
plugins = ["pydantic.mypy"]
warn_unused_configs = true
warn_return_any = true
warn_unused_ignores = true
disallow_untyped_defs = true
disallow_untyped_calls = true
disallow_any_unimported = true
no_implicit_optional = true
strict_optional = true
show_error_codes = true
exclude = ["migrations/"]
[[tool.mypy.overrides]]
module = ["tests.*"]
ignore_missing_imports = true
allow_untyped_defs = true
[tool.ruff]
line-length = 100
target-version = "py310"
src = ["app", "migrations", "tests"]
[tool.ruff.lint]
select = [
"E",
"F",
"W",
"B",
"UP",
"I",
"N",
"S",
"Q",
"C4",
"COM",
"DTZ",
"G",
"TID",
]
ignore = ["E203", "E266", "E501", "S101"]
[tool.ruff.lint.per-file-ignores]
"tests/**/*" = ["S311"]
"migrations/*" = ["B008", "DTZ001", "TID252"]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"

View File

@ -1,17 +1,17 @@
"""Pytest fixtures shared across API v1 tests.""" """Pytest fixtures shared across API v1 tests."""
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.api.deps import get_cache_backend, get_db_session from app.api.deps import get_cache_backend, get_db_session
from app.core.security import password_hasher from app.core.security import password_hasher
from app.main import create_app from app.main import create_app
from app.models import Base from app.models import Base
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from tests.utils.fake_redis import InMemoryRedis from tests.utils.fake_redis import InMemoryRedis

View File

@ -1,17 +1,17 @@
"""Shared helpers for task and activity API tests.""" """Shared helpers for task and activity API tests."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.core.security import jwt_service from app.core.security import jwt_service
from app.models.contact import Contact from app.models.contact import Contact
from app.models.deal import Deal from app.models.deal import Deal
from app.models.organization import Organization from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.user import User from app.models.user import User
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@dataclass(slots=True) @dataclass(slots=True)
@ -27,7 +27,9 @@ class Scenario:
async def prepare_scenario(session_factory: async_sessionmaker[AsyncSession]) -> Scenario: async def prepare_scenario(session_factory: async_sessionmaker[AsyncSession]) -> Scenario:
async with session_factory() as session: async with session_factory() as session:
user = User(email="owner@example.com", hashed_password="hashed", name="Owner", is_active=True) user = User(
email="owner@example.com", hashed_password="hashed", name="Owner", is_active=True
)
org = Organization(name="Acme LLC") org = Organization(name="Acme LLC")
session.add_all([user, org]) session.add_all([user, org])
await session.flush() await session.flush()

View File

@ -1,20 +1,20 @@
"""API tests for activity endpoints.""" """API tests for activity endpoints."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
import pytest import pytest
from app.models.activity import Activity, ActivityType
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.models.activity import Activity, ActivityType
from tests.api.v1.task_activity_shared import auth_headers, make_token, prepare_scenario from tests.api.v1.task_activity_shared import auth_headers, make_token, prepare_scenario
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_activity_comment_endpoint( async def test_create_activity_comment_endpoint(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_scenario(session_factory) scenario = await prepare_scenario(session_factory)
token = make_token(scenario.user_id, scenario.user_email) token = make_token(scenario.user_id, scenario.user_email)
@ -33,7 +33,8 @@ async def test_create_activity_comment_endpoint(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_activities_endpoint_supports_pagination( async def test_list_activities_endpoint_supports_pagination(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_scenario(session_factory) scenario = await prepare_scenario(session_factory)
token = make_token(scenario.user_id, scenario.user_email) token = make_token(scenario.user_id, scenario.user_email)

View File

@ -1,4 +1,5 @@
"""API tests for analytics endpoints.""" """API tests for analytics endpoints."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
@ -6,15 +7,14 @@ from datetime import datetime, timedelta, timezone
from decimal import Decimal from decimal import Decimal
import pytest import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.core.security import jwt_service from app.core.security import jwt_service
from app.models.contact import Contact from app.models.contact import Contact
from app.models.deal import Deal, DealStage, DealStatus from app.models.deal import Deal, DealStage, DealStatus
from app.models.organization import Organization from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.user import User from app.models.user import User
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@dataclass(slots=True) @dataclass(slots=True)
@ -26,10 +26,14 @@ class AnalyticsScenario:
in_progress_deal_id: int in_progress_deal_id: int
async def prepare_analytics_scenario(session_factory: async_sessionmaker[AsyncSession]) -> AnalyticsScenario: async def prepare_analytics_scenario(
session_factory: async_sessionmaker[AsyncSession],
) -> AnalyticsScenario:
async with session_factory() as session: async with session_factory() as session:
org = Organization(name="Analytics Org") org = Organization(name="Analytics Org")
user = User(email="analytics@example.com", hashed_password="hashed", name="Analyst", is_active=True) user = User(
email="analytics@example.com", hashed_password="hashed", name="Analyst", is_active=True
)
session.add_all([org, user]) session.add_all([org, user])
await session.flush() await session.flush()
@ -103,7 +107,9 @@ async def prepare_analytics_scenario(session_factory: async_sessionmaker[AsyncSe
user_id=user.id, user_id=user.id,
user_email=user.email, user_email=user.email,
token=token, token=token,
in_progress_deal_id=next(deal.id for deal in deals if deal.status is DealStatus.IN_PROGRESS), in_progress_deal_id=next(
deal.id for deal in deals if deal.status is DealStatus.IN_PROGRESS
),
) )
@ -113,7 +119,8 @@ def _headers(token: str, organization_id: int) -> dict[str, str]:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deals_summary_endpoint_returns_metrics( async def test_deals_summary_endpoint_returns_metrics(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_analytics_scenario(session_factory) scenario = await prepare_analytics_scenario(session_factory)
@ -134,7 +141,8 @@ async def test_deals_summary_endpoint_returns_metrics(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deals_summary_respects_days_filter( async def test_deals_summary_respects_days_filter(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_analytics_scenario(session_factory) scenario = await prepare_analytics_scenario(session_factory)
@ -150,7 +158,8 @@ async def test_deals_summary_respects_days_filter(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deals_funnel_returns_breakdown( async def test_deals_funnel_returns_breakdown(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_analytics_scenario(session_factory) scenario = await prepare_analytics_scenario(session_factory)
@ -162,7 +171,9 @@ async def test_deals_funnel_returns_breakdown(
assert response.status_code == 200 assert response.status_code == 200
payload = response.json() payload = response.json()
assert len(payload["stages"]) == 4 assert len(payload["stages"]) == 4
qualification = next(item for item in payload["stages"] if item["stage"] == DealStage.QUALIFICATION.value) qualification = next(
item for item in payload["stages"] if item["stage"] == DealStage.QUALIFICATION.value
)
assert qualification["total"] == 1 assert qualification["total"] == 1
proposal = next(item for item in payload["stages"] if item["stage"] == DealStage.PROPOSAL.value) proposal = next(item for item in payload["stages"] if item["stage"] == DealStage.PROPOSAL.value)
assert proposal["conversion_to_next"] == 100.0 assert proposal["conversion_to_next"] == 100.0
@ -198,4 +209,4 @@ async def test_deal_update_invalidates_cached_summary(
) )
assert refreshed.status_code == 200 assert refreshed.status_code == 200
payload = refreshed.json() payload = refreshed.json()
assert payload["won"]["count"] == 2 assert payload["won"]["count"] == 2

View File

@ -1,15 +1,15 @@
"""API tests for authentication endpoints.""" """API tests for authentication endpoints."""
from __future__ import annotations from __future__ import annotations
import pytest import pytest
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.core.security import password_hasher from app.core.security import password_hasher
from app.models.organization import Organization from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.user import User from app.models.user import User
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@pytest.mark.asyncio @pytest.mark.asyncio
@ -37,7 +37,7 @@ async def test_register_user_creates_organization_membership(
assert user is not None assert user is not None
organization = await session.scalar( organization = await session.scalar(
select(Organization).where(Organization.name == payload["organization_name"]) select(Organization).where(Organization.name == payload["organization_name"]),
) )
assert organization is not None assert organization is not None
@ -45,7 +45,7 @@ async def test_register_user_creates_organization_membership(
select(OrganizationMember).where( select(OrganizationMember).where(
OrganizationMember.organization_id == organization.id, OrganizationMember.organization_id == organization.id,
OrganizationMember.user_id == user.id, OrganizationMember.user_id == user.id,
) ),
) )
assert membership is not None assert membership is not None
assert membership.role == OrganizationRole.OWNER assert membership.role == OrganizationRole.OWNER
@ -71,7 +71,7 @@ async def test_register_user_without_organization_succeeds(
assert user is not None assert user is not None
membership = await session.scalar( membership = await session.scalar(
select(OrganizationMember).where(OrganizationMember.user_id == user.id) select(OrganizationMember).where(OrganizationMember.user_id == user.id),
) )
assert membership is None assert membership is None

View File

@ -1,21 +1,21 @@
"""API tests for contact endpoints.""" """API tests for contact endpoints."""
from __future__ import annotations from __future__ import annotations
import pytest import pytest
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.models.contact import Contact from app.models.contact import Contact
from app.models.organization_member import OrganizationMember, OrganizationRole from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.user import User from app.models.user import User
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from tests.api.v1.task_activity_shared import auth_headers, make_token, prepare_scenario from tests.api.v1.task_activity_shared import auth_headers, make_token, prepare_scenario
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_contacts_supports_search_and_pagination( async def test_list_contacts_supports_search_and_pagination(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_scenario(session_factory) scenario = await prepare_scenario(session_factory)
token = make_token(scenario.user_id, scenario.user_email) token = make_token(scenario.user_id, scenario.user_email)
@ -37,7 +37,7 @@ async def test_list_contacts_supports_search_and_pagination(
email="beta@example.com", email="beta@example.com",
phone=None, phone=None,
), ),
] ],
) )
await session.commit() await session.commit()
@ -54,7 +54,8 @@ async def test_list_contacts_supports_search_and_pagination(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_contact_returns_created_payload( async def test_create_contact_returns_created_payload(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_scenario(session_factory) scenario = await prepare_scenario(session_factory)
token = make_token(scenario.user_id, scenario.user_email) token = make_token(scenario.user_id, scenario.user_email)
@ -78,7 +79,8 @@ async def test_create_contact_returns_created_payload(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_member_cannot_assign_foreign_owner( async def test_member_cannot_assign_foreign_owner(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_scenario(session_factory) scenario = await prepare_scenario(session_factory)
token = make_token(scenario.user_id, scenario.user_email) token = make_token(scenario.user_id, scenario.user_email)
@ -88,7 +90,7 @@ async def test_member_cannot_assign_foreign_owner(
select(OrganizationMember).where( select(OrganizationMember).where(
OrganizationMember.organization_id == scenario.organization_id, OrganizationMember.organization_id == scenario.organization_id,
OrganizationMember.user_id == scenario.user_id, OrganizationMember.user_id == scenario.user_id,
) ),
) )
assert membership is not None assert membership is not None
membership.role = OrganizationRole.MEMBER membership.role = OrganizationRole.MEMBER
@ -107,7 +109,7 @@ async def test_member_cannot_assign_foreign_owner(
organization_id=scenario.organization_id, organization_id=scenario.organization_id,
user_id=other_user.id, user_id=other_user.id,
role=OrganizationRole.ADMIN, role=OrganizationRole.ADMIN,
) ),
) )
await session.commit() await session.commit()
@ -126,7 +128,8 @@ async def test_member_cannot_assign_foreign_owner(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_member_can_view_foreign_contacts( async def test_member_can_view_foreign_contacts(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_scenario(session_factory) scenario = await prepare_scenario(session_factory)
token = make_token(scenario.user_id, scenario.user_email) token = make_token(scenario.user_id, scenario.user_email)
@ -136,7 +139,7 @@ async def test_member_can_view_foreign_contacts(
select(OrganizationMember).where( select(OrganizationMember).where(
OrganizationMember.organization_id == scenario.organization_id, OrganizationMember.organization_id == scenario.organization_id,
OrganizationMember.user_id == scenario.user_id, OrganizationMember.user_id == scenario.user_id,
) ),
) )
assert membership is not None assert membership is not None
membership.role = OrganizationRole.MEMBER membership.role = OrganizationRole.MEMBER
@ -155,7 +158,7 @@ async def test_member_can_view_foreign_contacts(
organization_id=scenario.organization_id, organization_id=scenario.organization_id,
user_id=other_user.id, user_id=other_user.id,
role=OrganizationRole.MANAGER, role=OrganizationRole.MANAGER,
) ),
) )
session.add( session.add(
@ -165,7 +168,7 @@ async def test_member_can_view_foreign_contacts(
name="Foreign Owner", name="Foreign Owner",
email="foreign@example.com", email="foreign@example.com",
phone=None, phone=None,
) ),
) )
await session.commit() await session.commit()
@ -181,7 +184,8 @@ async def test_member_can_view_foreign_contacts(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_member_patch_foreign_contact_forbidden( async def test_member_patch_foreign_contact_forbidden(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_scenario(session_factory) scenario = await prepare_scenario(session_factory)
token = make_token(scenario.user_id, scenario.user_email) token = make_token(scenario.user_id, scenario.user_email)
@ -191,7 +195,7 @@ async def test_member_patch_foreign_contact_forbidden(
select(OrganizationMember).where( select(OrganizationMember).where(
OrganizationMember.organization_id == scenario.organization_id, OrganizationMember.organization_id == scenario.organization_id,
OrganizationMember.user_id == scenario.user_id, OrganizationMember.user_id == scenario.user_id,
) ),
) )
assert membership is not None assert membership is not None
membership.role = OrganizationRole.MEMBER membership.role = OrganizationRole.MEMBER
@ -210,7 +214,7 @@ async def test_member_patch_foreign_contact_forbidden(
organization_id=scenario.organization_id, organization_id=scenario.organization_id,
user_id=other_user.id, user_id=other_user.id,
role=OrganizationRole.MANAGER, role=OrganizationRole.MANAGER,
) ),
) )
contact = Contact( contact = Contact(
@ -235,7 +239,8 @@ async def test_member_patch_foreign_contact_forbidden(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_patch_contact_updates_fields( async def test_patch_contact_updates_fields(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_scenario(session_factory) scenario = await prepare_scenario(session_factory)
token = make_token(scenario.user_id, scenario.user_email) token = make_token(scenario.user_id, scenario.user_email)
@ -266,7 +271,8 @@ async def test_patch_contact_updates_fields(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_contact_with_deals_returns_conflict( async def test_delete_contact_with_deals_returns_conflict(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_scenario(session_factory) scenario = await prepare_scenario(session_factory)
token = make_token(scenario.user_id, scenario.user_email) token = make_token(scenario.user_id, scenario.user_email)

View File

@ -1,16 +1,15 @@
"""API tests for deal endpoints.""" """API tests for deal endpoints."""
from __future__ import annotations from __future__ import annotations
from decimal import Decimal from decimal import Decimal
import pytest import pytest
from app.models.activity import Activity, ActivityType
from app.models.deal import Deal, DealStage, DealStatus
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.models.activity import Activity, ActivityType
from app.models.deal import Deal, DealStage, DealStatus
from tests.api.v1.task_activity_shared import auth_headers, make_token, prepare_scenario from tests.api.v1.task_activity_shared import auth_headers, make_token, prepare_scenario
@ -105,7 +104,7 @@ async def test_update_deal_endpoint_updates_stage_and_logs_activity(
async with session_factory() as session: async with session_factory() as session:
activity_types = await session.scalars( activity_types = await session.scalars(
select(Activity.type).where(Activity.deal_id == scenario.deal_id) select(Activity.type).where(Activity.deal_id == scenario.deal_id),
) )
collected = set(activity_types.all()) collected = set(activity_types.all())

View File

@ -1,16 +1,13 @@
"""API tests for organization endpoints.""" """API tests for organization endpoints."""
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator, Sequence
from datetime import timedelta from datetime import timedelta
from typing import AsyncGenerator, Sequence, cast from typing import cast
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.schema import Table
from app.api.deps import get_db_session from app.api.deps import get_db_session
from app.core.security import jwt_service from app.core.security import jwt_service
from app.main import create_app from app.main import create_app
@ -18,6 +15,10 @@ from app.models import Base
from app.models.organization import Organization from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.user import User from app.models.user import User
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.schema import Table
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
@ -55,10 +56,13 @@ async def client(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_user_organizations_returns_memberships( async def test_list_user_organizations_returns_memberships(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
async with session_factory() as session: async with session_factory() as session:
user = User(email="owner@example.com", hashed_password="hashed", name="Owner", is_active=True) user = User(
email="owner@example.com", hashed_password="hashed", name="Owner", is_active=True
)
session.add(user) session.add(user)
await session.flush() await session.flush()
@ -110,8 +114,12 @@ async def test_owner_can_add_member_to_organization(
client: AsyncClient, client: AsyncClient,
) -> None: ) -> None:
async with session_factory() as session: async with session_factory() as session:
owner = User(email="owner-add@example.com", hashed_password="hashed", name="Owner", is_active=True) owner = User(
invitee = User(email="new-member@example.com", hashed_password="hashed", name="Member", is_active=True) email="owner-add@example.com", hashed_password="hashed", name="Owner", is_active=True
)
invitee = User(
email="new-member@example.com", hashed_password="hashed", name="Member", is_active=True
)
session.add_all([owner, invitee]) session.add_all([owner, invitee])
await session.flush() await session.flush()
@ -153,7 +161,7 @@ async def test_owner_can_add_member_to_organization(
select(OrganizationMember).where( select(OrganizationMember).where(
OrganizationMember.organization_id == organization.id, OrganizationMember.organization_id == organization.id,
OrganizationMember.user_id == invitee.id, OrganizationMember.user_id == invitee.id,
) ),
) )
assert new_membership is not None assert new_membership is not None
assert new_membership.role == OrganizationRole.MANAGER assert new_membership.role == OrganizationRole.MANAGER
@ -165,7 +173,12 @@ async def test_add_member_requires_existing_user(
client: AsyncClient, client: AsyncClient,
) -> None: ) -> None:
async with session_factory() as session: async with session_factory() as session:
owner = User(email="owner-missing@example.com", hashed_password="hashed", name="Owner", is_active=True) owner = User(
email="owner-missing@example.com",
hashed_password="hashed",
name="Owner",
is_active=True,
)
session.add(owner) session.add(owner)
await session.flush() await session.flush()
@ -206,8 +219,12 @@ async def test_member_role_cannot_add_users(
client: AsyncClient, client: AsyncClient,
) -> None: ) -> None:
async with session_factory() as session: async with session_factory() as session:
member_user = User(email="member@example.com", hashed_password="hashed", name="Member", is_active=True) member_user = User(
invitee = User(email="invitee@example.com", hashed_password="hashed", name="Invitee", is_active=True) email="member@example.com", hashed_password="hashed", name="Member", is_active=True
)
invitee = User(
email="invitee@example.com", hashed_password="hashed", name="Invitee", is_active=True
)
session.add_all([member_user, invitee]) session.add_all([member_user, invitee])
await session.flush() await session.flush()
@ -248,8 +265,12 @@ async def test_cannot_add_duplicate_member(
client: AsyncClient, client: AsyncClient,
) -> None: ) -> None:
async with session_factory() as session: async with session_factory() as session:
owner = User(email="dup-owner@example.com", hashed_password="hashed", name="Owner", is_active=True) owner = User(
invitee = User(email="dup-member@example.com", hashed_password="hashed", name="Invitee", is_active=True) email="dup-owner@example.com", hashed_password="hashed", name="Owner", is_active=True
)
invitee = User(
email="dup-member@example.com", hashed_password="hashed", name="Invitee", is_active=True
)
session.add_all([owner, invitee]) session.add_all([owner, invitee])
await session.flush() await session.flush()

View File

@ -1,20 +1,25 @@
"""API tests for task endpoints.""" """API tests for task endpoints."""
from __future__ import annotations from __future__ import annotations
from datetime import date, datetime, timedelta, timezone from datetime import date, datetime, timedelta, timezone
import pytest import pytest
from app.models.task import Task
from httpx import AsyncClient from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from tests.api.v1.task_activity_shared import (
from app.models.task import Task auth_headers,
create_deal,
from tests.api.v1.task_activity_shared import auth_headers, create_deal, make_token, prepare_scenario make_token,
prepare_scenario,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_task_endpoint_creates_task_and_activity( async def test_create_task_endpoint_creates_task_and_activity(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_scenario(session_factory) scenario = await prepare_scenario(session_factory)
token = make_token(scenario.user_id, scenario.user_email) token = make_token(scenario.user_id, scenario.user_email)
@ -40,7 +45,8 @@ async def test_create_task_endpoint_creates_task_and_activity(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_tasks_endpoint_filters_by_deal( async def test_list_tasks_endpoint_filters_by_deal(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None: ) -> None:
scenario = await prepare_scenario(session_factory) scenario = await prepare_scenario(session_factory)
token = make_token(scenario.user_id, scenario.user_email) token = make_token(scenario.user_id, scenario.user_email)
@ -63,7 +69,7 @@ async def test_list_tasks_endpoint_filters_by_deal(
due_date=datetime.now(timezone.utc) + timedelta(days=3), due_date=datetime.now(timezone.utc) + timedelta(days=3),
is_done=False, is_done=False,
), ),
] ],
) )
await session.commit() await session.commit()

View File

@ -1,4 +1,5 @@
"""Pytest configuration & shared fixtures.""" """Pytest configuration & shared fixtures."""
from __future__ import annotations from __future__ import annotations
import sys import sys

View File

@ -1,4 +1,5 @@
"""Regression tests ensuring Enum mappings store lowercase values.""" """Regression tests ensuring Enum mappings store lowercase values."""
from __future__ import annotations from __future__ import annotations
from enum import StrEnum from enum import StrEnum

View File

@ -1,14 +1,12 @@
"""Unit tests for ActivityService.""" """Unit tests for ActivityService."""
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator
import uuid import uuid
from collections.abc import AsyncGenerator
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.models.activity import Activity, ActivityType from app.models.activity import Activity, ActivityType
from app.models.base import Base from app.models.base import Base
from app.models.contact import Contact from app.models.contact import Contact
@ -24,6 +22,8 @@ from app.services.activity_service import (
ActivityValidationError, ActivityValidationError,
) )
from app.services.organization_service import OrganizationContext from app.services.organization_service import OrganizationContext
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
@ -91,9 +91,16 @@ async def test_list_activities_returns_only_current_deal(session: AsyncSession)
session.add_all( session.add_all(
[ [
Activity(deal_id=deal_id, author_id=context.user_id, type=ActivityType.COMMENT, payload={"text": "hi"}), Activity(
Activity(deal_id=deal_id + 1, author_id=context.user_id, type=ActivityType.SYSTEM, payload={}), deal_id=deal_id,
] author_id=context.user_id,
type=ActivityType.COMMENT,
payload={"text": "hi"},
),
Activity(
deal_id=deal_id + 1, author_id=context.user_id, type=ActivityType.SYSTEM, payload={}
),
],
) )
await session.flush() await session.flush()
@ -112,7 +119,9 @@ async def test_add_comment_rejects_empty_text(session: AsyncSession) -> None:
service = ActivityService(repository=repo) service = ActivityService(repository=repo)
with pytest.raises(ActivityValidationError): with pytest.raises(ActivityValidationError):
await service.add_comment(deal_id=deal_id, author_id=context.user_id, text=" ", context=context) await service.add_comment(
deal_id=deal_id, author_id=context.user_id, text=" ", context=context
)
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -1,4 +1,5 @@
"""Unit tests for AnalyticsService.""" """Unit tests for AnalyticsService."""
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
@ -7,9 +8,6 @@ from decimal import Decimal
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.models import Base from app.models import Base
from app.models.contact import Contact from app.models.contact import Contact
from app.models.deal import Deal, DealStage, DealStatus from app.models.deal import Deal, DealStage, DealStatus
@ -18,13 +16,17 @@ from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.user import User from app.models.user import User
from app.repositories.analytics_repo import AnalyticsRepository from app.repositories.analytics_repo import AnalyticsRepository
from app.services.analytics_service import AnalyticsService, invalidate_analytics_cache from app.services.analytics_service import AnalyticsService, invalidate_analytics_cache
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from tests.utils.fake_redis import InMemoryRedis from tests.utils.fake_redis import InMemoryRedis
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
async def session() -> AsyncGenerator[AsyncSession, None]: async def session() -> AsyncGenerator[AsyncSession, None]:
engine = create_async_engine( engine = create_async_engine(
"sqlite+aiosqlite:///:memory:", future=True, poolclass=StaticPool "sqlite+aiosqlite:///:memory:",
future=True,
poolclass=StaticPool,
) )
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
@ -36,12 +38,18 @@ async def session() -> AsyncGenerator[AsyncSession, None]:
async def _seed_data(session: AsyncSession) -> tuple[int, int, int]: async def _seed_data(session: AsyncSession) -> tuple[int, int, int]:
org = Organization(name="Analytics Org") org = Organization(name="Analytics Org")
user = User(email="analytics@example.com", hashed_password="hashed", name="Analyst", is_active=True) user = User(
email="analytics@example.com", hashed_password="hashed", name="Analyst", is_active=True
)
session.add_all([org, user]) session.add_all([org, user])
await session.flush() await session.flush()
member = OrganizationMember(organization_id=org.id, user_id=user.id, role=OrganizationRole.OWNER) member = OrganizationMember(
contact = Contact(organization_id=org.id, owner_id=user.id, name="Client", email="client@example.com") organization_id=org.id, user_id=user.id, role=OrganizationRole.OWNER
)
contact = Contact(
organization_id=org.id, owner_id=user.id, name="Client", email="client@example.com"
)
session.add_all([member, contact]) session.add_all([member, contact])
await session.flush() await session.flush()
@ -231,4 +239,4 @@ async def test_funnel_reads_from_cache_when_available(session: AsyncSession) ->
service._repository = _ExplodingRepository(session) service._repository = _ExplodingRepository(session)
cached = await service.get_deal_funnel(org_id) cached = await service.get_deal_funnel(org_id)
assert len(cached) == 4 assert len(cached) == 4

View File

@ -1,16 +1,16 @@
"""Unit tests for AuthService.""" """Unit tests for AuthService."""
from __future__ import annotations from __future__ import annotations
from typing import cast from typing import cast
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest # type: ignore[import-not-found] import pytest # type: ignore[import-not-found]
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import JWTService, PasswordHasher from app.core.security import JWTService, PasswordHasher
from app.models.user import User from app.models.user import User
from app.repositories.user_repo import UserRepository from app.repositories.user_repo import UserRepository
from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError
from sqlalchemy.ext.asyncio import AsyncSession
class StubUserRepository(UserRepository): class StubUserRepository(UserRepository):
@ -49,7 +49,9 @@ def jwt_service() -> JWTService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_success(password_hasher: PasswordHasher, jwt_service: JWTService) -> None: async def test_authenticate_success(
password_hasher: PasswordHasher, jwt_service: JWTService
) -> None:
hashed = password_hasher.hash("StrongPass123") hashed = password_hasher.hash("StrongPass123")
user = User(email="user@example.com", hashed_password=hashed, name="Alice", is_active=True) user = User(email="user@example.com", hashed_password=hashed, name="Alice", is_active=True)
user.id = 1 user.id = 1
@ -100,7 +102,9 @@ async def test_refresh_tokens_returns_new_pair(
password_hasher: PasswordHasher, password_hasher: PasswordHasher,
jwt_service: JWTService, jwt_service: JWTService,
) -> None: ) -> None:
user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True) user = User(
email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True
)
user.id = 7 user.id = 7
service = AuthService(StubUserRepository(user), password_hasher, jwt_service) service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
@ -116,7 +120,9 @@ async def test_refresh_tokens_rejects_access_token(
password_hasher: PasswordHasher, password_hasher: PasswordHasher,
jwt_service: JWTService, jwt_service: JWTService,
) -> None: ) -> None:
user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True) user = User(
email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True
)
user.id = 9 user.id = 9
service = AuthService(StubUserRepository(user), password_hasher, jwt_service) service = AuthService(StubUserRepository(user), password_hasher, jwt_service)

View File

@ -1,15 +1,12 @@
"""Unit tests for ContactService.""" """Unit tests for ContactService."""
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator
import uuid import uuid
from collections.abc import AsyncGenerator
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.models.base import Base from app.models.base import Base
from app.models.contact import Contact, ContactCreate from app.models.contact import Contact, ContactCreate
from app.models.deal import Deal from app.models.deal import Deal
@ -25,6 +22,9 @@ from app.services.contact_service import (
ContactUpdateData, ContactUpdateData,
) )
from app.services.organization_service import OrganizationContext from app.services.organization_service import OrganizationContext
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
@ -244,7 +244,7 @@ async def test_delete_contact_blocks_when_deals_exist(session: AsyncSession) ->
owner_id=contact.owner_id, owner_id=contact.owner_id,
title="Pending", title="Pending",
amount=None, amount=None,
) ),
) )
await session.flush() await session.flush()

View File

@ -1,16 +1,13 @@
"""Unit tests for DealService.""" """Unit tests for DealService."""
from __future__ import annotations from __future__ import annotations
import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from decimal import Decimal from decimal import Decimal
import uuid
import pytest # type: ignore[import-not-found] import pytest # type: ignore[import-not-found]
import pytest_asyncio # type: ignore[import-not-found] import pytest_asyncio # type: ignore[import-not-found]
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.models.activity import Activity, ActivityType from app.models.activity import Activity, ActivityType
from app.models.base import Base from app.models.base import Base
from app.models.contact import Contact from app.models.contact import Contact
@ -28,6 +25,9 @@ from app.services.deal_service import (
DealUpdateData, DealUpdateData,
) )
from app.services.organization_service import OrganizationContext from app.services.organization_service import OrganizationContext
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
@ -64,7 +64,9 @@ def _make_context(org: Organization, user: User, role: OrganizationRole) -> Orga
return OrganizationContext(organization=org, membership=membership) return OrganizationContext(organization=org, membership=membership)
async def _persist_base(session: AsyncSession, *, role: OrganizationRole = OrganizationRole.MANAGER) -> tuple[ async def _persist_base(
session: AsyncSession, *, role: OrganizationRole = OrganizationRole.MANAGER
) -> tuple[
OrganizationContext, OrganizationContext,
Contact, Contact,
DealRepository, DealRepository,

View File

@ -1,12 +1,11 @@
"""Unit tests for OrganizationService.""" """Unit tests for OrganizationService."""
from __future__ import annotations from __future__ import annotations
from typing import cast from typing import cast
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest # type: ignore[import-not-found] import pytest # type: ignore[import-not-found]
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.organization import Organization from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole from app.models.organization_member import OrganizationMember, OrganizationRole
from app.repositories.org_repo import OrganizationRepository from app.repositories.org_repo import OrganizationRepository
@ -18,6 +17,7 @@ from app.services.organization_service import (
OrganizationMemberAlreadyExistsError, OrganizationMemberAlreadyExistsError,
OrganizationService, OrganizationService,
) )
from sqlalchemy.ext.asyncio import AsyncSession
class StubOrganizationRepository(OrganizationRepository): class StubOrganizationRepository(OrganizationRepository):
@ -27,7 +27,9 @@ class StubOrganizationRepository(OrganizationRepository):
super().__init__(session=MagicMock(spec=AsyncSession)) super().__init__(session=MagicMock(spec=AsyncSession))
self._membership = membership self._membership = membership
async def get_membership(self, organization_id: int, user_id: int) -> OrganizationMember | None: # pragma: no cover - helper async def get_membership(
self, organization_id: int, user_id: int
) -> OrganizationMember | None: # pragma: no cover - helper
if ( if (
self._membership self._membership
and self._membership.organization_id == organization_id and self._membership.organization_id == organization_id
@ -37,7 +39,9 @@ class StubOrganizationRepository(OrganizationRepository):
return None return None
def make_membership(role: OrganizationRole, *, organization_id: int = 1, user_id: int = 10) -> OrganizationMember: def make_membership(
role: OrganizationRole, *, organization_id: int = 1, user_id: int = 10
) -> OrganizationMember:
organization = Organization(name="Acme Inc") organization = Organization(name="Acme Inc")
organization.id = organization_id organization.id = organization_id
membership = OrganizationMember( membership = OrganizationMember(
@ -70,7 +74,9 @@ class SessionStub:
class MembershipRepositoryStub(OrganizationRepository): class MembershipRepositoryStub(OrganizationRepository):
"""Repository stub that can emulate duplicate checks for add_member.""" """Repository stub that can emulate duplicate checks for add_member."""
def __init__(self, memberships: dict[tuple[int, int], OrganizationMember] | None = None) -> None: def __init__(
self, memberships: dict[tuple[int, int], OrganizationMember] | None = None
) -> None:
self._session_stub = SessionStub() self._session_stub = SessionStub()
super().__init__(session=cast(AsyncSession, self._session_stub)) super().__init__(session=cast(AsyncSession, self._session_stub))
self._memberships = memberships or {} self._memberships = memberships or {}
@ -88,7 +94,9 @@ async def test_get_context_success() -> None:
membership = make_membership(OrganizationRole.MANAGER) membership = make_membership(OrganizationRole.MANAGER)
service = OrganizationService(StubOrganizationRepository(membership)) service = OrganizationService(StubOrganizationRepository(membership))
context = await service.get_context(user_id=membership.user_id, organization_id=membership.organization_id) context = await service.get_context(
user_id=membership.user_id, organization_id=membership.organization_id
)
assert context.organization_id == membership.organization_id assert context.organization_id == membership.organization_id
assert context.role == OrganizationRole.MANAGER assert context.role == OrganizationRole.MANAGER
@ -174,7 +182,9 @@ async def test_add_member_rejects_duplicate_membership() -> None:
service = OrganizationService(repo) service = OrganizationService(repo)
with pytest.raises(OrganizationMemberAlreadyExistsError): with pytest.raises(OrganizationMemberAlreadyExistsError):
await service.add_member(context=context, user_id=duplicate_user_id, role=OrganizationRole.MANAGER) await service.add_member(
context=context, user_id=duplicate_user_id, role=OrganizationRole.MANAGER
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -191,4 +201,4 @@ async def test_add_member_requires_privileged_role() -> None:
await service.add_member(context=context, user_id=99, role=OrganizationRole.MANAGER) await service.add_member(context=context, user_id=99, role=OrganizationRole.MANAGER)
# Ensure DB work not attempted when permissions fail. # Ensure DB work not attempted when permissions fail.
assert repo.session_stub.committed is False assert repo.session_stub.committed is False

View File

@ -1,16 +1,13 @@
"""Unit tests for TaskService.""" """Unit tests for TaskService."""
from __future__ import annotations from __future__ import annotations
import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
import uuid
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.models.activity import Activity, ActivityType from app.models.activity import Activity, ActivityType
from app.models.base import Base from app.models.base import Base
from app.models.contact import Contact from app.models.contact import Contact
@ -28,6 +25,9 @@ from app.services.task_service import (
TaskService, TaskService,
TaskUpdateData, TaskUpdateData,
) )
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
@ -189,7 +189,9 @@ async def test_member_cannot_update_foreign_task(session: AsyncSession) -> None:
user_id=member.id, user_id=member.id,
role=OrganizationRole.MEMBER, role=OrganizationRole.MEMBER,
) )
member_context = OrganizationContext(organization=context_owner.organization, membership=membership) member_context = OrganizationContext(
organization=context_owner.organization, membership=membership
)
with pytest.raises(TaskForbiddenError): with pytest.raises(TaskForbiddenError):
await service.update_task( await service.update_task(

View File

@ -1,4 +1,5 @@
"""Simple in-memory Redis replacement for tests.""" """Simple in-memory Redis replacement for tests."""
from __future__ import annotations from __future__ import annotations
import fnmatch import fnmatch