"""Public API: history.""" import json import requests from wandb_gql import gql from wandb_gql.client import RetryError from wandb import util from wandb.apis.normalize import normalize_exceptions from wandb.sdk.lib import retry class HistoryScan: QUERY = gql( """ query HistoryPage($entity: String!, $project: String!, $run: String!, $minStep: Int64!, $maxStep: Int64!, $pageSize: Int!) { project(name: $project, entityName: $entity) { run(name: $run) { history(minStep: $minStep, maxStep: $maxStep, samples: $pageSize) } } } """ ) def __init__(self, client, run, min_step, max_step, page_size=1000): self.client = client self.run = run self.page_size = page_size self.min_step = min_step self.max_step = max_step self.page_offset = min_step # minStep for next page self.scan_offset = 0 # index within current page of rows self.rows = [] # current page of rows def __iter__(self): self.page_offset = self.min_step self.scan_offset = 0 self.rows = [] return self def __next__(self): while True: if self.scan_offset < len(self.rows): row = self.rows[self.scan_offset] self.scan_offset += 1 return row if self.page_offset >= self.max_step: raise StopIteration() self._load_next() next = __next__ @normalize_exceptions @retry.retriable( check_retry_fn=util.no_retry_auth, retryable_exceptions=(RetryError, requests.RequestException), ) def _load_next(self): max_step = self.page_offset + self.page_size if max_step > self.max_step: max_step = self.max_step variables = { "entity": self.run.entity, "project": self.run.project, "run": self.run.id, "minStep": int(self.page_offset), "maxStep": int(max_step), "pageSize": int(self.page_size), } res = self.client.execute(self.QUERY, variable_values=variables) res = res["project"]["run"]["history"] self.rows = [json.loads(row) for row in res] self.page_offset += self.page_size self.scan_offset = 0 class SampledHistoryScan: QUERY = gql( """ query SampledHistoryPage($entity: String!, $project: String!, $run: String!, $spec: JSONString!) { project(name: $project, entityName: $entity) { run(name: $run) { sampledHistory(specs: [$spec]) } } } """ ) def __init__(self, client, run, keys, min_step, max_step, page_size=1000): self.client = client self.run = run self.keys = keys self.page_size = page_size self.min_step = min_step self.max_step = max_step self.page_offset = min_step # minStep for next page self.scan_offset = 0 # index within current page of rows self.rows = [] # current page of rows def __iter__(self): self.page_offset = self.min_step self.scan_offset = 0 self.rows = [] return self def __next__(self): while True: if self.scan_offset < len(self.rows): row = self.rows[self.scan_offset] self.scan_offset += 1 return row if self.page_offset >= self.max_step: raise StopIteration() self._load_next() next = __next__ @normalize_exceptions @retry.retriable( check_retry_fn=util.no_retry_auth, retryable_exceptions=(RetryError, requests.RequestException), ) def _load_next(self): max_step = self.page_offset + self.page_size if max_step > self.max_step: max_step = self.max_step variables = { "entity": self.run.entity, "project": self.run.project, "run": self.run.id, "spec": json.dumps( { "keys": self.keys, "minStep": int(self.page_offset), "maxStep": int(max_step), "samples": int(self.page_size), } ), } res = self.client.execute(self.QUERY, variable_values=variables) res = res["project"]["run"]["sampledHistory"] self.rows = res[0] self.page_offset += self.page_size self.scan_offset = 0