feat: implement refresh token functionality; update authentication and token models; add tests for refresh endpoint
Test / test (push) Successful in 13s Details

This commit is contained in:
k1nq 2025-11-28 13:56:04 +05:00
parent a8bdf18e38
commit 6db1e865f6
7 changed files with 165 additions and 16 deletions

View File

@ -118,6 +118,9 @@ async def get_current_user(
sub = payload.get("sub") sub = payload.get("sub")
if sub is None: if sub is None:
raise credentials_exception raise credentials_exception
scope = payload.get("scope", "access")
if scope != "access":
raise credentials_exception
user_id = int(sub) user_id = int(sub)
except (jwt.PyJWTError, TypeError, ValueError): except (jwt.PyJWTError, TypeError, ValueError):
raise credentials_exception from None raise credentials_exception from None

View File

@ -9,10 +9,10 @@ from app.api.deps import get_auth_service, get_user_repository
from app.core.security import password_hasher from app.core.security import password_hasher
from app.models.organization import Organization from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.token import LoginRequest, TokenResponse from app.models.token import LoginRequest, RefreshRequest, TokenResponse
from app.models.user import UserCreate from app.models.user import UserCreate
from app.repositories.user_repo import UserRepository from app.repositories.user_repo import UserRepository
from app.services.auth_service import AuthService, InvalidCredentialsError from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError
class RegisterRequest(BaseModel): class RegisterRequest(BaseModel):
@ -61,7 +61,7 @@ async def register_user(
) from exc ) from exc
await repo.session.refresh(user) await repo.session.refresh(user)
return auth_service.create_access_token(user) return auth_service.issue_tokens(user)
@router.post("/login", response_model=TokenResponse) @router.post("/login", response_model=TokenResponse)
@ -74,7 +74,7 @@ async def login(
user = await service.authenticate(credentials.email, credentials.password) user = await service.authenticate(credentials.email, credentials.password)
except InvalidCredentialsError as exc: # pragma: no cover - thin API except InvalidCredentialsError as exc: # pragma: no cover - thin API
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
return service.create_access_token(user) return service.issue_tokens(user)
@router.post("/token", response_model=TokenResponse) @router.post("/token", response_model=TokenResponse)
@ -86,4 +86,15 @@ async def login_for_access_token(
user = await service.authenticate(credentials.email, credentials.password) user = await service.authenticate(credentials.email, credentials.password)
except InvalidCredentialsError as exc: # pragma: no cover - thin API except InvalidCredentialsError as exc: # pragma: no cover - thin API
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
return service.create_access_token(user) return service.issue_tokens(user)
@router.post("/refresh", response_model=TokenResponse)
async def refresh_tokens(
payload: RefreshRequest,
service: AuthService = Depends(get_auth_service),
) -> TokenResponse:
try:
return await service.refresh_tokens(payload.refresh_token)
except InvalidRefreshTokenError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc

View File

@ -19,6 +19,7 @@ class Settings(BaseSettings):
jwt_secret_key: SecretStr = Field(default=SecretStr("change-me")) jwt_secret_key: SecretStr = Field(default=SecretStr("change-me"))
jwt_algorithm: str = "HS256" jwt_algorithm: str = "HS256"
access_token_expire_minutes: int = 30 access_token_expire_minutes: int = 30
refresh_token_expire_days: int = 7
settings = Settings() settings = Settings()

View File

@ -14,10 +14,16 @@ class TokenPayload(BaseModel):
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
access_token: str access_token: str
refresh_token: str
token_type: str = "bearer" token_type: str = "bearer"
expires_in: int expires_in: int
refresh_expires_in: int
class LoginRequest(BaseModel): class LoginRequest(BaseModel):
email: EmailStr email: EmailStr
password: str password: str
class RefreshRequest(BaseModel):
refresh_token: str

View File

@ -2,6 +2,9 @@
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from datetime import timedelta
from typing import Any
import jwt
from app.core.config import settings from app.core.config import settings
from app.core.security import JWTService, PasswordHasher from app.core.security import JWTService, PasswordHasher
@ -14,6 +17,10 @@ class InvalidCredentialsError(Exception):
"""Raised when user authentication fails.""" """Raised when user authentication fails."""
class InvalidRefreshTokenError(Exception):
"""Raised when refresh token validation fails."""
class AuthService: class AuthService:
"""Handles authentication flows and token issuance.""" """Handles authentication flows and token issuance."""
@ -33,11 +40,47 @@ class AuthService:
raise InvalidCredentialsError("Invalid email or password") raise InvalidCredentialsError("Invalid email or password")
return user return user
def create_access_token(self, user: User) -> TokenResponse: def issue_tokens(self, user: User) -> TokenResponse:
expires_delta = timedelta(minutes=settings.access_token_expire_minutes) access_expires = timedelta(minutes=settings.access_token_expire_minutes)
token = self._jwt_service.create_access_token( refresh_expires = timedelta(days=settings.refresh_token_expire_days)
access_token = self._jwt_service.create_access_token(
subject=str(user.id), subject=str(user.id),
expires_delta=expires_delta, expires_delta=access_expires,
claims={"email": user.email}, claims={"email": user.email, "scope": "access"},
) )
return TokenResponse(access_token=token, expires_in=int(expires_delta.total_seconds())) refresh_token = self._jwt_service.create_access_token(
subject=str(user.id),
expires_delta=refresh_expires,
claims={"scope": "refresh"},
)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=int(access_expires.total_seconds()),
refresh_expires_in=int(refresh_expires.total_seconds()),
)
async def refresh_tokens(self, refresh_token: str) -> TokenResponse:
payload = self._decode_refresh_token(refresh_token)
sub = payload.get("sub")
if sub is None:
raise InvalidRefreshTokenError("Invalid refresh token")
try:
user_id = int(sub)
except (TypeError, ValueError) as exc: # pragma: no cover - defensive
raise InvalidRefreshTokenError("Invalid refresh token") from exc
user = await self._user_repository.get_by_id(user_id)
if user is None:
raise InvalidRefreshTokenError("Invalid refresh token")
return self.issue_tokens(user)
def _decode_refresh_token(self, token: str) -> dict[str, Any]:
try:
payload = self._jwt_service.decode(token)
except jwt.PyJWTError as exc:
raise InvalidRefreshTokenError("Invalid refresh token") from exc
if payload.get("scope") != "refresh":
raise InvalidRefreshTokenError("Invalid refresh token")
return payload

View File

@ -30,6 +30,7 @@ async def test_register_user_creates_organization_membership(
body = response.json() body = response.json()
assert body["token_type"] == "bearer" assert body["token_type"] == "bearer"
assert "access_token" in body assert "access_token" in body
assert "refresh_token" in body
async with session_factory() as session: async with session_factory() as session:
user = await session.scalar(select(User).where(User.email == payload["email"])) user = await session.scalar(select(User).where(User.email == payload["email"]))
@ -74,6 +75,7 @@ async def test_login_endpoint_returns_token_for_valid_credentials(
body = response.json() body = response.json()
assert body["token_type"] == "bearer" assert body["token_type"] == "bearer"
assert "access_token" in body assert "access_token" in body
assert "refresh_token" in body
@pytest.mark.asyncio @pytest.mark.asyncio
@ -98,3 +100,47 @@ async def test_token_endpoint_rejects_invalid_credentials(
assert response.status_code == 401 assert response.status_code == 401
assert response.json()["detail"] == "Invalid email or password" assert response.json()["detail"] == "Invalid email or password"
@pytest.mark.asyncio
async def test_refresh_endpoint_returns_new_tokens(
session_factory: async_sessionmaker[AsyncSession],
client: AsyncClient,
) -> None:
async with session_factory() as session:
user = User(
email="refresh-user@example.com",
hashed_password=password_hasher.hash("StrongPass123"),
name="Refresh User",
is_active=True,
)
session.add(user)
await session.commit()
login_response = await client.post(
"/api/v1/auth/login",
json={"email": "refresh-user@example.com", "password": "StrongPass123"},
)
assert login_response.status_code == 200
refresh_token = login_response.json()["refresh_token"]
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token},
)
assert response.status_code == 200
body = response.json()
assert "access_token" in body
assert "refresh_token" in body
@pytest.mark.asyncio
async def test_refresh_endpoint_rejects_invalid_token(client: AsyncClient) -> None:
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "not-a-jwt"},
)
assert response.status_code == 401
assert response.json()["detail"] == "Invalid refresh token"

