refactor: enhance type hinting and casting for improved type safety across multiple files
This commit is contained in:
parent
f234e60e65
commit
688ade0452
|
|
@ -6,7 +6,7 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import redis.asyncio as redis
|
||||
from app.core.config import settings
|
||||
|
|
@ -45,10 +45,13 @@ class RedisCacheManager:
|
|||
async with self._lock:
|
||||
if self._client is not None:
|
||||
return
|
||||
self._client = redis.from_url(
|
||||
settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=False,
|
||||
self._client = cast(
|
||||
Redis,
|
||||
redis.from_url( # type: ignore[no-untyped-call]
|
||||
settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=False,
|
||||
),
|
||||
)
|
||||
await self._refresh_availability()
|
||||
|
||||
|
|
@ -64,10 +67,13 @@ class RedisCacheManager:
|
|||
return
|
||||
async with self._lock:
|
||||
if self._client is None:
|
||||
self._client = redis.from_url(
|
||||
settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=False,
|
||||
self._client = cast(
|
||||
Redis,
|
||||
redis.from_url( # type: ignore[no-untyped-call]
|
||||
settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=False,
|
||||
),
|
||||
)
|
||||
await self._refresh_availability()
|
||||
|
||||
|
|
@ -76,7 +82,7 @@ class RedisCacheManager:
|
|||
self._available = False
|
||||
return
|
||||
try:
|
||||
await self._client.ping()
|
||||
await cast(Awaitable[Any], self._client.ping())
|
||||
except RedisError as exc: # pragma: no cover - logging only
|
||||
self._available = False
|
||||
logger.warning("Redis ping failed: %s", exc)
|
||||
|
|
@ -140,8 +146,8 @@ async def write_json(
|
|||
"""Serialize data to JSON and store it with TTL using retry/backoff."""
|
||||
payload = json.dumps(value, separators=(",", ":"), ensure_ascii=True).encode("utf-8")
|
||||
|
||||
async def _operation() -> Any:
|
||||
return await client.set(name=key, value=payload, ex=ttl_seconds)
|
||||
async def _operation() -> None:
|
||||
await client.set(name=key, value=payload, ex=ttl_seconds)
|
||||
|
||||
await _run_with_retry(_operation, backoff_ms)
|
||||
|
||||
|
|
@ -151,8 +157,8 @@ async def delete_keys(client: Redis, keys: list[str], backoff_ms: int) -> None:
|
|||
if not keys:
|
||||
return
|
||||
|
||||
async def _operation() -> Any:
|
||||
return await client.delete(*keys)
|
||||
async def _operation() -> None:
|
||||
await client.delete(*keys)
|
||||
|
||||
await _run_with_retry(_operation, backoff_ms)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import jwt
|
||||
from app.core.config import settings
|
||||
|
|
@ -18,10 +18,10 @@ class PasswordHasher:
|
|||
self._context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||||
|
||||
def hash(self, password: str) -> str:
|
||||
return self._context.hash(password)
|
||||
return cast(str, self._context.hash(password))
|
||||
|
||||
def verify(self, password: str, hashed_password: str) -> bool:
|
||||
return self._context.verify(password, hashed_password)
|
||||
return bool(self._context.verify(password, hashed_password))
|
||||
|
||||
|
||||
class JWTService:
|
||||
|
|
@ -45,10 +45,10 @@ class JWTService:
|
|||
}
|
||||
if claims:
|
||||
payload.update(claims)
|
||||
return jwt.encode(payload, self._secret_key, algorithm=self._algorithm)
|
||||
return cast(str, jwt.encode(payload, self._secret_key, algorithm=self._algorithm))
|
||||
|
||||
def decode(self, token: str) -> dict[str, Any]:
|
||||
return jwt.decode(token, self._secret_key, algorithms=[self._algorithm])
|
||||
return cast(dict[str, Any], jwt.decode(token, self._secret_key, algorithms=[self._algorithm]))
|
||||
|
||||
|
||||
password_hasher = PasswordHasher()
|
||||
|
|
|
|||
|
|
@ -10,9 +10,11 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
from sqlalchemy import DateTime, ForeignKey, Integer, func, text
|
||||
from sqlalchemy import Enum as SqlEnum
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.engine import Dialect
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.types import JSON as SA_JSON
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
from app.models.base import Base, enum_values
|
||||
|
||||
|
|
@ -31,7 +33,7 @@ class JSONBCompat(TypeDecorator):
|
|||
impl = JSONB
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect): # type: ignore[override]
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import JSON as SQLITE_JSON # local import
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class Base(DeclarativeBase):
|
|||
"""Base class that configures naming conventions."""
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(cls) -> str: # type: ignore[misc] # noqa: N805 - SQLAlchemy expects cls
|
||||
def __tablename__(cls) -> str: # noqa: N805 - SQLAlchemy expects cls
|
||||
return cls.__name__.lower()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import Select, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -123,6 +123,9 @@ class TaskRepository:
|
|||
|
||||
async def _resolve_task_owner(self, task: Task) -> int | None:
|
||||
if task.deal is not None:
|
||||
return task.deal.owner_id
|
||||
return int(task.deal.owner_id)
|
||||
stmt = select(Deal.owner_id).where(Deal.id == task.deal_id)
|
||||
return await self._session.scalar(stmt)
|
||||
owner_id_raw: Any = await self._session.scalar(stmt)
|
||||
if owner_id_raw is None:
|
||||
return None
|
||||
return cast(int, owner_id_raw)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
|
|
@ -151,11 +152,11 @@ class ContactService:
|
|||
def _build_update_mapping(self, updates: ContactUpdateData) -> dict[str, str | None]:
|
||||
payload: dict[str, str | None] = {}
|
||||
if updates.name is not UNSET:
|
||||
payload["name"] = updates.name
|
||||
payload["name"] = cast(str | None, updates.name)
|
||||
if updates.email is not UNSET:
|
||||
payload["email"] = updates.email
|
||||
payload["email"] = cast(str | None, updates.email)
|
||||
if updates.phone is not UNSET:
|
||||
payload["phone"] = updates.phone
|
||||
payload["phone"] = cast(str | None, updates.phone)
|
||||
return payload
|
||||
|
||||
async def _ensure_no_related_deals(self, contact_id: int) -> None:
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
|
||||
from app.models.activity import Activity, ActivityType
|
||||
from app.models.deal import Deal, DealStage, DealStatus
|
||||
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||
from sqlalchemy import Enum as SqlEnum
|
||||
|
||||
|
||||
def _values(enum_cls: type[StrEnum]) -> list[str]:
|
||||
|
|
@ -14,20 +16,20 @@ def _values(enum_cls: type[StrEnum]) -> list[str]:
|
|||
|
||||
|
||||
def test_organization_role_column_uses_value_strings() -> None:
|
||||
role_type = OrganizationMember.__table__.c.role.type # noqa: SLF001 - runtime inspection
|
||||
role_type = cast(SqlEnum, OrganizationMember.__table__.c.role.type) # noqa: SLF001
|
||||
assert role_type.enums == _values(OrganizationRole)
|
||||
|
||||
|
||||
def test_deal_status_column_uses_value_strings() -> None:
|
||||
status_type = Deal.__table__.c.status.type # noqa: SLF001 - runtime inspection
|
||||
status_type = cast(SqlEnum, Deal.__table__.c.status.type) # noqa: SLF001
|
||||
assert status_type.enums == _values(DealStatus)
|
||||
|
||||
|
||||
def test_deal_stage_column_uses_value_strings() -> None:
|
||||
stage_type = Deal.__table__.c.stage.type # noqa: SLF001 - runtime inspection
|
||||
stage_type = cast(SqlEnum, Deal.__table__.c.stage.type) # noqa: SLF001
|
||||
assert stage_type.enums == _values(DealStage)
|
||||
|
||||
|
||||
def test_activity_type_column_uses_value_strings() -> None:
|
||||
activity_type = Activity.__table__.c.type.type # noqa: SLF001 - runtime inspection
|
||||
activity_type = cast(SqlEnum, Activity.__table__.c.type.type) # noqa: SLF001
|
||||
assert activity_type.enums == _values(ActivityType)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from decimal import Decimal
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
|
@ -14,8 +15,13 @@ from app.models.deal import Deal, DealStage, DealStatus
|
|||
from app.models.organization import Organization
|
||||
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||
from app.models.user import User
|
||||
from app.repositories.analytics_repo import AnalyticsRepository
|
||||
from app.repositories.analytics_repo import (
|
||||
AnalyticsRepository,
|
||||
StageStatusRollup,
|
||||
StatusRollup,
|
||||
)
|
||||
from app.services.analytics_service import AnalyticsService, invalidate_analytics_cache
|
||||
from redis.asyncio.client import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from tests.utils.fake_redis import InMemoryRedis
|
||||
|
|
@ -170,20 +176,20 @@ async def test_funnel_breakdown_contains_stage_conversions(session: AsyncSession
|
|||
|
||||
|
||||
class _ExplodingRepository(AnalyticsRepository):
|
||||
async def fetch_status_rollup(self, organization_id: int): # type: ignore[override]
|
||||
async def fetch_status_rollup(self, organization_id: int) -> list[StatusRollup]:
|
||||
raise AssertionError("cache not used for status rollup")
|
||||
|
||||
async def count_new_deals_since(self, organization_id: int, threshold): # type: ignore[override]
|
||||
async def count_new_deals_since(self, organization_id: int, threshold: datetime) -> int:
|
||||
raise AssertionError("cache not used for new deal count")
|
||||
|
||||
async def fetch_stage_status_rollup(self, organization_id: int): # type: ignore[override]
|
||||
async def fetch_stage_status_rollup(self, organization_id: int) -> list[StageStatusRollup]:
|
||||
raise AssertionError("cache not used for funnel rollup")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_reads_from_cache_when_available(session: AsyncSession) -> None:
|
||||
org_id, _, _ = await _seed_data(session)
|
||||
cache = InMemoryRedis()
|
||||
cache = cast(Redis, InMemoryRedis())
|
||||
service = AnalyticsService(
|
||||
repository=AnalyticsRepository(session),
|
||||
cache=cache,
|
||||
|
|
@ -201,7 +207,7 @@ async def test_summary_reads_from_cache_when_available(session: AsyncSession) ->
|
|||
@pytest.mark.asyncio
|
||||
async def test_invalidation_refreshes_cached_summary(session: AsyncSession) -> None:
|
||||
org_id, _, contact_id = await _seed_data(session)
|
||||
cache = InMemoryRedis()
|
||||
cache = cast(Redis, InMemoryRedis())
|
||||
service = AnalyticsService(
|
||||
repository=AnalyticsRepository(session),
|
||||
cache=cache,
|
||||
|
|
@ -235,7 +241,7 @@ async def test_invalidation_refreshes_cached_summary(session: AsyncSession) -> N
|
|||
@pytest.mark.asyncio
|
||||
async def test_funnel_reads_from_cache_when_available(session: AsyncSession) -> None:
|
||||
org_id, _, _ = await _seed_data(session)
|
||||
cache = InMemoryRedis()
|
||||
cache = cast(Redis, InMemoryRedis())
|
||||
service = AnalyticsService(
|
||||
repository=AnalyticsRepository(session),
|
||||
cache=cache,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
import pytest
|
||||
from app.core.security import JWTService, PasswordHasher
|
||||
from app.models.user import User
|
||||
from app.repositories.user_repo import UserRepository
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import uuid
|
|||
from collections.abc import AsyncGenerator
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
import pytest_asyncio # type: ignore[import-not-found]
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from app.models.activity import Activity, ActivityType
|
||||
from app.models.base import Base
|
||||
from app.models.contact import Contact
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
import pytest
|
||||
from app.models.organization import Organization
|
||||
from app.models.organization_member import OrganizationMember, OrganizationRole
|
||||
from app.repositories.org_repo import OrganizationRepository
|
||||
|
|
|
|||
Loading…
Reference in New Issue