"""Reusable FastAPI dependencies.""" from collections.abc import AsyncGenerator import jwt from fastapi import Depends, Header, HTTPException, status from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.database import get_session from app.core.security import jwt_service, password_hasher from app.models.user import User from app.repositories.activity_repo import ActivityRepository from app.repositories.analytics_repo import AnalyticsRepository from app.repositories.contact_repo import ContactRepository from app.repositories.deal_repo import DealRepository from app.repositories.org_repo import OrganizationRepository from app.repositories.task_repo import TaskRepository from app.repositories.user_repo import UserRepository from app.services.analytics_service import AnalyticsService from app.services.auth_service import AuthService from app.services.activity_service import ActivityService from app.services.contact_service import ContactService from app.services.deal_service import DealService from app.services.organization_service import ( OrganizationAccessDeniedError, OrganizationContext, OrganizationContextMissingError, OrganizationService, ) from app.services.task_service import TaskService oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token") async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """Provide a scoped database session.""" async for session in get_session(): yield session def get_user_repository(session: AsyncSession = Depends(get_db_session)) -> UserRepository: return UserRepository(session=session) def get_organization_repository(session: AsyncSession = Depends(get_db_session)) -> OrganizationRepository: return OrganizationRepository(session=session) def get_deal_repository(session: AsyncSession = Depends(get_db_session)) -> DealRepository: return DealRepository(session=session) def get_contact_repository(session: AsyncSession = Depends(get_db_session)) -> ContactRepository: return ContactRepository(session=session) def get_task_repository(session: AsyncSession = Depends(get_db_session)) -> TaskRepository: return TaskRepository(session=session) def get_activity_repository(session: AsyncSession = Depends(get_db_session)) -> ActivityRepository: return ActivityRepository(session=session) def get_analytics_repository(session: AsyncSession = Depends(get_db_session)) -> AnalyticsRepository: return AnalyticsRepository(session=session) def get_deal_service(repo: DealRepository = Depends(get_deal_repository)) -> DealService: return DealService(repository=repo) def get_auth_service( repo: UserRepository = Depends(get_user_repository), ) -> AuthService: return AuthService( user_repository=repo, password_hasher=password_hasher, jwt_service=jwt_service, ) def get_organization_service( repo: OrganizationRepository = Depends(get_organization_repository), ) -> OrganizationService: return OrganizationService(repository=repo) def get_activity_service( repo: ActivityRepository = Depends(get_activity_repository), ) -> ActivityService: return ActivityService(repository=repo) def get_analytics_service( repo: AnalyticsRepository = Depends(get_analytics_repository), ) -> AnalyticsService: return AnalyticsService(repository=repo) def get_contact_service( repo: ContactRepository = Depends(get_contact_repository), ) -> ContactService: return ContactService(repository=repo) def get_task_service( task_repo: TaskRepository = Depends(get_task_repository), activity_repo: ActivityRepository = Depends(get_activity_repository), ) -> TaskService: return TaskService(task_repository=task_repo, activity_repository=activity_repo) async def get_current_user( token: str = Depends(oauth2_scheme), repo: UserRepository = Depends(get_user_repository), ) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", ) try: payload = jwt_service.decode(token) sub = payload.get("sub") if sub is None: raise credentials_exception scope = payload.get("scope", "access") if scope != "access": raise credentials_exception user_id = int(sub) except (jwt.PyJWTError, TypeError, ValueError): raise credentials_exception from None user = await repo.get_by_id(user_id) if user is None: raise credentials_exception return user async def get_organization_context( x_organization_id: int | None = Header(default=None, alias="X-Organization-Id"), current_user: User = Depends(get_current_user), service: OrganizationService = Depends(get_organization_service), ) -> OrganizationContext: try: return await service.get_context(user_id=current_user.id, organization_id=x_organization_id) except OrganizationContextMissingError as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc except OrganizationAccessDeniedError as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc