organizations #3

Merged
k1nq merged 3 commits from organizations into dev 2025-11-27 10:10:51 +00:00
7 changed files with 218 additions and 11 deletions

View File

@ -1,15 +1,22 @@
"""Reusable FastAPI dependencies.""" """Reusable FastAPI dependencies."""
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from fastapi import Depends import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.database import get_session from app.core.database import get_session
from app.core.security import jwt_service, password_hasher from app.core.security import jwt_service, password_hasher
from app.models.user import User
from app.repositories.org_repo import OrganizationRepository
from app.repositories.user_repo import UserRepository from app.repositories.user_repo import UserRepository
from app.services.auth_service import AuthService from app.services.auth_service import AuthService
from app.services.user_service import UserService from app.services.user_service import UserService
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token")
async def get_db_session() -> AsyncGenerator[AsyncSession, None]: async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""Provide a scoped database session.""" """Provide a scoped database session."""
@ -21,6 +28,10 @@ def get_user_repository(session: AsyncSession = Depends(get_db_session)) -> User
return UserRepository(session=session) return UserRepository(session=session)
def get_organization_repository(session: AsyncSession = Depends(get_db_session)) -> OrganizationRepository:
return OrganizationRepository(session=session)
def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService: def get_user_service(repo: UserRepository = Depends(get_user_repository)) -> UserService:
return UserService(user_repository=repo, password_hasher=password_hasher) return UserService(user_repository=repo, password_hasher=password_hasher)
@ -33,3 +44,26 @@ def get_auth_service(
password_hasher=password_hasher, password_hasher=password_hasher,
jwt_service=jwt_service, jwt_service=jwt_service,
) )
async def get_current_user(
token: str = Depends(oauth2_scheme),
repo: UserRepository = Depends(get_user_repository),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)
try:
payload = jwt_service.decode(token)
sub = payload.get("sub")
if sub is None:
raise credentials_exception
user_id = int(sub)
except (jwt.PyJWTError, TypeError, ValueError):
raise credentials_exception from None
user = await repo.get_by_id(user_id)
if user is None:
raise credentials_exception
return user

View File

@ -24,7 +24,7 @@ async def list_deals(
stage: str | None = None, stage: str | None = None,
owner_id: int | None = None, owner_id: int | None = None,
order_by: str | None = None, order_by: str | None = None,
order: str | None = Query(default=None, regex="^(asc|desc)$"), order: str | None = Query(default=None, pattern="^(asc|desc)$"),
) -> dict[str, str]: ) -> dict[str, str]:
"""Placeholder for deal filtering endpoint.""" """Placeholder for deal filtering endpoint."""
_ = (status_filter,) _ = (status_filter,)

View File

@ -1,16 +1,22 @@
"""Organization-related API stubs.""" """Organization-related API endpoints."""
from __future__ import annotations from __future__ import annotations
from fastapi import APIRouter, status from fastapi import APIRouter, Depends
from app.api.deps import get_current_user, get_organization_repository
from app.models.organization import OrganizationRead
from app.models.user import User
from app.repositories.org_repo import OrganizationRepository
router = APIRouter(prefix="/organizations", tags=["organizations"]) router = APIRouter(prefix="/organizations", tags=["organizations"])
def _stub(endpoint: str) -> dict[str, str]: @router.get("/me", response_model=list[OrganizationRead])
return {"detail": f"{endpoint} is not implemented yet"} async def list_user_organizations(
current_user: User = Depends(get_current_user),
repo: OrganizationRepository = Depends(get_organization_repository),
) -> list[OrganizationRead]:
"""Return organizations the authenticated user belongs to."""
organizations = await repo.list_for_user(current_user.id)
@router.get("/me", status_code=status.HTTP_501_NOT_IMPLEMENTED) return [OrganizationRead.model_validate(org) for org in organizations]
async def list_user_organizations() -> dict[str, str]:
"""Placeholder for returning organizations linked to the current user."""
return _stub("GET /organizations/me")

View File

@ -0,0 +1,49 @@
"""Organization repository for database operations."""
from __future__ import annotations
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.organization import Organization, OrganizationCreate
from app.models.organization_member import OrganizationMember
class OrganizationRepository:
"""Provides CRUD helpers for Organization model."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
@property
def session(self) -> AsyncSession:
return self._session
async def list(self) -> Sequence[Organization]:
result = await self._session.scalars(select(Organization))
return result.all()
async def get_by_id(self, organization_id: int) -> Organization | None:
return await self._session.get(Organization, organization_id)
async def get_by_name(self, name: str) -> Organization | None:
stmt = select(Organization).where(Organization.name == name)
result = await self._session.scalars(stmt)
return result.first()
async def list_for_user(self, user_id: int) -> Sequence[Organization]:
stmt = (
select(Organization)
.join(OrganizationMember, OrganizationMember.organization_id == Organization.id)
.where(OrganizationMember.user_id == user_id)
.order_by(Organization.id)
)
result = await self._session.scalars(stmt)
return result.unique().all()
async def create(self, data: OrganizationCreate) -> Organization:
organization = Organization(name=data.name)
self._session.add(organization)
await self._session.flush()
return organization

