organizations #3
|
|
@ -1,15 +1,22 @@
|
||||||
"""Reusable FastAPI dependencies."""
|
"""Reusable FastAPI dependencies."""
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from fastapi import Depends
|
import jwt
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
from app.core.database import get_session
|
from app.core.database import get_session
|
||||||
from app.core.security import jwt_service, password_hasher
|
from app.core.security import jwt_service, password_hasher
|
||||||
|
from app.models.user import User
|
||||||
|
from app.repositories.org_repo import OrganizationRepository
|
||||||
from app.repositories.user_repo import UserRepository
|
from app.repositories.user_repo import UserRepository
|
||||||
from app.services.auth_service import AuthService
|
from app.services.auth_service import AuthService
|
||||||
from app.services.user_service import UserService
|
from app.services.user_service import UserService
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token")
|
||||||
|
|
||||||
|
|
||||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""Provide a scoped database session."""
|
"""Provide a scoped database session."""
|
||||||
|
|
@ -21,6 +28,10 @@ def get_user_repository(session: AsyncSession = Depends(get_db_session)) -> User
|
||||||
return UserRepository(session=session)
|
return UserRepository(session=session)
|
||||||
|
|
||||||
|
|
||||||
|
def get_organization_repository(session: AsyncSession = Depends(get_db_session)) -> OrganizationRepository:
|
||||||
|
return OrganizationRepository(session=session)
|
||||||
|
|
||||||
|
|
||||||
def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService:
|
def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService:
|
||||||
return UserService(user_repository=repo, password_hasher=password_hasher)
|
return UserService(user_repository=repo, password_hasher=password_hasher)
|
||||||
|
|
||||||
|
|
@ -33,3 +44,26 @@ def get_auth_service(
|
||||||
password_hasher=password_hasher,
|
password_hasher=password_hasher,
|
||||||
jwt_service=jwt_service,
|
jwt_service=jwt_service,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: str = Depends(oauth2_scheme),
|
||||||
|
repo: UserRepository = Depends(get_user_repository),
|
||||||
|
) -> User:
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt_service.decode(token)
|
||||||
|
sub = payload.get("sub")
|
||||||
|
if sub is None:
|
||||||
|
raise credentials_exception
|
||||||
|
user_id = int(sub)
|
||||||
|
except (jwt.PyJWTError, TypeError, ValueError):
|
||||||
|
raise credentials_exception from None
|
||||||
|
|
||||||
|
user = await repo.get_by_id(user_id)
|
||||||
|
if user is None:
|
||||||
|
raise credentials_exception
|
||||||
|
return user
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ async def list_deals(
|
||||||
stage: str | None = None,
|
stage: str | None = None,
|
||||||
owner_id: int | None = None,
|
owner_id: int | None = None,
|
||||||
order_by: str | None = None,
|
order_by: str | None = None,
|
||||||
order: str | None = Query(default=None, regex="^(asc|desc)$"),
|
order: str | None = Query(default=None, pattern="^(asc|desc)$"),
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Placeholder for deal filtering endpoint."""
|
"""Placeholder for deal filtering endpoint."""
|
||||||
_ = (status_filter,)
|
_ = (status_filter,)
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,22 @@
|
||||||
"""Organization-related API stubs."""
|
"""Organization-related API endpoints."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user, get_organization_repository
|
||||||
|
from app.models.organization import OrganizationRead
|
||||||
|
from app.models.user import User
|
||||||
|
from app.repositories.org_repo import OrganizationRepository
|
||||||
|
|
||||||
router = APIRouter(prefix="/organizations", tags=["organizations"])
|
router = APIRouter(prefix="/organizations", tags=["organizations"])
|
||||||
|
|
||||||
|
|
||||||
def _stub(endpoint: str) -> dict[str, str]:
|
@router.get("/me", response_model=list[OrganizationRead])
|
||||||
return {"detail": f"{endpoint} is not implemented yet"}
|
async def list_user_organizations(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
repo: OrganizationRepository = Depends(get_organization_repository),
|
||||||
|
) -> list[OrganizationRead]:
|
||||||
|
"""Return organizations the authenticated user belongs to."""
|
||||||
|
|
||||||
|
organizations = await repo.list_for_user(current_user.id)
|
||||||
@router.get("/me", status_code=status.HTTP_501_NOT_IMPLEMENTED)
|
return [OrganizationRead.model_validate(org) for org in organizations]
|
||||||
async def list_user_organizations() -> dict[str, str]:
|
|
||||||
"""Placeholder for returning organizations linked to the current user."""
|
|
||||||
return _stub("GET /organizations/me")
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
"""Organization repository for database operations."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.organization import Organization, OrganizationCreate
|
||||||
|
from app.models.organization_member import OrganizationMember
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationRepository:
|
||||||
|
"""Provides CRUD helpers for Organization model."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self) -> AsyncSession:
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def list(self) -> Sequence[Organization]:
|
||||||
|
result = await self._session.scalars(select(Organization))
|
||||||
|
return result.all()
|
||||||
|
|
||||||
|
async def get_by_id(self, organization_id: int) -> Organization | None:
|
||||||
|
return await self._session.get(Organization, organization_id)
|
||||||
|
|
||||||
|
async def get_by_name(self, name: str) -> Organization | None:
|
||||||
|
stmt = select(Organization).where(Organization.name == name)
|
||||||
|
result = await self._session.scalars(stmt)
|
||||||
|
return result.first()
|
||||||
|
|
||||||
|
async def list_for_user(self, user_id: int) -> Sequence[Organization]:
|
||||||
|
stmt = (
|
||||||
|
select(Organization)
|
||||||
|
.join(OrganizationMember, OrganizationMember.organization_id == Organization.id)
|
||||||
|
.where(OrganizationMember.user_id == user_id)
|
||||||
|
.order_by(Organization.id)
|
||||||
|
)
|
||||||
|
result = await self._session.scalars(stmt)
|
||||||
|
return result.unique().all()
|
||||||
|
|
||||||
|
async def create(self, data: OrganizationCreate) -> Organization:
|
||||||
|
organization = Organization(name=data.name)
|
||||||
|
self._session.add(organization)
|
||||||
|
await self._session.flush()
|
||||||
|
return organization
|
||||||
|
|
@ -21,4 +21,5 @@ dev = [
|
||||||
"ruff>=0.14.6",
|
"ruff>=0.14.6",
|
||||||
"pytest>=8.3.3",
|
"pytest>=8.3.3",
|
||||||
"pytest-asyncio>=0.25.0",
|
"pytest-asyncio>=0.25.0",
|
||||||
|
"aiosqlite>=0.20.0",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,103 @@
|
||||||
|
"""API tests for organization endpoints."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import AsyncGenerator, Sequence, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.schema import Table
|
||||||
|
|
||||||
|
from app.api.deps import get_db_session
|
||||||
|
from app.core.security import jwt_service
|
||||||
|
from app.main import create_app
|
||||||
|
from app.models import Base
|
||||||
|
from app.models.organization import Organization
|
||||||
|
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture()
|
||||||
|
async def session_factory() -> AsyncGenerator[async_sessionmaker[AsyncSession], None]:
|
||||||
|
engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
tables: Sequence[Table] = cast(
|
||||||
|
Sequence[Table],
|
||||||
|
(User.__table__, Organization.__table__, OrganizationMember.__table__),
|
||||||
|
)
|
||||||
|
await conn.run_sync(Base.metadata.create_all, tables=tables)
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
yield SessionLocal
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture()
|
||||||
|
async def client(
|
||||||
|
session_factory: async_sessionmaker[AsyncSession],
|
||||||
|
) -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
async def _get_session_override() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
async with session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db_session] = _get_session_override
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://testserver") as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_user_organizations_returns_memberships(
|
||||||
|
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient
|
||||||
|
) -> None:
|
||||||
|
async with session_factory() as session:
|
||||||
|
user = User(email="owner@example.com", hashed_password="hashed", name="Owner", is_active=True)
|
||||||
|
session.add(user)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
org_1 = Organization(name="Alpha LLC")
|
||||||
|
org_2 = Organization(name="Beta LLC")
|
||||||
|
session.add_all([org_1, org_2])
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
membership = OrganizationMember(
|
||||||
|
organization_id=org_1.id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=OrganizationRole.OWNER,
|
||||||
|
)
|
||||||
|
other_member = OrganizationMember(
|
||||||
|
organization_id=org_2.id,
|
||||||
|
user_id=user.id + 1,
|
||||||
|
role=OrganizationRole.MEMBER,
|
||||||
|
)
|
||||||
|
session.add_all([membership, other_member])
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
token = jwt_service.create_access_token(
|
||||||
|
subject=str(user.id),
|
||||||
|
expires_delta=timedelta(minutes=30),
|
||||||
|
claims={"email": user.email},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/organizations/me",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert len(payload) == 1
|
||||||
|
assert payload[0]["id"] == org_1.id
|
||||||
|
assert payload[0]["name"] == org_1.name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_user_organizations_requires_token(client: AsyncClient) -> None:
|
||||||
|
response = await client.get("/api/v1/organizations/me")
|
||||||
|
assert response.status_code == 401
|
||||||
14
uv.lock
14
uv.lock
|
|
@ -2,6 +2,18 @@ version = 1
|
||||||
revision = 3
|
revision = 3
|
||||||
requires-python = ">=3.14"
|
requires-python = ">=3.14"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aiosqlite"
|
||||||
|
version = "0.21.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "alembic"
|
name = "alembic"
|
||||||
version = "1.17.2"
|
version = "1.17.2"
|
||||||
|
|
@ -843,6 +855,7 @@ dependencies = [
|
||||||
|
|
||||||
[package.dev-dependencies]
|
[package.dev-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
|
{ name = "aiosqlite" },
|
||||||
{ name = "isort" },
|
{ name = "isort" },
|
||||||
{ name = "mypy" },
|
{ name = "mypy" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
|
|
@ -863,6 +876,7 @@ requires-dist = [
|
||||||
|
|
||||||
[package.metadata.requires-dev]
|
[package.metadata.requires-dev]
|
||||||
dev = [
|
dev = [
|
||||||
|
{ name = "aiosqlite", specifier = ">=0.20.0" },
|
||||||
{ name = "isort", specifier = ">=7.0.0" },
|
{ name = "isort", specifier = ">=7.0.0" },
|
||||||
{ name = "mypy", specifier = ">=1.18.2" },
|
{ name = "mypy", specifier = ">=1.18.2" },
|
||||||
{ name = "pytest", specifier = ">=8.3.3" },
|
{ name = "pytest", specifier = ">=8.3.3" },
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue