feat: implement refresh token functionality; update authentication and token models; add tests for refresh endpoint
Test / test (push) Successful in 13s
Details
Test / test (push) Successful in 13s
Details
This commit is contained in:
parent
a8bdf18e38
commit
6db1e865f6
|
|
@ -118,6 +118,9 @@ async def get_current_user(
|
||||||
sub = payload.get("sub")
|
sub = payload.get("sub")
|
||||||
if sub is None:
|
if sub is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
scope = payload.get("scope", "access")
|
||||||
|
if scope != "access":
|
||||||
|
raise credentials_exception
|
||||||
user_id = int(sub)
|
user_id = int(sub)
|
||||||
except (jwt.PyJWTError, TypeError, ValueError):
|
except (jwt.PyJWTError, TypeError, ValueError):
|
||||||
raise credentials_exception from None
|
raise credentials_exception from None
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,10 @@ from app.api.deps import get_auth_service, get_user_repository
|
||||||
from app.core.security import password_hasher
|
from app.core.security import password_hasher
|
||||||
from app.models.organization import Organization
|
from app.models.organization import Organization
|
||||||
from app.models.organization_member import OrganizationMember, OrganizationRole
|
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||||
from app.models.token import LoginRequest, TokenResponse
|
from app.models.token import LoginRequest, RefreshRequest, TokenResponse
|
||||||
from app.models.user import UserCreate
|
from app.models.user import UserCreate
|
||||||
from app.repositories.user_repo import UserRepository
|
from app.repositories.user_repo import UserRepository
|
||||||
from app.services.auth_service import AuthService, InvalidCredentialsError
|
from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError
|
||||||
|
|
||||||
|
|
||||||
class RegisterRequest(BaseModel):
|
class RegisterRequest(BaseModel):
|
||||||
|
|
@ -61,7 +61,7 @@ async def register_user(
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
await repo.session.refresh(user)
|
await repo.session.refresh(user)
|
||||||
return auth_service.create_access_token(user)
|
return auth_service.issue_tokens(user)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=TokenResponse)
|
@router.post("/login", response_model=TokenResponse)
|
||||||
|
|
@ -74,7 +74,7 @@ async def login(
|
||||||
user = await service.authenticate(credentials.email, credentials.password)
|
user = await service.authenticate(credentials.email, credentials.password)
|
||||||
except InvalidCredentialsError as exc: # pragma: no cover - thin API
|
except InvalidCredentialsError as exc: # pragma: no cover - thin API
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
|
||||||
return service.create_access_token(user)
|
return service.issue_tokens(user)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/token", response_model=TokenResponse)
|
@router.post("/token", response_model=TokenResponse)
|
||||||
|
|
@ -86,4 +86,15 @@ async def login_for_access_token(
|
||||||
user = await service.authenticate(credentials.email, credentials.password)
|
user = await service.authenticate(credentials.email, credentials.password)
|
||||||
except InvalidCredentialsError as exc: # pragma: no cover - thin API
|
except InvalidCredentialsError as exc: # pragma: no cover - thin API
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
|
||||||
return service.create_access_token(user)
|
return service.issue_tokens(user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=TokenResponse)
|
||||||
|
async def refresh_tokens(
|
||||||
|
payload: RefreshRequest,
|
||||||
|
service: AuthService = Depends(get_auth_service),
|
||||||
|
) -> TokenResponse:
|
||||||
|
try:
|
||||||
|
return await service.refresh_tokens(payload.refresh_token)
|
||||||
|
except InvalidRefreshTokenError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) from exc
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ class Settings(BaseSettings):
|
||||||
jwt_secret_key: SecretStr = Field(default=SecretStr("change-me"))
|
jwt_secret_key: SecretStr = Field(default=SecretStr("change-me"))
|
||||||
jwt_algorithm: str = "HS256"
|
jwt_algorithm: str = "HS256"
|
||||||
access_token_expire_minutes: int = 30
|
access_token_expire_minutes: int = 30
|
||||||
|
refresh_token_expire_days: int = 7
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,16 @@ class TokenPayload(BaseModel):
|
||||||
|
|
||||||
class TokenResponse(BaseModel):
|
class TokenResponse(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
token_type: str = "bearer"
|
token_type: str = "bearer"
|
||||||
expires_in: int
|
expires_in: int
|
||||||
|
refresh_expires_in: int
|
||||||
|
|
||||||
|
|
||||||
class LoginRequest(BaseModel):
|
class LoginRequest(BaseModel):
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshRequest(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.security import JWTService, PasswordHasher
|
from app.core.security import JWTService, PasswordHasher
|
||||||
|
|
@ -14,6 +17,10 @@ class InvalidCredentialsError(Exception):
|
||||||
"""Raised when user authentication fails."""
|
"""Raised when user authentication fails."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRefreshTokenError(Exception):
|
||||||
|
"""Raised when refresh token validation fails."""
|
||||||
|
|
||||||
|
|
||||||
class AuthService:
|
class AuthService:
|
||||||
"""Handles authentication flows and token issuance."""
|
"""Handles authentication flows and token issuance."""
|
||||||
|
|
||||||
|
|
@ -33,11 +40,47 @@ class AuthService:
|
||||||
raise InvalidCredentialsError("Invalid email or password")
|
raise InvalidCredentialsError("Invalid email or password")
|
||||||
return user
|
return user
|
||||||
|
|
||||||
def create_access_token(self, user: User) -> TokenResponse:
|
def issue_tokens(self, user: User) -> TokenResponse:
|
||||||
expires_delta = timedelta(minutes=settings.access_token_expire_minutes)
|
access_expires = timedelta(minutes=settings.access_token_expire_minutes)
|
||||||
token = self._jwt_service.create_access_token(
|
refresh_expires = timedelta(days=settings.refresh_token_expire_days)
|
||||||
|
access_token = self._jwt_service.create_access_token(
|
||||||
subject=str(user.id),
|
subject=str(user.id),
|
||||||
expires_delta=expires_delta,
|
expires_delta=access_expires,
|
||||||
claims={"email": user.email},
|
claims={"email": user.email, "scope": "access"},
|
||||||
)
|
)
|
||||||
return TokenResponse(access_token=token, expires_in=int(expires_delta.total_seconds()))
|
refresh_token = self._jwt_service.create_access_token(
|
||||||
|
subject=str(user.id),
|
||||||
|
expires_delta=refresh_expires,
|
||||||
|
claims={"scope": "refresh"},
|
||||||
|
)
|
||||||
|
return TokenResponse(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
expires_in=int(access_expires.total_seconds()),
|
||||||
|
refresh_expires_in=int(refresh_expires.total_seconds()),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def refresh_tokens(self, refresh_token: str) -> TokenResponse:
|
||||||
|
payload = self._decode_refresh_token(refresh_token)
|
||||||
|
sub = payload.get("sub")
|
||||||
|
if sub is None:
|
||||||
|
raise InvalidRefreshTokenError("Invalid refresh token")
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_id = int(sub)
|
||||||
|
except (TypeError, ValueError) as exc: # pragma: no cover - defensive
|
||||||
|
raise InvalidRefreshTokenError("Invalid refresh token") from exc
|
||||||
|
|
||||||
|
user = await self._user_repository.get_by_id(user_id)
|
||||||
|
if user is None:
|
||||||
|
raise InvalidRefreshTokenError("Invalid refresh token")
|
||||||
|
return self.issue_tokens(user)
|
||||||
|
|
||||||
|
def _decode_refresh_token(self, token: str) -> dict[str, Any]:
|
||||||
|
try:
|
||||||
|
payload = self._jwt_service.decode(token)
|
||||||
|
except jwt.PyJWTError as exc:
|
||||||
|
raise InvalidRefreshTokenError("Invalid refresh token") from exc
|
||||||
|
if payload.get("scope") != "refresh":
|
||||||
|
raise InvalidRefreshTokenError("Invalid refresh token")
|
||||||
|
return payload
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ async def test_register_user_creates_organization_membership(
|
||||||
body = response.json()
|
body = response.json()
|
||||||
assert body["token_type"] == "bearer"
|
assert body["token_type"] == "bearer"
|
||||||
assert "access_token" in body
|
assert "access_token" in body
|
||||||
|
assert "refresh_token" in body
|
||||||
|
|
||||||
async with session_factory() as session:
|
async with session_factory() as session:
|
||||||
user = await session.scalar(select(User).where(User.email == payload["email"]))
|
user = await session.scalar(select(User).where(User.email == payload["email"]))
|
||||||
|
|
@ -74,6 +75,7 @@ async def test_login_endpoint_returns_token_for_valid_credentials(
|
||||||
body = response.json()
|
body = response.json()
|
||||||
assert body["token_type"] == "bearer"
|
assert body["token_type"] == "bearer"
|
||||||
assert "access_token" in body
|
assert "access_token" in body
|
||||||
|
assert "refresh_token" in body
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -98,3 +100,47 @@ async def test_token_endpoint_rejects_invalid_credentials(
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert response.json()["detail"] == "Invalid email or password"
|
assert response.json()["detail"] == "Invalid email or password"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_endpoint_returns_new_tokens(
|
||||||
|
session_factory: async_sessionmaker[AsyncSession],
|
||||||
|
client: AsyncClient,
|
||||||
|
) -> None:
|
||||||
|
async with session_factory() as session:
|
||||||
|
user = User(
|
||||||
|
email="refresh-user@example.com",
|
||||||
|
hashed_password=password_hasher.hash("StrongPass123"),
|
||||||
|
name="Refresh User",
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
login_response = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": "refresh-user@example.com", "password": "StrongPass123"},
|
||||||
|
)
|
||||||
|
assert login_response.status_code == 200
|
||||||
|
refresh_token = login_response.json()["refresh_token"]
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
json={"refresh_token": refresh_token},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert "access_token" in body
|
||||||
|
assert "refresh_token" in body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_endpoint_rejects_invalid_token(client: AsyncClient) -> None:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
json={"refresh_token": "not-a-jwt"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["detail"] == "Invalid refresh token"
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from app.core.security import JWTService, PasswordHasher
|
from app.core.security import JWTService, PasswordHasher
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.repositories.user_repo import UserRepository
|
from app.repositories.user_repo import UserRepository
|
||||||
from app.services.auth_service import AuthService, InvalidCredentialsError
|
from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError
|
||||||
|
|
||||||
|
|
||||||
class StubUserRepository(UserRepository):
|
class StubUserRepository(UserRepository):
|
||||||
|
|
@ -25,6 +25,11 @@ class StubUserRepository(UserRepository):
|
||||||
return self._user
|
return self._user
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_by_id(self, user_id: int) -> User | None: # pragma: no cover - helper
|
||||||
|
if self._user and self._user.id == user_id:
|
||||||
|
return self._user
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def password_hasher() -> PasswordHasher:
|
def password_hasher() -> PasswordHasher:
|
||||||
|
|
@ -71,7 +76,7 @@ async def test_authenticate_invalid_credentials(
|
||||||
await service.authenticate("user@example.com", "wrong-pass")
|
await service.authenticate("user@example.com", "wrong-pass")
|
||||||
|
|
||||||
|
|
||||||
def test_create_access_token_contains_user_claims(
|
def test_issue_tokens_contains_user_claims(
|
||||||
password_hasher: PasswordHasher,
|
password_hasher: PasswordHasher,
|
||||||
jwt_service: JWTService,
|
jwt_service: JWTService,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -79,9 +84,43 @@ def test_create_access_token_contains_user_claims(
|
||||||
user.id = 42
|
user.id = 42
|
||||||
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
|
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
|
||||||
|
|
||||||
token = service.create_access_token(user)
|
token_pair = service.issue_tokens(user)
|
||||||
payload = jwt_service.decode(token.access_token)
|
payload = jwt_service.decode(token_pair.access_token)
|
||||||
|
|
||||||
assert payload["sub"] == str(user.id)
|
assert payload["sub"] == str(user.id)
|
||||||
assert payload["email"] == user.email
|
assert payload["email"] == user.email
|
||||||
assert token.expires_in > 0
|
assert payload["scope"] == "access"
|
||||||
|
assert token_pair.refresh_token
|
||||||
|
assert token_pair.expires_in > 0
|
||||||
|
assert token_pair.refresh_expires_in > token_pair.expires_in
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_tokens_returns_new_pair(
|
||||||
|
password_hasher: PasswordHasher,
|
||||||
|
jwt_service: JWTService,
|
||||||
|
) -> None:
|
||||||
|
user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True)
|
||||||
|
user.id = 7
|
||||||
|
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
|
||||||
|
|
||||||
|
initial = service.issue_tokens(user)
|
||||||
|
refreshed = await service.refresh_tokens(initial.refresh_token)
|
||||||
|
|
||||||
|
assert refreshed.access_token
|
||||||
|
assert refreshed.refresh_token
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_tokens_rejects_access_token(
|
||||||
|
password_hasher: PasswordHasher,
|
||||||
|
jwt_service: JWTService,
|
||||||
|
) -> None:
|
||||||
|
user = User(email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True)
|
||||||
|
user.id = 9
|
||||||
|
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
|
||||||
|
|
||||||
|
pair = service.issue_tokens(user)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidRefreshTokenError):
|
||||||
|
await service.refresh_tokens(pair.access_token)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue