From 688ade04520abb8a9fa71983359e6c72fc889dc2 Mon Sep 17 00:00:00 2001 From: Artem Kashaev Date: Mon, 1 Dec 2025 16:44:14 +0500 Subject: [PATCH] refactor: enhance type hinting and casting for improved type safety across multiple files --- app/core/cache.py | 34 ++++++++++++--------- app/core/security.py | 10 +++--- app/models/activity.py | 4 ++- app/models/base.py | 2 +- app/repositories/task_repo.py | 9 ++++-- app/services/contact_service.py | 7 +++-- tests/__init__.py | 0 tests/api/__init__.py | 0 tests/api/v1/__init__.py | 0 tests/models/test_enums.py | 10 +++--- tests/services/test_analytics_service.py | 20 +++++++----- tests/services/test_auth_service.py | 2 +- tests/services/test_deal_service.py | 4 +-- tests/services/test_organization_service.py | 2 +- 14 files changed, 62 insertions(+), 42 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/api/__init__.py create mode 100644 tests/api/v1/__init__.py diff --git a/app/core/cache.py b/app/core/cache.py index 8dd004b..b9f1f06 100644 --- a/app/core/cache.py +++ b/app/core/cache.py @@ -6,7 +6,7 @@ import asyncio import json import logging from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, cast import redis.asyncio as redis from app.core.config import settings @@ -45,10 +45,13 @@ class RedisCacheManager: async with self._lock: if self._client is not None: return - self._client = redis.from_url( - settings.redis_url, - encoding="utf-8", - decode_responses=False, + self._client = cast( + Redis, + redis.from_url( # type: ignore[no-untyped-call] + settings.redis_url, + encoding="utf-8", + decode_responses=False, + ), ) await self._refresh_availability() @@ -64,10 +67,13 @@ class RedisCacheManager: return async with self._lock: if self._client is None: - self._client = redis.from_url( - settings.redis_url, - encoding="utf-8", - decode_responses=False, + self._client = cast( + Redis, + redis.from_url( # type: ignore[no-untyped-call] + settings.redis_url, + encoding="utf-8", + decode_responses=False, + ), ) await self._refresh_availability() @@ -76,7 +82,7 @@ class RedisCacheManager: self._available = False return try: - await self._client.ping() + await cast(Awaitable[Any], self._client.ping()) except RedisError as exc: # pragma: no cover - logging only self._available = False logger.warning("Redis ping failed: %s", exc) @@ -140,8 +146,8 @@ async def write_json( """Serialize data to JSON and store it with TTL using retry/backoff.""" payload = json.dumps(value, separators=(",", ":"), ensure_ascii=True).encode("utf-8") - async def _operation() -> Any: - return await client.set(name=key, value=payload, ex=ttl_seconds) + async def _operation() -> None: + await client.set(name=key, value=payload, ex=ttl_seconds) await _run_with_retry(_operation, backoff_ms) @@ -151,8 +157,8 @@ async def delete_keys(client: Redis, keys: list[str], backoff_ms: int) -> None: if not keys: return - async def _operation() -> Any: - return await client.delete(*keys) + async def _operation() -> None: + await client.delete(*keys) await _run_with_retry(_operation, backoff_ms) diff --git a/app/core/security.py b/app/core/security.py index e2daf66..91bb5cd 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import Mapping from datetime import datetime, timedelta, timezone -from typing import Any +from typing import Any, cast import jwt from app.core.config import settings @@ -18,10 +18,10 @@ class PasswordHasher: self._context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto") def hash(self, password: str) -> str: - return self._context.hash(password) + return cast(str, self._context.hash(password)) def verify(self, password: str, hashed_password: str) -> bool: - return self._context.verify(password, hashed_password) + return bool(self._context.verify(password, hashed_password)) class JWTService: @@ -45,10 +45,10 @@ class JWTService: } if claims: payload.update(claims) - return jwt.encode(payload, self._secret_key, algorithm=self._algorithm) + return cast(str, jwt.encode(payload, self._secret_key, algorithm=self._algorithm)) def decode(self, token: str) -> dict[str, Any]: - return jwt.decode(token, self._secret_key, algorithms=[self._algorithm]) + return cast(dict[str, Any], jwt.decode(token, self._secret_key, algorithms=[self._algorithm])) password_hasher = PasswordHasher() diff --git a/app/models/activity.py b/app/models/activity.py index f344c70..b907ed2 100644 --- a/app/models/activity.py +++ b/app/models/activity.py @@ -10,9 +10,11 @@ from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import DateTime, ForeignKey, Integer, func, text from sqlalchemy import Enum as SqlEnum from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.engine import Dialect from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.types import JSON as SA_JSON from sqlalchemy.types import TypeDecorator +from sqlalchemy.sql.type_api import TypeEngine from app.models.base import Base, enum_values @@ -31,7 +33,7 @@ class JSONBCompat(TypeDecorator): impl = JSONB cache_ok = True - def load_dialect_impl(self, dialect): # type: ignore[override] + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "sqlite": from sqlalchemy.dialects.sqlite import JSON as SQLITE_JSON # local import diff --git a/app/models/base.py b/app/models/base.py index 9717469..4cda063 100644 --- a/app/models/base.py +++ b/app/models/base.py @@ -14,7 +14,7 @@ class Base(DeclarativeBase): """Base class that configures naming conventions.""" @declared_attr.directive - def __tablename__(cls) -> str: # type: ignore[misc] # noqa: N805 - SQLAlchemy expects cls + def __tablename__(cls) -> str: # noqa: N805 - SQLAlchemy expects cls return cls.__name__.lower() diff --git a/app/repositories/task_repo.py b/app/repositories/task_repo.py index af26968..784b217 100644 --- a/app/repositories/task_repo.py +++ b/app/repositories/task_repo.py @@ -5,7 +5,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime -from typing import Any +from typing import Any, cast from sqlalchemy import Select, select from sqlalchemy.ext.asyncio import AsyncSession @@ -123,6 +123,9 @@ class TaskRepository: async def _resolve_task_owner(self, task: Task) -> int | None: if task.deal is not None: - return task.deal.owner_id + return int(task.deal.owner_id) stmt = select(Deal.owner_id).where(Deal.id == task.deal_id) - return await self._session.scalar(stmt) + owner_id_raw: Any = await self._session.scalar(stmt) + if owner_id_raw is None: + return None + return cast(int, owner_id_raw) diff --git a/app/services/contact_service.py b/app/services/contact_service.py index 79f7554..6b8a0a4 100644 --- a/app/services/contact_service.py +++ b/app/services/contact_service.py @@ -4,6 +4,7 @@ from __future__ import annotations from collections.abc import Sequence from dataclasses import dataclass +from typing import cast from sqlalchemy import select @@ -151,11 +152,11 @@ class ContactService: def _build_update_mapping(self, updates: ContactUpdateData) -> dict[str, str | None]: payload: dict[str, str | None] = {} if updates.name is not UNSET: - payload["name"] = updates.name + payload["name"] = cast(str | None, updates.name) if updates.email is not UNSET: - payload["email"] = updates.email + payload["email"] = cast(str | None, updates.email) if updates.phone is not UNSET: - payload["phone"] = updates.phone + payload["phone"] = cast(str | None, updates.phone) return payload async def _ensure_no_related_deals(self, contact_id: int) -> None: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/api/v1/__init__.py b/tests/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/test_enums.py b/tests/models/test_enums.py index 3fbe4f0..7c5f1fd 100644 --- a/tests/models/test_enums.py +++ b/tests/models/test_enums.py @@ -3,10 +3,12 @@ from __future__ import annotations from enum import StrEnum +from typing import cast from app.models.activity import Activity, ActivityType from app.models.deal import Deal, DealStage, DealStatus from app.models.organization_member import OrganizationMember, OrganizationRole +from sqlalchemy import Enum as SqlEnum def _values(enum_cls: type[StrEnum]) -> list[str]: @@ -14,20 +16,20 @@ def _values(enum_cls: type[StrEnum]) -> list[str]: def test_organization_role_column_uses_value_strings() -> None: - role_type = OrganizationMember.__table__.c.role.type # noqa: SLF001 - runtime inspection + role_type = cast(SqlEnum, OrganizationMember.__table__.c.role.type) # noqa: SLF001 assert role_type.enums == _values(OrganizationRole) def test_deal_status_column_uses_value_strings() -> None: - status_type = Deal.__table__.c.status.type # noqa: SLF001 - runtime inspection + status_type = cast(SqlEnum, Deal.__table__.c.status.type) # noqa: SLF001 assert status_type.enums == _values(DealStatus) def test_deal_stage_column_uses_value_strings() -> None: - stage_type = Deal.__table__.c.stage.type # noqa: SLF001 - runtime inspection + stage_type = cast(SqlEnum, Deal.__table__.c.stage.type) # noqa: SLF001 assert stage_type.enums == _values(DealStage) def test_activity_type_column_uses_value_strings() -> None: - activity_type = Activity.__table__.c.type.type # noqa: SLF001 - runtime inspection + activity_type = cast(SqlEnum, Activity.__table__.c.type.type) # noqa: SLF001 assert activity_type.enums == _values(ActivityType) diff --git a/tests/services/test_analytics_service.py b/tests/services/test_analytics_service.py index 93512ba..a839907 100644 --- a/tests/services/test_analytics_service.py +++ b/tests/services/test_analytics_service.py @@ -5,6 +5,7 @@ from __future__ import annotations from collections.abc import AsyncGenerator from datetime import datetime, timedelta, timezone from decimal import Decimal +from typing import cast import pytest import pytest_asyncio @@ -14,8 +15,13 @@ from app.models.deal import Deal, DealStage, DealStatus from app.models.organization import Organization from app.models.organization_member import OrganizationMember, OrganizationRole from app.models.user import User -from app.repositories.analytics_repo import AnalyticsRepository +from app.repositories.analytics_repo import ( + AnalyticsRepository, + StageStatusRollup, + StatusRollup, +) from app.services.analytics_service import AnalyticsService, invalidate_analytics_cache +from redis.asyncio.client import Redis from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool from tests.utils.fake_redis import InMemoryRedis @@ -170,20 +176,20 @@ async def test_funnel_breakdown_contains_stage_conversions(session: AsyncSession class _ExplodingRepository(AnalyticsRepository): - async def fetch_status_rollup(self, organization_id: int): # type: ignore[override] + async def fetch_status_rollup(self, organization_id: int) -> list[StatusRollup]: raise AssertionError("cache not used for status rollup") - async def count_new_deals_since(self, organization_id: int, threshold): # type: ignore[override] + async def count_new_deals_since(self, organization_id: int, threshold: datetime) -> int: raise AssertionError("cache not used for new deal count") - async def fetch_stage_status_rollup(self, organization_id: int): # type: ignore[override] + async def fetch_stage_status_rollup(self, organization_id: int) -> list[StageStatusRollup]: raise AssertionError("cache not used for funnel rollup") @pytest.mark.asyncio async def test_summary_reads_from_cache_when_available(session: AsyncSession) -> None: org_id, _, _ = await _seed_data(session) - cache = InMemoryRedis() + cache = cast(Redis, InMemoryRedis()) service = AnalyticsService( repository=AnalyticsRepository(session), cache=cache, @@ -201,7 +207,7 @@ async def test_summary_reads_from_cache_when_available(session: AsyncSession) -> @pytest.mark.asyncio async def test_invalidation_refreshes_cached_summary(session: AsyncSession) -> None: org_id, _, contact_id = await _seed_data(session) - cache = InMemoryRedis() + cache = cast(Redis, InMemoryRedis()) service = AnalyticsService( repository=AnalyticsRepository(session), cache=cache, @@ -235,7 +241,7 @@ async def test_invalidation_refreshes_cached_summary(session: AsyncSession) -> N @pytest.mark.asyncio async def test_funnel_reads_from_cache_when_available(session: AsyncSession) -> None: org_id, _, _ = await _seed_data(session) - cache = InMemoryRedis() + cache = cast(Redis, InMemoryRedis()) service = AnalyticsService( repository=AnalyticsRepository(session), cache=cache, diff --git a/tests/services/test_auth_service.py b/tests/services/test_auth_service.py index 5a23a58..14bd5d7 100644 --- a/tests/services/test_auth_service.py +++ b/tests/services/test_auth_service.py @@ -5,7 +5,7 @@ from __future__ import annotations from typing import cast from unittest.mock import MagicMock -import pytest # type: ignore[import-not-found] +import pytest from app.core.security import JWTService, PasswordHasher from app.models.user import User from app.repositories.user_repo import UserRepository diff --git a/tests/services/test_deal_service.py b/tests/services/test_deal_service.py index 11d525d..aba33fd 100644 --- a/tests/services/test_deal_service.py +++ b/tests/services/test_deal_service.py @@ -6,8 +6,8 @@ import uuid from collections.abc import AsyncGenerator from decimal import Decimal -import pytest # type: ignore[import-not-found] -import pytest_asyncio # type: ignore[import-not-found] +import pytest +import pytest_asyncio from app.models.activity import Activity, ActivityType from app.models.base import Base from app.models.contact import Contact diff --git a/tests/services/test_organization_service.py b/tests/services/test_organization_service.py index 73bf21c..22218b8 100644 --- a/tests/services/test_organization_service.py +++ b/tests/services/test_organization_service.py @@ -5,7 +5,7 @@ from __future__ import annotations from typing import cast from unittest.mock import MagicMock -import pytest # type: ignore[import-not-found] +import pytest from app.models.organization import Organization from app.models.organization_member import OrganizationMember, OrganizationRole from app.repositories.org_repo import OrganizationRepository