refactor: enhance type hinting and casting for improved type safety across multiple files

This commit is contained in:
Artem Kashaev 2025-12-01 16:44:14 +05:00
parent f234e60e65
commit 688ade0452
14 changed files with 62 additions and 42 deletions

View File

@ -6,7 +6,7 @@ import asyncio
import json
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from typing import Any, cast
import redis.asyncio as redis
from app.core.config import settings
@ -45,10 +45,13 @@ class RedisCacheManager:
async with self._lock:
if self._client is not None:
return
self._client = redis.from_url(
settings.redis_url,
encoding="utf-8",
decode_responses=False,
self._client = cast(
Redis,
redis.from_url( # type: ignore[no-untyped-call]
settings.redis_url,
encoding="utf-8",
decode_responses=False,
),
)
await self._refresh_availability()
@ -64,10 +67,13 @@ class RedisCacheManager:
return
async with self._lock:
if self._client is None:
self._client = redis.from_url(
settings.redis_url,
encoding="utf-8",
decode_responses=False,
self._client = cast(
Redis,
redis.from_url( # type: ignore[no-untyped-call]
settings.redis_url,
encoding="utf-8",
decode_responses=False,
),
)
await self._refresh_availability()
@ -76,7 +82,7 @@ class RedisCacheManager:
self._available = False
return
try:
await self._client.ping()
await cast(Awaitable[Any], self._client.ping())
except RedisError as exc: # pragma: no cover - logging only
self._available = False
logger.warning("Redis ping failed: %s", exc)
@ -140,8 +146,8 @@ async def write_json(
"""Serialize data to JSON and store it with TTL using retry/backoff."""
payload = json.dumps(value, separators=(",", ":"), ensure_ascii=True).encode("utf-8")
async def _operation() -> Any:
return await client.set(name=key, value=payload, ex=ttl_seconds)
async def _operation() -> None:
await client.set(name=key, value=payload, ex=ttl_seconds)
await _run_with_retry(_operation, backoff_ms)
@ -151,8 +157,8 @@ async def delete_keys(client: Redis, keys: list[str], backoff_ms: int) -> None:
if not keys:
return
async def _operation() -> Any:
return await client.delete(*keys)
async def _operation() -> None:
await client.delete(*keys)
await _run_with_retry(_operation, backoff_ms)

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime, timedelta, timezone
from typing import Any
from typing import Any, cast
import jwt
from app.core.config import settings
@ -18,10 +18,10 @@ class PasswordHasher:
self._context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
def hash(self, password: str) -> str:
return self._context.hash(password)
return cast(str, self._context.hash(password))
def verify(self, password: str, hashed_password: str) -> bool:
return self._context.verify(password, hashed_password)
return bool(self._context.verify(password, hashed_password))
class JWTService:
@ -45,10 +45,10 @@ class JWTService:
}
if claims:
payload.update(claims)
return jwt.encode(payload, self._secret_key, algorithm=self._algorithm)
return cast(str, jwt.encode(payload, self._secret_key, algorithm=self._algorithm))
def decode(self, token: str) -> dict[str, Any]:
return jwt.decode(token, self._secret_key, algorithms=[self._algorithm])
return cast(dict[str, Any], jwt.decode(token, self._secret_key, algorithms=[self._algorithm]))
password_hasher = PasswordHasher()

View File

@ -10,9 +10,11 @@ from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import DateTime, ForeignKey, Integer, func, text
from sqlalchemy import Enum as SqlEnum
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.engine import Dialect
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.types import JSON as SA_JSON
from sqlalchemy.types import TypeDecorator
from sqlalchemy.sql.type_api import TypeEngine
from app.models.base import Base, enum_values
@ -31,7 +33,7 @@ class JSONBCompat(TypeDecorator):
impl = JSONB
cache_ok = True
def load_dialect_impl(self, dialect): # type: ignore[override]
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "sqlite":
from sqlalchemy.dialects.sqlite import JSON as SQLITE_JSON # local import

View File

@ -14,7 +14,7 @@ class Base(DeclarativeBase):
"""Base class that configures naming conventions."""
@declared_attr.directive
def __tablename__(cls) -> str: # type: ignore[misc] # noqa: N805 - SQLAlchemy expects cls
def __tablename__(cls) -> str: # noqa: N805 - SQLAlchemy expects cls
return cls.__name__.lower()

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from typing import Any, cast
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
@ -123,6 +123,9 @@ class TaskRepository:
async def _resolve_task_owner(self, task: Task) -> int | None:
if task.deal is not None:
return task.deal.owner_id
return int(task.deal.owner_id)
stmt = select(Deal.owner_id).where(Deal.id == task.deal_id)
return await self._session.scalar(stmt)
owner_id_raw: Any = await self._session.scalar(stmt)
if owner_id_raw is None:
return None
return cast(int, owner_id_raw)

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import cast
from sqlalchemy import select
@ -151,11 +152,11 @@ class ContactService:
def _build_update_mapping(self, updates: ContactUpdateData) -> dict[str, str | None]:
payload: dict[str, str | None] = {}
if updates.name is not UNSET:
payload["name"] = updates.name
payload["name"] = cast(str | None, updates.name)
if updates.email is not UNSET:
payload["email"] = updates.email
payload["email"] = cast(str | None, updates.email)
if updates.phone is not UNSET:
payload["phone"] = updates.phone
payload["phone"] = cast(str | None, updates.phone)
return payload
async def _ensure_no_related_deals(self, contact_id: int) -> None:

0
tests/__init__.py Normal file
View File

0
tests/api/__init__.py Normal file
View File

0
tests/api/v1/__init__.py Normal file
View File

View File

@ -3,10 +3,12 @@
from __future__ import annotations
from enum import StrEnum
from typing import cast
from app.models.activity import Activity, ActivityType
from app.models.deal import Deal, DealStage, DealStatus
from app.models.organization_member import OrganizationMember, OrganizationRole
from sqlalchemy import Enum as SqlEnum
def _values(enum_cls: type[StrEnum]) -> list[str]:
@ -14,20 +16,20 @@ def _values(enum_cls: type[StrEnum]) -> list[str]:
def test_organization_role_column_uses_value_strings() -> None:
role_type = OrganizationMember.__table__.c.role.type # noqa: SLF001 - runtime inspection
role_type = cast(SqlEnum, OrganizationMember.__table__.c.role.type) # noqa: SLF001
assert role_type.enums == _values(OrganizationRole)
def test_deal_status_column_uses_value_strings() -> None:
status_type = Deal.__table__.c.status.type # noqa: SLF001 - runtime inspection
status_type = cast(SqlEnum, Deal.__table__.c.status.type) # noqa: SLF001
assert status_type.enums == _values(DealStatus)
def test_deal_stage_column_uses_value_strings() -> None:
stage_type = Deal.__table__.c.stage.type # noqa: SLF001 - runtime inspection
stage_type = cast(SqlEnum, Deal.__table__.c.stage.type) # noqa: SLF001
assert stage_type.enums == _values(DealStage)
def test_activity_type_column_uses_value_strings() -> None:
activity_type = Activity.__table__.c.type.type # noqa: SLF001 - runtime inspection
activity_type = cast(SqlEnum, Activity.__table__.c.type.type) # noqa: SLF001
assert activity_type.enums == _values(ActivityType)

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from collections.abc import AsyncGenerator
from datetime import datetime, timedelta, timezone
from decimal import Decimal
from typing import cast
import pytest
import pytest_asyncio
@ -14,8 +15,13 @@ from app.models.deal import Deal, DealStage, DealStatus
from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.user import User
from app.repositories.analytics_repo import AnalyticsRepository
from app.repositories.analytics_repo import (
AnalyticsRepository,
StageStatusRollup,
StatusRollup,
)
from app.services.analytics_service import AnalyticsService, invalidate_analytics_cache
from redis.asyncio.client import Redis
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from tests.utils.fake_redis import InMemoryRedis
@ -170,20 +176,20 @@ async def test_funnel_breakdown_contains_stage_conversions(session: AsyncSession
class _ExplodingRepository(AnalyticsRepository):
async def fetch_status_rollup(self, organization_id: int): # type: ignore[override]
async def fetch_status_rollup(self, organization_id: int) -> list[StatusRollup]:
raise AssertionError("cache not used for status rollup")
async def count_new_deals_since(self, organization_id: int, threshold): # type: ignore[override]
async def count_new_deals_since(self, organization_id: int, threshold: datetime) -> int:
raise AssertionError("cache not used for new deal count")
async def fetch_stage_status_rollup(self, organization_id: int): # type: ignore[override]
async def fetch_stage_status_rollup(self, organization_id: int) -> list[StageStatusRollup]:
raise AssertionError("cache not used for funnel rollup")
@pytest.mark.asyncio
async def test_summary_reads_from_cache_when_available(session: AsyncSession) -> None:
org_id, _, _ = await _seed_data(session)
cache = InMemoryRedis()
cache = cast(Redis, InMemoryRedis())
service = AnalyticsService(
repository=AnalyticsRepository(session),
cache=cache,
@ -201,7 +207,7 @@ async def test_summary_reads_from_cache_when_available(session: AsyncSession) ->
@pytest.mark.asyncio
async def test_invalidation_refreshes_cached_summary(session: AsyncSession) -> None:
org_id, _, contact_id = await _seed_data(session)
cache = InMemoryRedis()
cache = cast(Redis, InMemoryRedis())
service = AnalyticsService(
repository=AnalyticsRepository(session),
cache=cache,
@ -235,7 +241,7 @@ async def test_invalidation_refreshes_cached_summary(session: AsyncSession) -> N
@pytest.mark.asyncio
async def test_funnel_reads_from_cache_when_available(session: AsyncSession) -> None:
org_id, _, _ = await _seed_data(session)
cache = InMemoryRedis()
cache = cast(Redis, InMemoryRedis())
service = AnalyticsService(
repository=AnalyticsRepository(session),
cache=cache,

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from typing import cast
from unittest.mock import MagicMock
import pytest # type: ignore[import-not-found]
import pytest
from app.core.security import JWTService, PasswordHasher
from app.models.user import User
from app.repositories.user_repo import UserRepository

View File

@ -6,8 +6,8 @@ import uuid
from collections.abc import AsyncGenerator
from decimal import Decimal
import pytest # type: ignore[import-not-found]
import pytest_asyncio # type: ignore[import-not-found]
import pytest
import pytest_asyncio
from app.models.activity import Activity, ActivityType
from app.models.base import Base
from app.models.contact import Contact

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from typing import cast
from unittest.mock import MagicMock
import pytest # type: ignore[import-not-found]
import pytest
from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole
from app.repositories.org_repo import OrganizationRepository