View File

@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.security import JWTService, PasswordHasher from app.core.security import JWTService, PasswordHasher
from app.models.user import User from app.models.user import User
from app.repositories.user_repo import UserRepository from app.repositories.user_repo import UserRepository
from app.services.auth_service import AuthService, InvalidCredentialsError from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError
class StubUserRepository(UserRepository): class StubUserRepository(UserRepository):
@ -25,6 +25,11 @@ class StubUserRepository(UserRepository):
return self._user return self._user
return None return None
async def get_by_id(self, user_id: int) -> User | None: # pragma: no cover - helper
if self._user and self._user.id == user_id:
return self._user
return None
@pytest.fixture() @pytest.fixture()
def password_hasher() -> PasswordHasher: def password_hasher() -> PasswordHasher:
@ -71,7 +76,7 @@ async def test_authenticate_invalid_credentials(
await service.authenticate("user@example.com", "wrong-pass") await service.authenticate("user@example.com", "wrong-pass")
def test_create_access_token_contains_user_claims( def test_issue_tokens_contains_user_claims(
password_hasher: PasswordHasher, password_hasher: PasswordHasher,
jwt_service: JWTService, jwt_service: JWTService,
) -> None: ) -> None:
@ -79,9 +84,43 @@ def test_create_access_token_contains_user_claims(
user.id = 42 user.id = 42
service = AuthService(StubUserRepository(user), password_hasher, jwt_service) service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
token = service.create_access_token(user) token_pair = service.issue_tokens(user)
payload = jwt_service.decode(token.access_token) payload = jwt_service.decode(token_pair.access_token)
assert payload["sub"] == str(user.id) assert payload["sub"] == str(user.id)
assert payload["email"] == user.email assert payload["email"] == user.email
assert token.expires_in > 0 assert payload["scope"] == "access"
assert token_pair.refresh_token
assert token_pair.expires_in > 0
assert token_pair.refresh_expires_in > token_pair.expires_in
@pytest.mark.asyncio
async def test_refresh_tokens_returns_new_pair(
password_hasher: PasswordHasher,
jwt_service: JWTService,
) -> None:
user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True)
user.id = 7
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
initial = service.issue_tokens(user)
refreshed = await service.refresh_tokens(initial.refresh_token)
assert refreshed.access_token
assert refreshed.refresh_token
@pytest.mark.asyncio
async def test_refresh_tokens_rejects_access_token(
password_hasher: PasswordHasher,
jwt_service: JWTService,
) -> None:
user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True)
user.id = 9
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
pair = service.issue_tokens(user)
with pytest.raises(InvalidRefreshTokenError):
await service.refresh_tokens(pair.access_token)