feat: implement refresh token functionality; update authentication and token models; add tests for refresh endpoint
Test / test (push) Successful in 13s
Details
Test / test (push) Successful in 13s
Details
This commit is contained in:
parent
a8bdf18e38
commit
6db1e865f6
|
|
@ -118,6 +118,9 @@ async def get_current_user(
|
|||
sub = payload.get("sub")
|
||||
if sub is None:
|
||||
raise credentials_exception
|
||||
scope = payload.get("scope", "access")
|
||||
if scope != "access":
|
||||
raise credentials_exception
|
||||
user_id = int(sub)
|
||||
except (jwt.PyJWTError, TypeError, ValueError):
|
||||
raise credentials_exception from None
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ from app.api.deps import get_auth_service, get_user_repository
|
|||
from app.core.security import password_hasher
|
||||
from app.models.organization import Organization
|
||||
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||
from app.models.token import LoginRequest, TokenResponse
|
||||
from app.models.token import LoginRequest, RefreshRequest, TokenResponse
|
||||
from app.models.user import UserCreate
|
||||
from app.repositories.user_repo import UserRepository
|
||||
from app.services.auth_service import AuthService, InvalidCredentialsError
|
||||
from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
|
|
@ -61,7 +61,7 @@ async def register_user(
|
|||
) from exc
|
||||
|
||||
await repo.session.refresh(user)
|
||||
return auth_service.create_access_token(user)
|
||||
return auth_service.issue_tokens(user)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
|
|
@ -74,7 +74,7 @@ async def login(
|
|||
user = await service.authenticate(credentials.email, credentials.password)
|
||||
except InvalidCredentialsError as exc: # pragma: no cover - thin API
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
|
||||
return service.create_access_token(user)
|
||||
return service.issue_tokens(user)
|
||||
|
||||
|
||||
@router.post("/token", response_model=TokenResponse)
|
||||
|
|
@ -86,4 +86,15 @@ async def login_for_access_token(
|
|||
user = await service.authenticate(credentials.email, credentials.password)
|
||||
except InvalidCredentialsError as exc: # pragma: no cover - thin API
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
|
||||
return service.create_access_token(user)
|
||||
return service.issue_tokens(user)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh_tokens(
|
||||
payload: RefreshRequest,
|
||||
service: AuthService = Depends(get_auth_service),
|
||||
) -> TokenResponse:
|
||||
try:
|
||||
return await service.refresh_tokens(payload.refresh_token)
|
||||
except InvalidRefreshTokenError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class Settings(BaseSettings):
|
|||
jwt_secret_key: SecretStr = Field(default=SecretStr("change-me"))
|
||||
jwt_algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 30
|
||||
refresh_token_expire_days: int = 7
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
|
|
|||
|
|
@ -14,10 +14,16 @@ class TokenPayload(BaseModel):
|
|||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
refresh_expires_in: int
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.security import JWTService, PasswordHasher
|
||||
|
|
@ -14,6 +17,10 @@ class InvalidCredentialsError(Exception):
|
|||
"""Raised when user authentication fails."""
|
||||
|
||||
|
||||
class InvalidRefreshTokenError(Exception):
|
||||
"""Raised when refresh token validation fails."""
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Handles authentication flows and token issuance."""
|
||||
|
||||
|
|
@ -33,11 +40,47 @@ class AuthService:
|
|||
raise InvalidCredentialsError("Invalid email or password")
|
||||
return user
|
||||
|
||||
def create_access_token(self, user: User) -> TokenResponse:
|
||||
expires_delta = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
token = self._jwt_service.create_access_token(
|
||||
def issue_tokens(self, user: User) -> TokenResponse:
|
||||
access_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||
refresh_expires = timedelta(days=settings.refresh_token_expire_days)
|
||||
access_token = self._jwt_service.create_access_token(
|
||||
subject=str(user.id),
|
||||
expires_delta=expires_delta,
|
||||
claims={"email": user.email},
|
||||
expires_delta=access_expires,
|
||||
claims={"email": user.email, "scope": "access"},
|
||||
)
|
||||
return TokenResponse(access_token=token, expires_in=int(expires_delta.total_seconds()))
|
||||
refresh_token = self._jwt_service.create_access_token(
|
||||
subject=str(user.id),
|
||||
expires_delta=refresh_expires,
|
||||
claims={"scope": "refresh"},
|
||||
)
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=int(access_expires.total_seconds()),
|
||||
refresh_expires_in=int(refresh_expires.total_seconds()),
|
||||
)
|
||||
|
||||
async def refresh_tokens(self, refresh_token: str) -> TokenResponse:
|
||||
payload = self._decode_refresh_token(refresh_token)
|
||||
sub = payload.get("sub")
|
||||
if sub is None:
|
||||
raise InvalidRefreshTokenError("Invalid refresh token")
|
||||
|
||||
try:
|
||||
user_id = int(sub)
|
||||
except (TypeError, ValueError) as exc: # pragma: no cover - defensive
|
||||
raise InvalidRefreshTokenError("Invalid refresh token") from exc
|
||||
|
||||
user = await self._user_repository.get_by_id(user_id)
|
||||
if user is None:
|
||||
raise InvalidRefreshTokenError("Invalid refresh token")
|
||||
return self.issue_tokens(user)
|
||||
|
||||
def _decode_refresh_token(self, token: str) -> dict[str, Any]:
|
||||
try:
|
||||
payload = self._jwt_service.decode(token)
|
||||
except jwt.PyJWTError as exc:
|
||||
raise InvalidRefreshTokenError("Invalid refresh token") from exc
|
||||
if payload.get("scope") != "refresh":
|
||||
raise InvalidRefreshTokenError("Invalid refresh token")
|
||||
return payload
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ async def test_register_user_creates_organization_membership(
|
|||
body = response.json()
|
||||
assert body["token_type"] == "bearer"
|
||||
assert "access_token" in body
|
||||
assert "refresh_token" in body
|
||||
|
||||
async with session_factory() as session:
|
||||
user = await session.scalar(select(User).where(User.email == payload["email"]))
|
||||
|
|
@ -74,6 +75,7 @@ async def test_login_endpoint_returns_token_for_valid_credentials(
|
|||
body = response.json()
|
||||
assert body["token_type"] == "bearer"
|
||||
assert "access_token" in body
|
||||
assert "refresh_token" in body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -98,3 +100,47 @@ async def test_token_endpoint_rejects_invalid_credentials(
|
|||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Invalid email or password"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_endpoint_returns_new_tokens(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
client: AsyncClient,
|
||||
) -> None:
|
||||
async with session_factory() as session:
|
||||
user = User(
|
||||
email="refresh-user@example.com",
|
||||
hashed_password=password_hasher.hash("StrongPass123"),
|
||||
name="Refresh User",
|
||||
is_active=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
login_response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "refresh-user@example.com", "password": "StrongPass123"},
|
||||
)
|
||||
assert login_response.status_code == 200
|
||||
refresh_token = login_response.json()["refresh_token"]
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "access_token" in body
|
||||
assert "refresh_token" in body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_endpoint_rejects_invalid_token(client: AsyncClient) -> None:
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "not-a-jwt"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Invalid refresh token"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.core.security import JWTService, PasswordHasher
|
||||
from app.models.user import User
|
||||
from app.repositories.user_repo import UserRepository
|
||||
from app.services.auth_service import AuthService, InvalidCredentialsError
|
||||
from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError
|
||||
|
||||
|
||||
class StubUserRepository(UserRepository):
|
||||
|
|
@ -25,6 +25,11 @@ class StubUserRepository(UserRepository):
|
|||
return self._user
|
||||
return None
|
||||
|
||||
async def get_by_id(self, user_id: int) -> User | None: # pragma: no cover - helper
|
||||
if self._user and self._user.id == user_id:
|
||||
return self._user
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def password_hasher() -> PasswordHasher:
|
||||
|
|
@ -71,7 +76,7 @@ async def test_authenticate_invalid_credentials(
|
|||
await service.authenticate("user@example.com", "wrong-pass")
|
||||
|
||||
|
||||
def test_create_access_token_contains_user_claims(
|
||||
def test_issue_tokens_contains_user_claims(
|
||||
password_hasher: PasswordHasher,
|
||||
jwt_service: JWTService,
|
||||
) -> None:
|
||||
|
|
@ -79,9 +84,43 @@ def test_create_access_token_contains_user_claims(
|
|||
user.id = 42
|
||||
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
|
||||
|
||||
token = service.create_access_token(user)
|
||||
payload = jwt_service.decode(token.access_token)
|
||||
token_pair = service.issue_tokens(user)
|
||||
payload = jwt_service.decode(token_pair.access_token)
|
||||
|
||||
assert payload["sub"] == str(user.id)
|
||||
assert payload["email"] == user.email
|
||||
assert token.expires_in > 0
|
||||
assert payload["scope"] == "access"
|
||||
assert token_pair.refresh_token
|
||||
assert token_pair.expires_in > 0
|
||||
assert token_pair.refresh_expires_in > token_pair.expires_in
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_returns_new_pair(
|
||||
password_hasher: PasswordHasher,
|
||||
jwt_service: JWTService,
|
||||
) -> None:
|
||||
user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True)
|
||||
user.id = 7
|
||||
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
|
||||
|
||||
initial = service.issue_tokens(user)
|
||||
refreshed = await service.refresh_tokens(initial.refresh_token)
|
||||
|
||||
assert refreshed.access_token
|
||||
assert refreshed.refresh_token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_rejects_access_token(
|
||||
password_hasher: PasswordHasher,
|
||||
jwt_service: JWTService,
|
||||
) -> None:
|
||||
user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True)
|
||||
user.id = 9
|
||||
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
|
||||
|
||||
pair = service.issue_tokens(user)
|
||||
|
||||
with pytest.raises(InvalidRefreshTokenError):
|
||||
await service.refresh_tokens(pair.access_token)
|
||||
|
|
|
|||
Loading…
Reference in New Issue