View File

@ -21,4 +21,5 @@ dev = [
"ruff>=0.14.6", "ruff>=0.14.6",
"pytest>=8.3.3", "pytest>=8.3.3",
"pytest-asyncio>=0.25.0", "pytest-asyncio>=0.25.0",
"aiosqlite>=0.20.0",
] ]

View File

@ -0,0 +1,103 @@
"""API tests for organization endpoints."""
from __future__ import annotations
from datetime import timedelta
from typing import AsyncGenerator, Sequence, cast
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.schema import Table
from app.api.deps import get_db_session
from app.core.security import jwt_service
from app.main import create_app
from app.models import Base
from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.user import User
@pytest_asyncio.fixture()
async def session_factory() -> AsyncGenerator[async_sessionmaker[AsyncSession], None]:
engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True)
async with engine.begin() as conn:
tables: Sequence[Table] = cast(
Sequence[Table],
(User.__table__, Organization.__table__, OrganizationMember.__table__),
)
await conn.run_sync(Base.metadata.create_all, tables=tables)
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
yield SessionLocal
await engine.dispose()
@pytest_asyncio.fixture()
async def client(
session_factory: async_sessionmaker[AsyncSession],
) -> AsyncGenerator[AsyncClient, None]:
app = create_app()
async def _get_session_override() -> AsyncGenerator[AsyncSession, None]:
async with session_factory() as session:
yield session
app.dependency_overrides[get_db_session] = _get_session_override
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://testserver") as test_client:
yield test_client
@pytest.mark.asyncio
async def test_list_user_organizations_returns_memberships(
session_factory: async_sessionmaker[AsyncSession], client: AsyncClient
) -> None:
async with session_factory() as session:
user = User(email="owner@example.com", hashed_password="hashed", name="Owner", is_active=True)
session.add(user)
await session.flush()
org_1 = Organization(name="Alpha LLC")
org_2 = Organization(name="Beta LLC")
session.add_all([org_1, org_2])
await session.flush()
membership = OrganizationMember(
organization_id=org_1.id,
user_id=user.id,
role=OrganizationRole.OWNER,
)
other_member = OrganizationMember(
organization_id=org_2.id,
user_id=user.id + 1,
role=OrganizationRole.MEMBER,
)
session.add_all([membership, other_member])
await session.commit()
token = jwt_service.create_access_token(
subject=str(user.id),
expires_delta=timedelta(minutes=30),
claims={"email": user.email},
)
response = await client.get(
"/api/v1/organizations/me",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
payload = response.json()
assert len(payload) == 1
assert payload[0]["id"] == org_1.id
assert payload[0]["name"] == org_1.name
@pytest.mark.asyncio
async def test_list_user_organizations_requires_token(client: AsyncClient) -> None:
response = await client.get("/api/v1/organizations/me")
assert response.status_code == 401

14
uv.lock
View File

@ -2,6 +2,18 @@ version = 1
revision = 3 revision = 3
requires-python = ">=3.14" requires-python = ">=3.14"
[[package]]
name = "aiosqlite"
version = "0.21.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" },
]
[[package]] [[package]]
name = "alembic" name = "alembic"
version = "1.17.2" version = "1.17.2"
@ -843,6 +855,7 @@ dependencies = [
[package.dev-dependencies] [package.dev-dependencies]
dev = [ dev = [
{ name = "aiosqlite" },
{ name = "isort" }, { name = "isort" },
{ name = "mypy" }, { name = "mypy" },
{ name = "pytest" }, { name = "pytest" },
@ -863,6 +876,7 @@ requires-dist = [
[package.metadata.requires-dev] [package.metadata.requires-dev]
dev = [ dev = [
{ name = "aiosqlite", specifier = ">=0.20.0" },
{ name = "isort", specifier = ">=7.0.0" }, { name = "isort", specifier = ">=7.0.0" },
{ name = "mypy", specifier = ">=1.18.2" }, { name = "mypy", specifier = ">=1.18.2" },
{ name = "pytest", specifier = ">=8.3.3" }, { name = "pytest", specifier = ">=8.3.3" },