test_task_crm/app/api/deps.py

130 lines
4.5 KiB
Python

"""Reusable FastAPI dependencies."""
from collections.abc import AsyncGenerator
import jwt
from fastapi import Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.database import get_session
from app.core.security import jwt_service, password_hasher
from app.models.user import User
from app.repositories.activity_repo import ActivityRepository
from app.repositories.deal_repo import DealRepository
from app.repositories.org_repo import OrganizationRepository
from app.repositories.task_repo import TaskRepository
from app.repositories.user_repo import UserRepository
from app.services.auth_service import AuthService
from app.services.activity_service import ActivityService
from app.services.deal_service import DealService
from app.services.organization_service import (
OrganizationAccessDeniedError,
OrganizationContext,
OrganizationContextMissingError,
OrganizationService,
)
from app.services.task_service import TaskService
from app.services.user_service import UserService
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token")
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""Provide a scoped database session."""
async for session in get_session():
yield session
def get_user_repository(session: AsyncSession = Depends(get_db_session)) -> UserRepository:
return UserRepository(session=session)
def get_organization_repository(session: AsyncSession = Depends(get_db_session)) -> OrganizationRepository:
return OrganizationRepository(session=session)
def get_deal_repository(session: AsyncSession = Depends(get_db_session)) -> DealRepository:
return DealRepository(session=session)
def get_task_repository(session: AsyncSession = Depends(get_db_session)) -> TaskRepository:
return TaskRepository(session=session)
def get_activity_repository(session: AsyncSession = Depends(get_db_session)) -> ActivityRepository:
return ActivityRepository(session=session)
def get_deal_service(repo: DealRepository = Depends(get_deal_repository)) -> DealService:
return DealService(repository=repo)
def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService:
return UserService(user_repository=repo, password_hasher=password_hasher)
def get_auth_service(
repo: UserRepository = Depends(get_user_repository),
) -> AuthService:
return AuthService(
user_repository=repo,
password_hasher=password_hasher,
jwt_service=jwt_service,
)
def get_organization_service(
repo: OrganizationRepository = Depends(get_organization_repository),
) -> OrganizationService:
return OrganizationService(repository=repo)
def get_activity_service(
repo: ActivityRepository = Depends(get_activity_repository),
) -> ActivityService:
return ActivityService(repository=repo)
def get_task_service(
task_repo: TaskRepository = Depends(get_task_repository),
activity_repo: ActivityRepository = Depends(get_activity_repository),
) -> TaskService:
return TaskService(task_repository=task_repo, activity_repository=activity_repo)
async def get_current_user(
token: str = Depends(oauth2_scheme),
repo: UserRepository = Depends(get_user_repository),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)
try:
payload = jwt_service.decode(token)
sub = payload.get("sub")
if sub is None:
raise credentials_exception
user_id = int(sub)
except (jwt.PyJWTError, TypeError, ValueError):
raise credentials_exception from None
user = await repo.get_by_id(user_id)
if user is None:
raise credentials_exception
return user
async def get_organization_context(
x_organization_id: int | None = Header(default=None, alias="X-Organization-Id"),
current_user: User = Depends(get_current_user),
service: OrganizationService = Depends(get_organization_service),
) -> OrganizationContext:
try:
return await service.get_context(user_id=current_user.id, organization_id=x_organization_id)
except OrganizationContextMissingError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
except OrganizationAccessDeniedError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc