130 lines
4.5 KiB
Python
130 lines
4.5 KiB
Python
"""Reusable FastAPI dependencies."""
|
|
from collections.abc import AsyncGenerator
|
|
|
|
import jwt
|
|
from fastapi import Depends, Header, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.core.database import get_session
|
|
from app.core.security import jwt_service, password_hasher
|
|
from app.models.user import User
|
|
from app.repositories.activity_repo import ActivityRepository
|
|
from app.repositories.deal_repo import DealRepository
|
|
from app.repositories.org_repo import OrganizationRepository
|
|
from app.repositories.task_repo import TaskRepository
|
|
from app.repositories.user_repo import UserRepository
|
|
from app.services.auth_service import AuthService
|
|
from app.services.activity_service import ActivityService
|
|
from app.services.deal_service import DealService
|
|
from app.services.organization_service import (
|
|
OrganizationAccessDeniedError,
|
|
OrganizationContext,
|
|
OrganizationContextMissingError,
|
|
OrganizationService,
|
|
)
|
|
from app.services.task_service import TaskService
|
|
from app.services.user_service import UserService
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token")
|
|
|
|
|
|
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Provide a scoped database session."""
|
|
async for session in get_session():
|
|
yield session
|
|
|
|
|
|
def get_user_repository(session: AsyncSession = Depends(get_db_session)) -> UserRepository:
|
|
return UserRepository(session=session)
|
|
|
|
|
|
def get_organization_repository(session: AsyncSession = Depends(get_db_session)) -> OrganizationRepository:
|
|
return OrganizationRepository(session=session)
|
|
|
|
|
|
def get_deal_repository(session: AsyncSession = Depends(get_db_session)) -> DealRepository:
|
|
return DealRepository(session=session)
|
|
|
|
|
|
def get_task_repository(session: AsyncSession = Depends(get_db_session)) -> TaskRepository:
|
|
return TaskRepository(session=session)
|
|
|
|
|
|
def get_activity_repository(session: AsyncSession = Depends(get_db_session)) -> ActivityRepository:
|
|
return ActivityRepository(session=session)
|
|
|
|
|
|
def get_deal_service(repo: DealRepository = Depends(get_deal_repository)) -> DealService:
|
|
return DealService(repository=repo)
|
|
|
|
|
|
def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService:
|
|
return UserService(user_repository=repo, password_hasher=password_hasher)
|
|
|
|
|
|
def get_auth_service(
|
|
repo: UserRepository = Depends(get_user_repository),
|
|
) -> AuthService:
|
|
return AuthService(
|
|
user_repository=repo,
|
|
password_hasher=password_hasher,
|
|
jwt_service=jwt_service,
|
|
)
|
|
|
|
|
|
def get_organization_service(
|
|
repo: OrganizationRepository = Depends(get_organization_repository),
|
|
) -> OrganizationService:
|
|
return OrganizationService(repository=repo)
|
|
|
|
|
|
def get_activity_service(
|
|
repo: ActivityRepository = Depends(get_activity_repository),
|
|
) -> ActivityService:
|
|
return ActivityService(repository=repo)
|
|
|
|
|
|
def get_task_service(
|
|
task_repo: TaskRepository = Depends(get_task_repository),
|
|
activity_repo: ActivityRepository = Depends(get_activity_repository),
|
|
) -> TaskService:
|
|
return TaskService(task_repository=task_repo, activity_repository=activity_repo)
|
|
|
|
|
|
async def get_current_user(
|
|
token: str = Depends(oauth2_scheme),
|
|
repo: UserRepository = Depends(get_user_repository),
|
|
) -> User:
|
|
credentials_exception = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
)
|
|
try:
|
|
payload = jwt_service.decode(token)
|
|
sub = payload.get("sub")
|
|
if sub is None:
|
|
raise credentials_exception
|
|
user_id = int(sub)
|
|
except (jwt.PyJWTError, TypeError, ValueError):
|
|
raise credentials_exception from None
|
|
|
|
user = await repo.get_by_id(user_id)
|
|
if user is None:
|
|
raise credentials_exception
|
|
return user
|
|
|
|
|
|
async def get_organization_context(
|
|
x_organization_id: int | None = Header(default=None, alias="X-Organization-Id"),
|
|
current_user: User = Depends(get_current_user),
|
|
service: OrganizationService = Depends(get_organization_service),
|
|
) -> OrganizationContext:
|
|
try:
|
|
return await service.get_context(user_id=current_user.id, organization_id=x_organization_id)
|
|
except OrganizationContextMissingError as exc:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
|
except OrganizationAccessDeniedError as exc:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|