From 6db1e865f6b5d75517a9b08d3c08d674427f5ae8 Mon Sep 17 00:00:00 2001 From: k1nq Date: Fri, 28 Nov 2025 13:56:04 +0500 Subject: [PATCH] feat: implement refresh token functionality; update authentication and token models; add tests for refresh endpoint --- app/api/deps.py | 3 ++ app/api/v1/auth.py | 21 ++++++++--- app/core/config.py | 1 + app/models/token.py | 6 ++++ app/services/auth_service.py | 55 +++++++++++++++++++++++++---- tests/api/v1/test_auth.py | 46 ++++++++++++++++++++++++ tests/services/test_auth_service.py | 49 ++++++++++++++++++++++--- 7 files changed, 165 insertions(+), 16 deletions(-) diff --git a/app/api/deps.py b/app/api/deps.py index 7efc4f0..921d605 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -118,6 +118,9 @@ async def get_current_user( sub = payload.get("sub") if sub is None: raise credentials_exception + scope = payload.get("scope", "access") + if scope != "access": + raise credentials_exception user_id = int(sub) except (jwt.PyJWTError, TypeError, ValueError): raise credentials_exception from None diff --git a/app/api/v1/auth.py b/app/api/v1/auth.py index a424036..a44b86d 100644 --- a/app/api/v1/auth.py +++ b/app/api/v1/auth.py @@ -9,10 +9,10 @@ from app.api.deps import get_auth_service, get_user_repository from app.core.security import password_hasher from app.models.organization import Organization from app.models.organization_member import OrganizationMember, OrganizationRole -from app.models.token import LoginRequest, TokenResponse +from app.models.token import LoginRequest, RefreshRequest, TokenResponse from app.models.user import UserCreate from app.repositories.user_repo import UserRepository -from app.services.auth_service import AuthService, InvalidCredentialsError +from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError class RegisterRequest(BaseModel): @@ -61,7 +61,7 @@ async def register_user( ) from exc await repo.session.refresh(user) - return auth_service.create_access_token(user) + return auth_service.issue_tokens(user) @router.post("/login", response_model=TokenResponse) @@ -74,7 +74,7 @@ async def login( user = await service.authenticate(credentials.email, credentials.password) except InvalidCredentialsError as exc: # pragma: no cover - thin API raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc - return service.create_access_token(user) + return service.issue_tokens(user) @router.post("/token", response_model=TokenResponse) @@ -86,4 +86,15 @@ async def login_for_access_token( user = await service.authenticate(credentials.email, credentials.password) except InvalidCredentialsError as exc: # pragma: no cover - thin API raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc - return service.create_access_token(user) + return service.issue_tokens(user) + + +@router.post("/refresh", response_model=TokenResponse) +async def refresh_tokens( + payload: RefreshRequest, + service: AuthService = Depends(get_auth_service), +) -> TokenResponse: + try: + return await service.refresh_tokens(payload.refresh_token) + except InvalidRefreshTokenError as exc: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc diff --git a/app/core/config.py b/app/core/config.py index 43211c8..e838665 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -19,6 +19,7 @@ class Settings(BaseSettings): jwt_secret_key: SecretStr = Field(default=SecretStr("change-me")) jwt_algorithm: str = "HS256" access_token_expire_minutes: int = 30 + refresh_token_expire_days: int = 7 settings = Settings() diff --git a/app/models/token.py b/app/models/token.py index c67baa0..526381a 100644 --- a/app/models/token.py +++ b/app/models/token.py @@ -14,10 +14,16 @@ class TokenPayload(BaseModel): class TokenResponse(BaseModel): access_token: str + refresh_token: str token_type: str = "bearer" expires_in: int + refresh_expires_in: int class LoginRequest(BaseModel): email: EmailStr password: str + + +class RefreshRequest(BaseModel): + refresh_token: str diff --git a/app/services/auth_service.py b/app/services/auth_service.py index 7effa73..8c1e934 100644 --- a/app/services/auth_service.py +++ b/app/services/auth_service.py @@ -2,6 +2,9 @@ from __future__ import annotations from datetime import timedelta +from typing import Any + +import jwt from app.core.config import settings from app.core.security import JWTService, PasswordHasher @@ -14,6 +17,10 @@ class InvalidCredentialsError(Exception): """Raised when user authentication fails.""" +class InvalidRefreshTokenError(Exception): + """Raised when refresh token validation fails.""" + + class AuthService: """Handles authentication flows and token issuance.""" @@ -33,11 +40,47 @@ class AuthService: raise InvalidCredentialsError("Invalid email or password") return user - def create_access_token(self, user: User) -> TokenResponse: - expires_delta = timedelta(minutes=settings.access_token_expire_minutes) - token = self._jwt_service.create_access_token( + def issue_tokens(self, user: User) -> TokenResponse: + access_expires = timedelta(minutes=settings.access_token_expire_minutes) + refresh_expires = timedelta(days=settings.refresh_token_expire_days) + access_token = self._jwt_service.create_access_token( subject=str(user.id), - expires_delta=expires_delta, - claims={"email": user.email}, + expires_delta=access_expires, + claims={"email": user.email, "scope": "access"}, ) - return TokenResponse(access_token=token, expires_in=int(expires_delta.total_seconds())) + refresh_token = self._jwt_service.create_access_token( + subject=str(user.id), + expires_delta=refresh_expires, + claims={"scope": "refresh"}, + ) + return TokenResponse( + access_token=access_token, + refresh_token=refresh_token, + expires_in=int(access_expires.total_seconds()), + refresh_expires_in=int(refresh_expires.total_seconds()), + ) + + async def refresh_tokens(self, refresh_token: str) -> TokenResponse: + payload = self._decode_refresh_token(refresh_token) + sub = payload.get("sub") + if sub is None: + raise InvalidRefreshTokenError("Invalid refresh token") + + try: + user_id = int(sub) + except (TypeError, ValueError) as exc: # pragma: no cover - defensive + raise InvalidRefreshTokenError("Invalid refresh token") from exc + + user = await self._user_repository.get_by_id(user_id) + if user is None: + raise InvalidRefreshTokenError("Invalid refresh token") + return self.issue_tokens(user) + + def _decode_refresh_token(self, token: str) -> dict[str, Any]: + try: + payload = self._jwt_service.decode(token) + except jwt.PyJWTError as exc: + raise InvalidRefreshTokenError("Invalid refresh token") from exc + if payload.get("scope") != "refresh": + raise InvalidRefreshTokenError("Invalid refresh token") + return payload diff --git a/tests/api/v1/test_auth.py b/tests/api/v1/test_auth.py index b81c66d..37eb090 100644 --- a/tests/api/v1/test_auth.py +++ b/tests/api/v1/test_auth.py @@ -30,6 +30,7 @@ async def test_register_user_creates_organization_membership( body = response.json() assert body["token_type"] == "bearer" assert "access_token" in body + assert "refresh_token" in body async with session_factory() as session: user = await session.scalar(select(User).where(User.email == payload["email"])) @@ -74,6 +75,7 @@ async def test_login_endpoint_returns_token_for_valid_credentials( body = response.json() assert body["token_type"] == "bearer" assert "access_token" in body + assert "refresh_token" in body @pytest.mark.asyncio @@ -98,3 +100,47 @@ async def test_token_endpoint_rejects_invalid_credentials( assert response.status_code == 401 assert response.json()["detail"] == "Invalid email or password" + + +@pytest.mark.asyncio +async def test_refresh_endpoint_returns_new_tokens( + session_factory: async_sessionmaker[AsyncSession], + client: AsyncClient, +) -> None: + async with session_factory() as session: + user = User( + email="refresh-user@example.com", + hashed_password=password_hasher.hash("StrongPass123"), + name="Refresh User", + is_active=True, + ) + session.add(user) + await session.commit() + + login_response = await client.post( + "/api/v1/auth/login", + json={"email": "refresh-user@example.com", "password": "StrongPass123"}, + ) + assert login_response.status_code == 200 + refresh_token = login_response.json()["refresh_token"] + + response = await client.post( + "/api/v1/auth/refresh", + json={"refresh_token": refresh_token}, + ) + + assert response.status_code == 200 + body = response.json() + assert "access_token" in body + assert "refresh_token" in body + + +@pytest.mark.asyncio +async def test_refresh_endpoint_rejects_invalid_token(client: AsyncClient) -> None: + response = await client.post( + "/api/v1/auth/refresh", + json={"refresh_token": "not-a-jwt"}, + ) + + assert response.status_code == 401 + assert response.json()["detail"] == "Invalid refresh token" diff --git a/tests/services/test_auth_service.py b/tests/services/test_auth_service.py index b31d359..40cfd42 100644 --- a/tests/services/test_auth_service.py +++ b/tests/services/test_auth_service.py @@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.security import JWTService, PasswordHasher from app.models.user import User from app.repositories.user_repo import UserRepository -from app.services.auth_service import AuthService, InvalidCredentialsError +from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError class StubUserRepository(UserRepository): @@ -25,6 +25,11 @@ class StubUserRepository(UserRepository): return self._user return None + async def get_by_id(self, user_id: int) -> User | None: # pragma: no cover - helper + if self._user and self._user.id == user_id: + return self._user + return None + @pytest.fixture() def password_hasher() -> PasswordHasher: @@ -71,7 +76,7 @@ async def test_authenticate_invalid_credentials( await service.authenticate("user@example.com", "wrong-pass") -def test_create_access_token_contains_user_claims( +def test_issue_tokens_contains_user_claims( password_hasher: PasswordHasher, jwt_service: JWTService, ) -> None: @@ -79,9 +84,43 @@ def test_create_access_token_contains_user_claims( user.id = 42 service = AuthService(StubUserRepository(user), password_hasher, jwt_service) - token = service.create_access_token(user) - payload = jwt_service.decode(token.access_token) + token_pair = service.issue_tokens(user) + payload = jwt_service.decode(token_pair.access_token) assert payload["sub"] == str(user.id) assert payload["email"] == user.email - assert token.expires_in > 0 + assert payload["scope"] == "access" + assert token_pair.refresh_token + assert token_pair.expires_in > 0 + assert token_pair.refresh_expires_in > token_pair.expires_in + + +@pytest.mark.asyncio +async def test_refresh_tokens_returns_new_pair( + password_hasher: PasswordHasher, + jwt_service: JWTService, +) -> None: + user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True) + user.id = 7 + service = AuthService(StubUserRepository(user), password_hasher, jwt_service) + + initial = service.issue_tokens(user) + refreshed = await service.refresh_tokens(initial.refresh_token) + + assert refreshed.access_token + assert refreshed.refresh_token + + +@pytest.mark.asyncio +async def test_refresh_tokens_rejects_access_token( + password_hasher: PasswordHasher, + jwt_service: JWTService, +) -> None: + user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True) + user.id = 9 + service = AuthService(StubUserRepository(user), password_hasher, jwt_service) + + pair = service.issue_tokens(user) + + with pytest.raises(InvalidRefreshTokenError): + await service.refresh_tokens(pair.access_token)