133 lines
4.3 KiB
Python
133 lines
4.3 KiB
Python
"""Unit tests for AuthService."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import cast
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest # type: ignore[import-not-found]
|
|
from app.core.security import JWTService, PasswordHasher
|
|
from app.models.user import User
|
|
from app.repositories.user_repo import UserRepository
|
|
from app.services.auth_service import AuthService, InvalidCredentialsError, InvalidRefreshTokenError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
class StubUserRepository(UserRepository):
|
|
"""In-memory stand-in for UserRepository."""
|
|
|
|
def __init__(self, user: User | None) -> None:
|
|
super().__init__(session=MagicMock(spec=AsyncSession))
|
|
self._user = user
|
|
|
|
async def get_by_email(self, email: str) -> User | None: # pragma: no cover - helper
|
|
if self._user and self._user.email == email:
|
|
return self._user
|
|
return None
|
|
|
|
async def get_by_id(self, user_id: int) -> User | None: # pragma: no cover - helper
|
|
if self._user and self._user.id == user_id:
|
|
return self._user
|
|
return None
|
|
|
|
|
|
@pytest.fixture()
|
|
def password_hasher() -> PasswordHasher:
|
|
class DummyPasswordHasher:
|
|
def hash(self, password: str) -> str: # pragma: no cover - trivial
|
|
return f"hashed::{password}"
|
|
|
|
def verify(self, password: str, hashed_password: str) -> bool: # pragma: no cover - trivial
|
|
return hashed_password == self.hash(password)
|
|
|
|
return cast(PasswordHasher, DummyPasswordHasher())
|
|
|
|
|
|
@pytest.fixture()
|
|
def jwt_service() -> JWTService:
|
|
return JWTService(secret_key="unit-test-secret", algorithm="HS256")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_authenticate_success(
|
|
password_hasher: PasswordHasher, jwt_service: JWTService,
|
|
) -> None:
|
|
hashed = password_hasher.hash("StrongPass123")
|
|
user = User(email="user@example.com", hashed_password=hashed, name="Alice", is_active=True)
|
|
user.id = 1
|
|
repo = StubUserRepository(user)
|
|
service = AuthService(repo, password_hasher, jwt_service)
|
|
|
|
authenticated = await service.authenticate("user@example.com", "StrongPass123")
|
|
|
|
assert authenticated is user
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_authenticate_invalid_credentials(
|
|
password_hasher: PasswordHasher,
|
|
jwt_service: JWTService,
|
|
) -> None:
|
|
hashed = password_hasher.hash("StrongPass123")
|
|
user = User(email="user@example.com", hashed_password=hashed, name="Alice", is_active=True)
|
|
user.id = 1
|
|
repo = StubUserRepository(user)
|
|
service = AuthService(repo, password_hasher, jwt_service)
|
|
|
|
with pytest.raises(InvalidCredentialsError):
|
|
await service.authenticate("user@example.com", "wrong-pass")
|
|
|
|
|
|
def test_issue_tokens_contains_user_claims(
|
|
password_hasher: PasswordHasher,
|
|
jwt_service: JWTService,
|
|
) -> None:
|
|
user = User(email="user@example.com", hashed_password="hashed", name="Alice", is_active=True)
|
|
user.id = 42
|
|
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
|
|
|
|
token_pair = service.issue_tokens(user)
|
|
payload = jwt_service.decode(token_pair.access_token)
|
|
|
|
assert payload["sub"] == str(user.id)
|
|
assert payload["email"] == user.email
|
|
assert payload["scope"] == "access"
|
|
assert token_pair.refresh_token
|
|
assert token_pair.expires_in > 0
|
|
assert token_pair.refresh_expires_in > token_pair.expires_in
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_tokens_returns_new_pair(
|
|
password_hasher: PasswordHasher,
|
|
jwt_service: JWTService,
|
|
) -> None:
|
|
user = User(
|
|
email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True,
|
|
)
|
|
user.id = 7
|
|
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
|
|
|
|
initial = service.issue_tokens(user)
|
|
refreshed = await service.refresh_tokens(initial.refresh_token)
|
|
|
|
assert refreshed.access_token
|
|
assert refreshed.refresh_token
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_tokens_rejects_access_token(
|
|
password_hasher: PasswordHasher,
|
|
jwt_service: JWTService,
|
|
) -> None:
|
|
user = User(
|
|
email="refresh@example.com", hashed_password="hashed", name="Refresh", is_active=True,
|
|
)
|
|
user.id = 9
|
|
service = AuthService(StubUserRepository(user), password_hasher, jwt_service)
|
|
|
|
pair = service.issue_tokens(user)
|
|
|
|
with pytest.raises(InvalidRefreshTokenError):
|
|
await service.refresh_tokens(pair.access_token)
|