"""Reusable FastAPI dependencies.""" from collections.abc import AsyncGenerator import jwt from fastapi import Depends, Header, HTTPException, status from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.database import get_session from app.core.security import jwt_service, password_hasher from app.models.user import User from app.repositories.deal_repo import DealRepository from app.repositories.org_repo import OrganizationRepository from app.repositories.user_repo import UserRepository from app.services.auth_service import AuthService from app.services.deal_service import DealService from app.services.organization_service import ( OrganizationAccessDeniedError, OrganizationContext, OrganizationContextMissingError, OrganizationService, ) from app.services.user_service import UserService oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token") async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """Provide a scoped database session.""" async for session in get_session(): yield session def get_user_repository(session: AsyncSession = Depends(get_db_session)) -> UserRepository: return UserRepository(session=session) def get_organization_repository(session: AsyncSession = Depends(get_db_session)) -> OrganizationRepository: return OrganizationRepository(session=session) def get_deal_repository(session: AsyncSession = Depends(get_db_session)) -> DealRepository: return DealRepository(session=session) def get_deal_service(repo: DealRepository = Depends(get_deal_repository)) -> DealService: return DealService(repository=repo) def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService: return UserService(user_repository=repo, password_hasher=password_hasher) def get_auth_service( repo: UserRepository = Depends(get_user_repository), ) -> AuthService: return AuthService( user_repository=repo, password_hasher=password_hasher, jwt_service=jwt_service, ) def get_organization_service( repo: OrganizationRepository = Depends(get_organization_repository), ) -> OrganizationService: return OrganizationService(repository=repo) async def get_current_user( token: str = Depends(oauth2_scheme), repo: UserRepository = Depends(get_user_repository), ) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", ) try: payload = jwt_service.decode(token) sub = payload.get("sub") if sub is None: raise credentials_exception user_id = int(sub) except (jwt.PyJWTError, TypeError, ValueError): raise credentials_exception from None user = await repo.get_by_id(user_id) if user is None: raise credentials_exception return user async def get_organization_context( x_organization_id: int | None = Header(default=None, alias="X-Organization-Id"), current_user: User = Depends(get_current_user), service: OrganizationService = Depends(get_organization_service), ) -> OrganizationContext: try: return await service.get_context(user_id=current_user.id, organization_id=x_organization_id) except OrganizationContextMissingError as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc except OrganizationAccessDeniedError as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc