File size: 3,550 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from __future__ import annotations
from abc import abstractmethod
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Iterator,
Mapping,
Protocol,
Sized,
TypeVar,
overload,
)
if TYPE_CHECKING:
from wandb_graphql.language.ast import Document
T = TypeVar("T")
# Structural type hint for the client instance
class _Client(Protocol):
def execute(self, *args: Any, **kwargs: Any) -> dict[str, Any]: ...
class Paginator(Iterator[T]):
"""An iterator for paginated objects from GraphQL requests."""
QUERY: ClassVar[Document | None] = None
def __init__(
self,
client: _Client,
variables: Mapping[str, Any],
per_page: int = 50, # We don't allow unbounded paging
):
self.client: _Client = client
# shallow copy partly guards against mutating the original input
self.variables: dict[str, Any] = dict(variables)
self.per_page: int = per_page
self.objects: list[T] = []
self.index: int = -1
self.last_response: object | None = None
def __iter__(self) -> Iterator[T]:
self.index = -1
return self
@property
@abstractmethod
def more(self) -> bool:
"""Whether there are more pages to be fetched."""
raise NotImplementedError
@property
@abstractmethod
def cursor(self) -> str | None:
"""The start cursor to use for the next fetched page."""
raise NotImplementedError
@abstractmethod
def convert_objects(self) -> list[T]:
"""Convert the last fetched response data into the iterated objects."""
raise NotImplementedError
def update_variables(self) -> None:
"""Update the query variables for the next page fetch."""
self.variables.update({"perPage": self.per_page, "cursor": self.cursor})
def _update_response(self) -> None:
"""Fetch and store the response data for the next page."""
self.last_response = self.client.execute(
self.QUERY, variable_values=self.variables
)
def _load_page(self) -> bool:
"""Fetch the next page, if any, returning True and storing the response if there was one."""
if not self.more:
return False
self.update_variables()
self._update_response()
self.objects.extend(self.convert_objects())
return True
@overload
def __getitem__(self, index: int) -> T: ...
@overload
def __getitem__(self, index: slice) -> list[T]: ...
def __getitem__(self, index: int | slice) -> T | list[T]:
loaded = True
stop = index.stop if isinstance(index, slice) else index
while loaded and stop > len(self.objects) - 1:
loaded = self._load_page()
return self.objects[index]
def __next__(self) -> T:
self.index += 1
if len(self.objects) <= self.index:
if not self._load_page():
raise StopIteration
if len(self.objects) <= self.index:
raise StopIteration
return self.objects[self.index]
next = __next__
class SizedPaginator(Paginator[T], Sized):
"""A Paginator for objects with a known total count."""
def __len__(self) -> int:
if self.length is None:
self._load_page()
if self.length is None:
raise ValueError("Object doesn't provide length")
return self.length
@property
@abstractmethod
def length(self) -> int | None:
raise NotImplementedError
|