dev #11

Merged
k1nq merged 76 commits from dev into master 2025-11-30 04:48:35 +00:00
7 changed files with 165 additions and 16 deletions
Showing only changes of commit 6db1e865f6 - Show all commits

View File

@ -118,6 +118,9 @@ async def get_current_user(
sub = payload.get("sub")
if sub is None:
raise credentials_exception
scope = payload.get("scope", "access")
if scope != "access":
raise credentials_exception
user_id = int(sub)
except (jwt.PyJWTError, TypeError, ValueError):
raise credentials_exception from None

View File

@ -9,10 +9,10 @@ from app.api.deps import get_auth_service, get_user_repository
from app.core.security import password_hasher
from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.token import LoginRequest, TokenResponse
from app.models.token import LoginRequest, RefreshRequest, TokenResponse
from app.models.user import UserCreate
from app.repositories.user_repo import UserRepository
from app.services.auth_service import AuthService, InvalidCredentialsError
from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError
class RegisterRequest(BaseModel):
@ -61,7 +61,7 @@ async def register_user(
) from exc
await repo.session.refresh(user)
return auth_service.create_access_token(user)
return auth_service.issue_tokens(user)
@router.post("/login", response_model=TokenResponse)
@ -74,7 +74,7 @@ async def login(
user = await service.authenticate(credentials.email, credentials.password)
except InvalidCredentialsError as exc: # pragma: no cover - thin API
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
return service.create_access_token(user)
return service.issue_tokens(user)
@router.post("/token", response_model=TokenResponse)
@ -86,4 +86,15 @@ async def login_for_access_token(
user = await service.authenticate(credentials.email, credentials.password)
except InvalidCredentialsError as exc: # pragma: no cover - thin API
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
return service.create_access_token(user)
return service.issue_tokens(user)
@router.post("/refresh", response_model=TokenResponse)
async def refresh_tokens(
payload: RefreshRequest,
service: AuthService = Depends(get_auth_service),
) -> TokenResponse:
try:
return await service.refresh_tokens(payload.refresh_token)
except InvalidRefreshTokenError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc

View File

@ -19,6 +19,7 @@ class Settings(BaseSettings):
jwt_secret_key: SecretStr = Field(default=SecretStr("change-me"))
jwt_algorithm: str = "HS256"
access_token_expire_minutes: int = 30
refresh_token_expire_days: int = 7
settings = Settings()

View File

@ -14,10 +14,16 @@ class TokenPayload(BaseModel):
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
refresh_expires_in: int
class LoginRequest(BaseModel):
email: EmailStr
password: str
class RefreshRequest(BaseModel):
refresh_token: str

View File

@ -2,6 +2,9 @@
from __future__ import annotations
from datetime import timedelta
from typing import Any
import jwt
from app.core.config import settings
from app.core.security import JWTService, PasswordHasher
@ -14,6 +17,10 @@ class InvalidCredentialsError(Exception):
"""Raised when user authentication fails."""
class InvalidRefreshTokenError(Exception):
"""Raised when refresh token validation fails."""
class AuthService:
"""Handles authentication flows and token issuance."""
@ -33,11 +40,47 @@ class AuthService:
raise InvalidCredentialsError("Invalid email or password")
return user
def create_access_token(self, user: User) -> TokenResponse:
expires_delta = timedelta(minutes=settings.access_token_expire_minutes)
token = self._jwt_service.create_access_token(
def issue_tokens(self, user: User) -> TokenResponse:
access_expires = timedelta(minutes=settings.access_token_expire_minutes)
refresh_expires = timedelta(days=settings.refresh_token_expire_days)
access_token = self._jwt_service.create_access_token(
subject=str(user.id),
expires_delta=expires_delta,
claims={"email": user.email},
expires_delta=access_expires,
claims={"email": user.email, "scope": "access"},
)
return TokenResponse(access_token=token, expires_in=int(expires_delta.total_seconds()))
refresh_token = self._jwt_service.create_access_token(
subject=str(user.id),
expires_delta=refresh_expires,
claims={"scope": "refresh"},
)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=int(access_expires.total_seconds()),
refresh_expires_in=int(refresh_expires.total_seconds()),
)
async def refresh_tokens(self, refresh_token: str) -> TokenResponse:
payload = self._decode_refresh_token(refresh_token)
sub = payload.get("sub")
if sub is None:
raise InvalidRefreshTokenError("Invalid refresh token")
try:
user_id = int(sub)
except (TypeError, ValueError) as exc: # pragma: no cover - defensive
raise InvalidRefreshTokenError("Invalid refresh token") from exc
user = await self._user_repository.get_by_id(user_id)
if user is None:
raise InvalidRefreshTokenError("Invalid refresh token")
return self.issue_tokens(user)
def _decode_refresh_token(self, token: str) -> dict[str, Any]:
try:
payload = self._jwt_service.decode(token)
except jwt.PyJWTError as exc:
raise InvalidRefreshTokenError("Invalid refresh token") from exc
if payload.get("scope") != "refresh":
raise InvalidRefreshTokenError("Invalid refresh token")
return payload

View File

@ -30,6 +30,7 @@ async def test_register_user_creates_organization_membership(
body = response.json()
assert body["token_type"] == "bearer"
assert "access_token" in body
assert "refresh_token" in body
async with session_factory() as session:
user = await session.scalar(select(User).where(User.email == payload["email"]))
@ -74,6 +75,7 @@ async def test_login_endpoint_returns_token_for_valid_credentials(
body = response.json()
assert body["token_type"] == "bearer"
assert "access_token" in body
assert "refresh_token" in body
@pytest.mark.asyncio
@ -98,3 +100,47 @@ async def test_token_endpoint_rejects_invalid_credentials(
assert response.status_code == 401
assert response.json()["detail"] == "Invalid email or password"
@pytest.mark.asyncio
async def test_refresh_endpoint_returns_new_tokens(
session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None:
async with session_factory() as session:
user = User(
email="refresh-user@example.com",
hashed_password=password_hasher.hash("StrongPass123"),
name="Refresh User",
is_active=True,
)
session.add(user)
await session.commit()
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "refresh-user@example.com", "password": "StrongPass123"},
)
assert login_response.status_code == 200
refresh_token = login_response.json()["refresh_token"]
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token},
)
assert response.status_code == 200
body = response.json()
assert "access_token" in body
assert "refresh_token" in body
@pytest.mark.asyncio
async def test_refresh_endpoint_rejects_invalid_token(client: AsyncClient) -> None:
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "not-a-jwt"},
)
assert response.status_code == 401
assert response.json()["detail"] == "Invalid refresh token"

View File

@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import JWTService, PasswordHasher
from app.models.user import User
from app.repositories.user_repo import UserRepository
from app.services.auth_service import AuthService, InvalidCredentialsError
from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError
class StubUserRepository(UserRepository):
@ -25,6 +25,11 @@ class StubUserRepository(UserRepository):
return self._user
return None
async def get_by_id(self, user_id: int) -> User | None: # pragma: no cover - helper
if self._user and self._user.id == user_id:
return self._user
return None
@pytest.fixture()
def password_hasher() -> PasswordHasher:
@ -71,7 +76,7 @@ async def test_authenticate_invalid_credentials(
await service.authenticate("user@example.com", "wrong-pass")
def test_create_access_token_contains_user_claims(
def test_issue_tokens_contains_user_claims(
password_hasher: PasswordHasher,
jwt_service: JWTService,
) -> None:
@ -79,9 +84,43 @@ def test_create_access_token_contains_user_claims(
user.id = 42
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
token = service.create_access_token(user)
payload = jwt_service.decode(token.access_token)
token_pair = service.issue_tokens(user)
payload = jwt_service.decode(token_pair.access_token)
assert payload["sub"] == str(user.id)
assert payload["email"] == user.email
assert token.expires_in > 0
assert payload["scope"] == "access"
assert token_pair.refresh_token
assert token_pair.expires_in > 0
assert token_pair.refresh_expires_in > token_pair.expires_in
@pytest.mark.asyncio
async def test_refresh_tokens_returns_new_pair(
password_hasher: PasswordHasher,
jwt_service: JWTService,
) -> None:
user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True)
user.id = 7
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
initial = service.issue_tokens(user)
refreshed = await service.refresh_tokens(initial.refresh_token)
assert refreshed.access_token
assert refreshed.refresh_token
@pytest.mark.asyncio
async def test_refresh_tokens_rejects_access_token(
password_hasher: PasswordHasher,
jwt_service: JWTService,
) -> None:
user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True)
user.id = 9
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
pair = service.issue_tokens(user)
with pytest.raises(InvalidRefreshTokenError):
await service.refresh_tokens(pair.access_token)