"""API tests for authentication endpoints.""" from __future__ import annotations import pytest from httpx import AsyncClient from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from app.core.security import password_hasher from app.models.organization import Organization from app.models.organization_member import OrganizationMember, OrganizationRole from app.models.user import User @pytest.mark.asyncio async def test_register_user_creates_organization_membership( session_factory: async_sessionmaker[AsyncSession], client: AsyncClient, ) -> None: payload = { "email": "new-owner@example.com", "password": "StrongPass123!", "name": "Alice Owner", "organization_name": "Rocket LLC", } response = await client.post("/api/v1/auth/register", json=payload) assert response.status_code == 201 body = response.json() assert body["token_type"] == "bearer" assert "access_token" in body assert "refresh_token" in body async with session_factory() as session: user = await session.scalar(select(User).where(User.email == payload["email"])) assert user is not None organization = await session.scalar( select(Organization).where(Organization.name == payload["organization_name"]) ) assert organization is not None membership = await session.scalar( select(OrganizationMember).where( OrganizationMember.organization_id == organization.id, OrganizationMember.user_id == user.id, ) ) assert membership is not None assert membership.role == OrganizationRole.OWNER @pytest.mark.asyncio async def test_login_endpoint_returns_token_for_valid_credentials( session_factory: async_sessionmaker[AsyncSession], client: AsyncClient, ) -> None: async with session_factory() as session: user = User( email="login-user@example.com", hashed_password=password_hasher.hash("Secret123!"), name="Login User", is_active=True, ) session.add(user) await session.commit() response = await client.post( "/api/v1/auth/login", json={"email": "login-user@example.com", "password": "Secret123!"}, ) assert response.status_code == 200 body = response.json() assert body["token_type"] == "bearer" assert "access_token" in body assert "refresh_token" in body @pytest.mark.asyncio async def test_token_endpoint_rejects_invalid_credentials( session_factory: async_sessionmaker[AsyncSession], client: AsyncClient, ) -> None: async with session_factory() as session: user = User( email="token-user@example.com", hashed_password=password_hasher.hash("SuperSecret123"), name="Token User", is_active=True, ) session.add(user) await session.commit() response = await client.post( "/api/v1/auth/token", json={"email": "token-user@example.com", "password": "wrong-pass"}, ) assert response.status_code == 401 assert response.json()["detail"] == "Invalid email or password" @pytest.mark.asyncio async def test_refresh_endpoint_returns_new_tokens( session_factory: async_sessionmaker[AsyncSession], client: AsyncClient, ) -> None: async with session_factory() as session: user = User( email="refresh-user@example.com", hashed_password=password_hasher.hash("StrongPass123"), name="Refresh User", is_active=True, ) session.add(user) await session.commit() login_response = await client.post( "/api/v1/auth/login", json={"email": "refresh-user@example.com", "password": "StrongPass123"}, ) assert login_response.status_code == 200 refresh_token = login_response.json()["refresh_token"] response = await client.post( "/api/v1/auth/refresh", json={"refresh_token": refresh_token}, ) assert response.status_code == 200 body = response.json() assert "access_token" in body assert "refresh_token" in body @pytest.mark.asyncio async def test_refresh_endpoint_rejects_invalid_token(client: AsyncClient) -> None: response = await client.post( "/api/v1/auth/refresh", json={"refresh_token": "not-a-jwt"}, ) assert response.status_code == 401 assert response.json()["detail"] == "Invalid refresh token"