"""API tests for organization endpoints.""" from __future__ import annotations from datetime import timedelta from typing import AsyncGenerator, Sequence, cast import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.schema import Table from app.api.deps import get_db_session from app.core.security import jwt_service from app.main import create_app from app.models import Base from app.models.organization import Organization from app.models.organization_member import OrganizationMember, OrganizationRole from app.models.user import User @pytest_asyncio.fixture() async def session_factory() -> AsyncGenerator[async_sessionmaker[AsyncSession], None]: engine = create_async_engine("sqlite+aiosqlite:///:memory:", future=True) async with engine.begin() as conn: tables: Sequence[Table] = cast( Sequence[Table], (User.__table__, Organization.__table__, OrganizationMember.__table__), ) await conn.run_sync(Base.metadata.create_all, tables=tables) SessionLocal = async_sessionmaker(engine, expire_on_commit=False) yield SessionLocal await engine.dispose() @pytest_asyncio.fixture() async def client( session_factory: async_sessionmaker[AsyncSession], ) -> AsyncGenerator[AsyncClient, None]: app = create_app() async def _get_session_override() -> AsyncGenerator[AsyncSession, None]: async with session_factory() as session: yield session app.dependency_overrides[get_db_session] = _get_session_override transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://testserver") as test_client: yield test_client @pytest.mark.asyncio async def test_list_user_organizations_returns_memberships( session_factory: async_sessionmaker[AsyncSession], client: AsyncClient ) -> None: async with session_factory() as session: user = User(email="owner@example.com", hashed_password="hashed", name="Owner", is_active=True) session.add(user) await session.flush() org_1 = Organization(name="Alpha LLC") org_2 = Organization(name="Beta LLC") session.add_all([org_1, org_2]) await session.flush() membership = OrganizationMember( organization_id=org_1.id, user_id=user.id, role=OrganizationRole.OWNER, ) other_member = OrganizationMember( organization_id=org_2.id, user_id=user.id + 1, role=OrganizationRole.MEMBER, ) session.add_all([membership, other_member]) await session.commit() token = jwt_service.create_access_token( subject=str(user.id), expires_delta=timedelta(minutes=30), claims={"email": user.email}, ) response = await client.get( "/api/v1/organizations/me", headers={"Authorization": f"Bearer {token}"}, ) assert response.status_code == 200 payload = response.json() assert len(payload) == 1 assert payload[0]["id"] == org_1.id assert payload[0]["name"] == org_1.name @pytest.mark.asyncio async def test_list_user_organizations_requires_token(client: AsyncClient) -> None: response = await client.get("/api/v1/organizations/me") assert response.status_code == 401