"""Reusable FastAPI dependencies.""" from collections.abc import AsyncGenerator import jwt from fastapi import Depends, Header, HTTPException, status from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession from app.core.cache import get_cache_client from app.core.config import settings from app.core.database import get_session from app.core.security import jwt_service, password_hasher from app.models.user import User from app.repositories.activity_repo import ActivityRepository from app.repositories.analytics_repo import AnalyticsRepository from app.repositories.contact_repo import ContactRepository from app.repositories.deal_repo import DealRepository from app.repositories.org_repo import OrganizationRepository from app.repositories.task_repo import TaskRepository from app.repositories.user_repo import UserRepository from app.services.analytics_service import AnalyticsService from app.services.auth_service import AuthService from app.services.activity_service import ActivityService from app.services.contact_service import ContactService from app.services.deal_service import DealService from app.services.organization_service import ( OrganizationAccessDeniedError, OrganizationContext, OrganizationContextMissingError, OrganizationService, ) from app.services.task_service import TaskService from redis.asyncio.client import Redis oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token") async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """Provide a scoped database session.""" async for session in get_session(): yield session def get_user_repository(session: AsyncSession = Depends(get_db_session)) -> UserRepository: return UserRepository(session=session) def get_organization_repository(session: AsyncSession = Depends(get_db_session)) -> OrganizationRepository: return OrganizationRepository(session=session) def get_deal_repository(session: AsyncSession = Depends(get_db_session)) -> DealRepository: return DealRepository(session=session) def get_contact_repository(session: AsyncSession = Depends(get_db_session)) -> ContactRepository: return ContactRepository(session=session) def get_task_repository(session: AsyncSession = Depends(get_db_session)) -> TaskRepository: return TaskRepository(session=session) def get_activity_repository(session: AsyncSession = Depends(get_db_session)) -> ActivityRepository: return ActivityRepository(session=session) def get_analytics_repository(session: AsyncSession = Depends(get_db_session)) -> AnalyticsRepository: return AnalyticsRepository(session=session) def get_cache_backend() -> Redis | None: return get_cache_client() def get_deal_service( repo: DealRepository = Depends(get_deal_repository), cache: Redis | None = Depends(get_cache_backend), ) -> DealService: return DealService( repository=repo, cache=cache, cache_backoff_ms=settings.analytics_cache_backoff_ms, ) def get_auth_service( repo: UserRepository = Depends(get_user_repository), ) -> AuthService: return AuthService( user_repository=repo, password_hasher=password_hasher, jwt_service=jwt_service, ) def get_organization_service( repo: OrganizationRepository = Depends(get_organization_repository), ) -> OrganizationService: return OrganizationService(repository=repo) def get_activity_service( repo: ActivityRepository = Depends(get_activity_repository), ) -> ActivityService: return ActivityService(repository=repo) def get_analytics_service( repo: AnalyticsRepository = Depends(get_analytics_repository), cache: Redis | None = Depends(get_cache_backend), ) -> AnalyticsService: return AnalyticsService( repository=repo, cache=cache, ttl_seconds=settings.analytics_cache_ttl_seconds, backoff_ms=settings.analytics_cache_backoff_ms, ) def get_contact_service( repo: ContactRepository = Depends(get_contact_repository), ) -> ContactService: return ContactService(repository=repo) def get_task_service( task_repo: TaskRepository = Depends(get_task_repository), activity_repo: ActivityRepository = Depends(get_activity_repository), ) -> TaskService: return TaskService(task_repository=task_repo, activity_repository=activity_repo) async def get_current_user( token: str = Depends(oauth2_scheme), repo: UserRepository = Depends(get_user_repository), ) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", ) try: payload = jwt_service.decode(token) sub = payload.get("sub") if sub is None: raise credentials_exception scope = payload.get("scope", "access") if scope != "access": raise credentials_exception user_id = int(sub) except (jwt.PyJWTError, TypeError, ValueError): raise credentials_exception from None user = await repo.get_by_id(user_id) if user is None: raise credentials_exception return user async def get_organization_context( x_organization_id: int | None = Header(default=None, alias="X-Organization-Id"), current_user: User = Depends(get_current_user), service: OrganizationService = Depends(get_organization_service), ) -> OrganizationContext: try: return await service.get_context(user_id=current_user.id, organization_id=x_organization_id) except OrganizationContextMissingError as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc except OrganizationAccessDeniedError as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc