|
from __future__ import annotations |
|
|
|
from abc import abstractmethod |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
ClassVar, |
|
Iterator, |
|
Mapping, |
|
Protocol, |
|
Sized, |
|
TypeVar, |
|
overload, |
|
) |
|
|
|
if TYPE_CHECKING: |
|
from wandb_graphql.language.ast import Document |
|
|
|
T = TypeVar("T") |
|
|
|
|
|
|
|
class _Client(Protocol): |
|
def execute(self, *args: Any, **kwargs: Any) -> dict[str, Any]: ... |
|
|
|
|
|
class Paginator(Iterator[T]): |
|
"""An iterator for paginated objects from GraphQL requests.""" |
|
|
|
QUERY: ClassVar[Document | None] = None |
|
|
|
def __init__( |
|
self, |
|
client: _Client, |
|
variables: Mapping[str, Any], |
|
per_page: int = 50, |
|
): |
|
self.client: _Client = client |
|
|
|
|
|
self.variables: dict[str, Any] = dict(variables) |
|
|
|
self.per_page: int = per_page |
|
self.objects: list[T] = [] |
|
self.index: int = -1 |
|
self.last_response: object | None = None |
|
|
|
def __iter__(self) -> Iterator[T]: |
|
self.index = -1 |
|
return self |
|
|
|
@property |
|
@abstractmethod |
|
def more(self) -> bool: |
|
"""Whether there are more pages to be fetched.""" |
|
raise NotImplementedError |
|
|
|
@property |
|
@abstractmethod |
|
def cursor(self) -> str | None: |
|
"""The start cursor to use for the next fetched page.""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def convert_objects(self) -> list[T]: |
|
"""Convert the last fetched response data into the iterated objects.""" |
|
raise NotImplementedError |
|
|
|
def update_variables(self) -> None: |
|
"""Update the query variables for the next page fetch.""" |
|
self.variables.update({"perPage": self.per_page, "cursor": self.cursor}) |
|
|
|
def _update_response(self) -> None: |
|
"""Fetch and store the response data for the next page.""" |
|
self.last_response = self.client.execute( |
|
self.QUERY, variable_values=self.variables |
|
) |
|
|
|
def _load_page(self) -> bool: |
|
"""Fetch the next page, if any, returning True and storing the response if there was one.""" |
|
if not self.more: |
|
return False |
|
self.update_variables() |
|
self._update_response() |
|
self.objects.extend(self.convert_objects()) |
|
return True |
|
|
|
@overload |
|
def __getitem__(self, index: int) -> T: ... |
|
@overload |
|
def __getitem__(self, index: slice) -> list[T]: ... |
|
|
|
def __getitem__(self, index: int | slice) -> T | list[T]: |
|
loaded = True |
|
stop = index.stop if isinstance(index, slice) else index |
|
while loaded and stop > len(self.objects) - 1: |
|
loaded = self._load_page() |
|
return self.objects[index] |
|
|
|
def __next__(self) -> T: |
|
self.index += 1 |
|
if len(self.objects) <= self.index: |
|
if not self._load_page(): |
|
raise StopIteration |
|
if len(self.objects) <= self.index: |
|
raise StopIteration |
|
return self.objects[self.index] |
|
|
|
next = __next__ |
|
|
|
|
|
class SizedPaginator(Paginator[T], Sized): |
|
"""A Paginator for objects with a known total count.""" |
|
|
|
def __len__(self) -> int: |
|
if self.length is None: |
|
self._load_page() |
|
if self.length is None: |
|
raise ValueError("Object doesn't provide length") |
|
return self.length |
|
|
|
@property |
|
@abstractmethod |
|
def length(self) -> int | None: |
|
raise NotImplementedError |
|
|