"""Asynchronous base classes for AsyncAPIModel, AsyncQuerySet, and AsyncManager."""
from collections.abc import AsyncIterator
from typing import Any, Generic, TypeVar
from pydantic import BaseModel
from .async_client import AsyncServiceClient
from .exceptions import DoesNotExist, MultipleObjectsReturned
T = TypeVar("T", bound="AsyncAPIModel")
[docs]
class AsyncQuerySet(Generic[T]):
"""Django-like AsyncQuerySet for filtering and retrieving API resources.
Provides lazy evaluation with result caching and chainable filter methods.
All methods that fetch data are async.
Example:
>>> queryset = Policy.objects.filter(status='active')
>>> queryset = queryset.order_by('-created_at')
>>> async for policy in queryset: # Executes query here
... print(policy.policy_number)
"""
[docs]
def __init__(self, model_class: type[T], manager: "AsyncManager[T]") -> None:
"""Initialize AsyncQuerySet.
Args:
model_class: The model class this QuerySet represents
manager: The manager that created this QuerySet
"""
self.model_class = model_class
self.manager = manager
# Query parameters
self._filters: dict[str, Any] = {}
self._excludes: dict[str, Any] = {}
self._order_by_fields: list[str] = []
self._limit: int | None = None
self._offset: int | None = None
# Result caching
self._result_cache: list[T] | None = None
self._fetched = False
def _clone(self) -> "AsyncQuerySet[T]":
"""Create a copy of this QuerySet for chaining.
Returns:
A new AsyncQuerySet with the same parameters
"""
qs = AsyncQuerySet(self.model_class, self.manager)
qs._filters = self._filters.copy()
qs._excludes = self._excludes.copy()
qs._order_by_fields = self._order_by_fields.copy()
qs._limit = self._limit
qs._offset = self._offset
return qs
def _build_params(self) -> dict[str, Any]:
"""Build query parameters for the API request.
Returns:
Dictionary of query parameters
"""
params: dict[str, Any] = {}
# Add filters
params.update(self._filters)
# Add excludes
for key, value in self._excludes.items():
params[f"exclude_{key}"] = value
# Add ordering
if self._order_by_fields:
params["ordering"] = ",".join(self._order_by_fields)
# Add pagination
if self._limit is not None:
params["limit"] = self._limit
if self._offset is not None:
params["offset"] = self._offset
return params
[docs]
def __getitem__(self, key: int | slice) -> "AsyncQuerySet[T]":
"""Support slicing (indexing not supported for async).
Args:
key: Slice object (int indexing not supported in async context)
Returns:
AsyncQuerySet with limits applied
Example:
>>> subset = Policy.objects.all()[10:20] # Get slice
>>> async for policy in subset:
... print(policy.policy_number)
Note:
Integer indexing like `queryset[0]` is not supported for async querysets.
Use `await queryset.first()` instead.
"""
if isinstance(key, int):
raise TypeError(
"Integer indexing is not supported for async querysets. "
"Use 'await queryset.first()' or 'await queryset.last()' instead."
)
elif isinstance(key, slice):
# Get slice
qs = self._clone()
if key.start is not None:
qs._offset = key.start
if key.stop is not None:
if key.start is not None:
qs._limit = key.stop - key.start
else:
qs._limit = key.stop
return qs
else:
raise TypeError("AsyncQuerySet indices must be slices")
async def _fetch(self) -> None:
"""Execute the query and cache results."""
if self._fetched:
return
params = self._build_params()
response = await self.manager.client.get(self.manager.get_endpoint(), params=params)
# Parse response data
data = response.data
# Handle paginated responses
if isinstance(data, dict) and "results" in data:
results = data["results"]
elif isinstance(data, list):
results = data
else:
results = [data]
# Convert to model instances
self._result_cache = [
self.model_class.from_api(item, client=self.manager.client) # type: ignore[misc]
for item in results
]
self._fetched = True
# Filtering methods (sync - return new QuerySet)
[docs]
def filter(self, **kwargs: Any) -> "AsyncQuerySet[T]":
"""Filter QuerySet by given parameters.
Args:
**kwargs: Field lookups
Returns:
New AsyncQuerySet with filters applied
Example:
>>> Policy.objects.filter(status='active', premium_amount__gte=1000)
"""
qs = self._clone()
qs._filters.update(kwargs)
return qs
[docs]
def exclude(self, **kwargs: Any) -> "AsyncQuerySet[T]":
"""Exclude results matching given parameters.
Args:
**kwargs: Field lookups to exclude
Returns:
New AsyncQuerySet with exclusions applied
Example:
>>> Policy.objects.exclude(status='cancelled')
"""
qs = self._clone()
qs._excludes.update(kwargs)
return qs
[docs]
def all(self) -> "AsyncQuerySet[T]":
"""Return a copy of this QuerySet.
Returns:
New AsyncQuerySet (clone)
"""
return self._clone()
[docs]
def order_by(self, *fields: str) -> "AsyncQuerySet[T]":
"""Order results by given fields.
Args:
*fields: Field names (prefix with '-' for descending)
Returns:
New AsyncQuerySet with ordering applied
Example:
>>> Policy.objects.order_by('-created_at', 'policy_number')
"""
qs = self._clone()
qs._order_by_fields = list(fields)
return qs
# Retrieval methods (async)
[docs]
async def first(self) -> T | None:
"""Get the first result or None.
Returns:
First model instance or None
Example:
>>> policy = await Policy.objects.filter(status='active').first()
"""
qs = self._clone()
qs._limit = 1
await qs._fetch()
return qs._result_cache[0] if qs._result_cache else None
[docs]
async def last(self) -> T | None:
"""Get the last result or None.
Returns:
Last model instance or None
"""
qs = self._clone()
# If ordering specified, reverse it and get first
if qs._order_by_fields:
qs._order_by_fields = [
f[1:] if f.startswith("-") else f"-{f}" for f in qs._order_by_fields
]
qs._limit = 1
# If no ordering, need to fetch all and get last
await qs._fetch()
return qs._result_cache[-1] if qs._result_cache else None
[docs]
async def get(self, **kwargs: Any) -> T:
"""Get a single object matching the criteria.
Args:
**kwargs: Field lookups
Returns:
Single model instance
Raises:
DoesNotExist: If no results found
MultipleObjectsReturned: If multiple results found
Example:
>>> policy = await Policy.objects.get(id=123)
"""
qs = self.filter(**kwargs)
qs._limit = 2 # Fetch 2 to detect multiple
await qs._fetch()
if not qs._result_cache:
raise DoesNotExist(f"{self.model_class.__name__} matching query does not exist")
if len(qs._result_cache) > 1:
raise MultipleObjectsReturned(
f"get() returned more than one {self.model_class.__name__}"
)
return qs._result_cache[0]
[docs]
async def exists(self) -> bool:
"""Check if any results exist.
Returns:
True if results exist, False otherwise
Example:
>>> if await Policy.objects.filter(status='active').exists():
... print("Active policies found")
"""
qs = self._clone()
qs._limit = 1
await qs._fetch()
return bool(qs._result_cache)
[docs]
async def count(self) -> int:
"""Get count of results.
Returns:
Number of results
Example:
>>> count = await Policy.objects.filter(status='active').count()
"""
# Try to get count from API without fetching all results
params = self._build_params()
params["count_only"] = "true"
try:
response = await self.manager.client.get(self.manager.get_endpoint(), params=params)
# Try to get count from response
if isinstance(response.data, dict):
if "count" in response.data:
return int(response.data["count"])
elif "total" in response.data:
return int(response.data["total"])
except Exception:
pass
# Fallback: fetch and count
await self._fetch()
return len(self._result_cache) if self._result_cache else 0
# Async iteration support
[docs]
async def __aiter__(self) -> AsyncIterator[T]:
"""Make QuerySet async iterable.
Returns:
Async iterator over model instances
Example:
>>> async for policy in Policy.objects.filter(status='active'):
... print(policy.policy_number)
"""
await self._fetch()
for item in self._result_cache or []:
yield item
[docs]
async def alen(self) -> int:
"""Get length of results (async version).
Returns:
Number of cached results
"""
await self._fetch()
return len(self._result_cache) if self._result_cache else 0
# Value extraction methods (async)
[docs]
async def values(self, *fields: str) -> list[dict[str, Any]]:
"""Return list of dictionaries instead of model instances.
Args:
*fields: Field names to include (all if not specified)
Returns:
List of dictionaries
Example:
>>> policies = await Policy.objects.values('id', 'policy_number')
>>> # [{'id': 1, 'policy_number': 'POL-001'}, ...]
"""
await self._fetch()
if not self._result_cache:
return []
results = []
for obj in self._result_cache:
data = obj.to_dict()
if fields:
data = {k: v for k, v in data.items() if k in fields}
results.append(data)
return results
[docs]
async def values_list(self, *fields: str, flat: bool = False) -> list[Any]:
"""Return list of tuples instead of model instances.
Args:
*fields: Field names to include
flat: If True and one field, return flat list
Returns:
List of tuples (or flat list if flat=True)
Example:
>>> ids = await Policy.objects.values_list('id', flat=True)
>>> # [1, 2, 3, ...]
"""
await self._fetch()
if not self._result_cache:
return []
if flat and len(fields) != 1:
raise ValueError("'flat' is only valid when one field is specified")
results = []
for obj in self._result_cache:
data = obj.to_dict()
if flat:
results.append(data.get(fields[0]))
else:
values = tuple(data.get(field) for field in fields)
results.append(values)
return results
[docs]
def __repr__(self) -> str:
"""String representation of QuerySet."""
if self._result_cache is not None:
return f"<AsyncQuerySet {list(self._result_cache)}>"
return f"<AsyncQuerySet for {self.model_class.__name__}>"
[docs]
class AsyncManager(Generic[T]):
"""Django-like AsyncManager for model querying.
Handles creation and retrieval of model instances using async methods.
Example:
>>> class Policy(AsyncAPIModel):
... objects = AsyncManager()
>>> policies = await Policy.objects.filter(status='active').alen()
"""
[docs]
def __init__(self, model_class: type[T], client: AsyncServiceClient) -> None:
"""Initialize AsyncManager.
Args:
model_class: The model class this manager handles
client: Async HTTP client for API requests
"""
self.model_class = model_class
self.client = client
[docs]
def get_endpoint(self) -> str:
"""Get the API endpoint for this model.
Returns:
API endpoint path
"""
return self.model_class.get_endpoint()
# Query methods (return AsyncQuerySet)
[docs]
def all(self) -> AsyncQuerySet[T]:
"""Get all objects.
Returns:
AsyncQuerySet of all objects
Example:
>>> all_policies = Policy.objects.all()
>>> async for policy in all_policies:
... print(policy)
"""
return AsyncQuerySet(self.model_class, self)
[docs]
def filter(self, **kwargs: Any) -> AsyncQuerySet[T]:
"""Filter objects by criteria.
Args:
**kwargs: Filter parameters
Returns:
Filtered AsyncQuerySet
Example:
>>> active_policies = Policy.objects.filter(status='active')
"""
return self.all().filter(**kwargs)
[docs]
def exclude(self, **kwargs: Any) -> AsyncQuerySet[T]:
"""Exclude objects matching criteria.
Args:
**kwargs: Exclusion parameters
Returns:
Filtered AsyncQuerySet
"""
return self.all().exclude(**kwargs)
# Async retrieval methods
[docs]
async def get(self, **kwargs: Any) -> T:
"""Get a single object.
Args:
**kwargs: Lookup parameters
Returns:
Single model instance
Raises:
DoesNotExist: If not found
MultipleObjectsReturned: If multiple found
Example:
>>> policy = await Policy.objects.get(id=123)
"""
return await self.all().get(**kwargs)
[docs]
def order_by(self, *fields: str) -> AsyncQuerySet[T]:
"""Order results by given fields.
Args:
*fields: Field names to order by (prefix with '-' for descending)
Returns:
Ordered AsyncQuerySet
Example:
>>> policies = Policy.objects.order_by('-created_at')
"""
return self.all().order_by(*fields)
[docs]
async def first(self) -> T | None:
"""Get the first object or None.
Returns:
First model instance or None
"""
return await self.all().first()
[docs]
async def last(self) -> T | None:
"""Get the last object or None.
Returns:
Last model instance or None
"""
return await self.all().last()
[docs]
async def exists(self) -> bool:
"""Check if any objects exist.
Returns:
True if results exist, False otherwise
"""
return await self.all().exists()
[docs]
async def count(self) -> int:
"""Count the number of objects.
Returns:
Number of objects
"""
return await self.all().count()
[docs]
async def values(self, *fields: str) -> list[dict[str, Any]]:
"""Return list of dictionaries instead of model instances.
Args:
*fields: Field names to include (all if not specified)
Returns:
List of dictionaries
Example:
>>> policies = await Policy.objects.values('id', 'policy_number')
"""
return await self.all().values(*fields)
[docs]
async def values_list(self, *fields: str, flat: bool = False) -> list[Any]:
"""Return list of tuples instead of model instances.
Args:
*fields: Field names to include
flat: If True and one field, return flat list
Returns:
List of tuples (or flat list if flat=True)
Example:
>>> ids = await Policy.objects.values_list('id', flat=True)
"""
return await self.all().values_list(*fields, flat=flat)
# Async creation methods
[docs]
async def create(self, **kwargs: Any) -> T:
"""Create a new object in the API.
Args:
**kwargs: Field values
Returns:
Created model instance
Example:
>>> policy = await Policy.objects.create(
... policy_number='POL-001',
... premium_amount=1500.00
... )
"""
# Validate with Pydantic schema
schema_class = self.model_class.get_schema_class()
validated_data = schema_class(**kwargs)
data = validated_data.model_dump(mode="json", exclude_unset=True)
# Make API request
response = await self.client.post(self.get_endpoint(), data=data)
# Return model instance
return self.model_class.from_api(response.data, client=self.client) # type: ignore[return-value]
[docs]
async def get_or_create(
self, defaults: dict[str, Any] | None = None, **kwargs: Any
) -> tuple[T, bool]:
"""Get an existing object or create a new one.
Args:
defaults: Values to use when creating
**kwargs: Lookup parameters
Returns:
Tuple of (object, created) where created is a boolean
Example:
>>> policy, created = await Policy.objects.get_or_create(
... policy_number='POL-001',
... defaults={'premium_amount': 1500.00}
... )
"""
try:
obj = await self.get(**kwargs)
return obj, False
except DoesNotExist:
create_data = kwargs.copy()
if defaults:
create_data.update(defaults)
obj = await self.create(**create_data)
return obj, True
[docs]
async def update_or_create(
self, defaults: dict[str, Any] | None = None, **kwargs: Any
) -> tuple[T, bool]:
"""Update an existing object or create a new one.
Args:
defaults: Values to update/create with
**kwargs: Lookup parameters
Returns:
Tuple of (object, created) where created is a boolean
Example:
>>> policy, created = await Policy.objects.update_or_create(
... policy_number='POL-001',
... defaults={'premium_amount': 2000.00}
... )
"""
try:
obj = await self.get(**kwargs)
# Update object
if defaults:
for key, value in defaults.items():
setattr(obj, key, value)
await obj.save(update_fields=list(defaults.keys()))
return obj, False
except DoesNotExist:
create_data = kwargs.copy()
if defaults:
create_data.update(defaults)
obj = await self.create(**create_data)
return obj, True
[docs]
async def bulk_create(self, objs: list[dict[str, Any]]) -> list[T]:
"""Create multiple objects in bulk.
Args:
objs: List of dictionaries with object data
Returns:
List of created model instances
Example:
>>> policies = await Policy.objects.bulk_create([
... {'policy_number': 'POL-001', 'premium_amount': 1500},
... {'policy_number': 'POL-002', 'premium_amount': 2000},
... ])
"""
created_objects = []
for obj_data in objs:
obj = await self.create(**obj_data)
created_objects.append(obj)
return created_objects
[docs]
class AsyncAPIModel:
"""Base class for async API models.
Provides Django ORM-like interface for working with API resources asynchronously.
Must be subclassed with _schema_class and _endpoint defined.
Example:
>>> class PolicySchema(BaseModel):
... id: int | None = None
... policy_number: str
... premium_amount: float
>>>
>>> class Policy(AsyncAPIModel):
... _schema_class = PolicySchema
... _endpoint = "/api/v1/policies/"
>>>
>>> # Register with async client
>>> Policy.objects = AsyncManager(Policy, async_client)
"""
# Class attributes (to be overridden in subclasses)
_schema_class: type[BaseModel]
_endpoint: str
objects: AsyncManager["AsyncAPIModel"]
[docs]
def __init__(self, **kwargs: Any) -> None:
"""Initialize model instance.
Args:
**kwargs: Field values or _pydantic_instance
"""
# Check if we received a pydantic instance
if "_pydantic_instance" in kwargs:
self._pydantic_instance: BaseModel = kwargs["_pydantic_instance"]
self._client: AsyncServiceClient | None = kwargs.get("_client")
else:
# Validate and create pydantic instance
schema_class = self.get_schema_class()
self._pydantic_instance = schema_class(**kwargs)
self._client = kwargs.get("_client")
# Set attributes from pydantic instance (preserving nested Pydantic models)
for key in self._pydantic_instance.model_fields.keys():
object.__setattr__(self, key, getattr(self._pydantic_instance, key))
[docs]
def __setattr__(self, name: str, value: Any) -> None:
"""Set attribute and update internal pydantic instance.
Args:
name: Attribute name
value: Attribute value
"""
# Set the attribute normally
object.__setattr__(self, name, value)
# If it's a model field and we have a pydantic instance, update it
if (
name not in ("_pydantic_instance", "_client")
and hasattr(self, "_pydantic_instance")
and name in self._pydantic_instance.model_fields
):
# Get current data and update with new value
current_data = self._pydantic_instance.model_dump()
current_data[name] = value
# Recreate pydantic instance with updated data
schema_class = self.get_schema_class()
self._pydantic_instance = schema_class(**current_data)
[docs]
@classmethod
def get_endpoint(cls) -> str:
"""Get the API endpoint for this model.
Returns:
API endpoint path
"""
return cls._endpoint
[docs]
@classmethod
def get_schema_class(cls) -> type[BaseModel]:
"""Get the Pydantic schema class for this model.
Returns:
Pydantic BaseModel class
"""
return cls._schema_class
[docs]
@classmethod
def from_api(
cls, data: dict[str, Any], client: AsyncServiceClient | None = None
) -> "AsyncAPIModel":
"""Create model instance from API response data.
Args:
data: API response data
client: Async HTTP client
Returns:
Model instance
"""
schema_class = cls.get_schema_class()
pydantic_instance = schema_class(**data)
return cls(_pydantic_instance=pydantic_instance, _client=client)
[docs]
def to_dict(self, exclude_unset: bool = False, exclude_none: bool = False) -> dict[str, Any]:
"""Convert model to dictionary.
Args:
exclude_unset: Exclude fields that weren't explicitly set
exclude_none: Exclude fields with None values
Returns:
Dictionary representation
"""
return self._pydantic_instance.model_dump(
mode="json", exclude_unset=exclude_unset, exclude_none=exclude_none
)
[docs]
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
"""Alias for to_dict for Pydantic compatibility.
Args:
**kwargs: Arguments to pass to model_dump
Returns:
Dictionary representation
"""
return self._pydantic_instance.model_dump(**kwargs)
[docs]
async def save(self, update_fields: list[str] | None = None) -> None:
"""Save this instance to the API (async).
Creates or updates based on whether the object has an ID.
Args:
update_fields: List of fields to update (None = all fields)
Raises:
APIException: If client is not set
Example:
>>> policy = Policy(policy_number='POL-001', premium_amount=1500)
>>> await policy.save() # Creates new
>>> policy.premium_amount = 2000
>>> await policy.save(update_fields=['premium_amount']) # Updates
"""
from .exceptions import APIException
if not self._client:
raise APIException("Cannot save: no client configured for this instance")
# Get data to send
data = self.to_dict(exclude_unset=True)
schema_class = self.get_schema_class()
# Filter by update_fields if specified
if update_fields:
data = {k: v for k, v in data.items() if k in update_fields}
# Skip validation for partial updates - API will validate
else:
# Validate full data
schema_class(**data)
# Determine if update or create
if hasattr(self, "id") and self.id:
# Update existing
endpoint = f"{self.get_endpoint()}{self.id}/"
response = await self._client.patch(endpoint, data=data)
else:
# Create new
response = await self._client.post(self.get_endpoint(), data=data)
# Update instance with response data
pydantic_instance = schema_class(**response.data)
self._pydantic_instance = pydantic_instance
# Update attributes (preserving nested Pydantic models)
for key in pydantic_instance.model_fields.keys():
object.__setattr__(self, key, getattr(pydantic_instance, key))
[docs]
async def delete(self) -> None:
"""Delete this instance from the API (async).
Raises:
APIException: If client is not set or object has no ID
Example:
>>> policy = await Policy.objects.get(id=123)
>>> await policy.delete()
"""
from .exceptions import APIException
if not self._client:
raise APIException("Cannot delete: no client configured")
if not hasattr(self, "id") or not self.id:
raise APIException("Cannot delete: object has no id")
endpoint = f"{self.get_endpoint()}{self.id}/"
await self._client.delete(endpoint)
[docs]
async def refresh_from_api(self) -> None:
"""Refresh this instance's data from the API (async).
Raises:
APIException: If client is not set or object has no ID
Example:
>>> policy = await Policy.objects.get(id=123)
>>> # ... time passes, data may have changed ...
>>> await policy.refresh_from_api() # Reload from API
"""
from .exceptions import APIException
if not self._client:
raise APIException("Cannot refresh: no client configured")
if not hasattr(self, "id") or not self.id:
raise APIException("Cannot refresh: object has no id")
endpoint = f"{self.get_endpoint()}{self.id}/"
response = await self._client.get(endpoint)
# Update instance with fresh data
schema_class = self.get_schema_class()
pydantic_instance = schema_class(**response.data)
self._pydantic_instance = pydantic_instance
# Update attributes (preserving nested Pydantic models)
for key in pydantic_instance.model_fields.keys():
object.__setattr__(self, key, getattr(pydantic_instance, key))
[docs]
def __repr__(self) -> str:
"""String representation."""
return f"<{self.__class__.__name__}: {self._pydantic_instance}>"
[docs]
def __str__(self) -> str:
"""Human-readable string."""
return repr(self)
[docs]
def register_async_models(client: AsyncServiceClient, *model_classes: type[AsyncAPIModel]) -> None:
"""Register model classes with an async client.
This assigns an AsyncManager instance to each model's 'objects' attribute.
Args:
client: AsyncServiceClient instance
*model_classes: Model classes to register
Example:
>>> async_client = AsyncServiceClient(base_url="https://api.example.com")
>>> register_async_models(async_client, Policy, Claim, Broker)
>>> policy = await Policy.objects.get(id=123)
"""
for model_class in model_classes:
model_class.objects = AsyncManager(model_class, client)