File size: 2,026 Bytes
88c8414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from collections import deque

from gradio_client import Client

from trackio.utils import generate_readable_name


class Run:
    def __init__(
        self,
        url: str,
        project: str,
        client: Client,
        name: str | None = None,
        config: dict | None = None,
        dataset_id: str | None = None,
    ):
        self.url = url
        self.project = project
        self.client = client
        self.name = name or generate_readable_name()
        self.config = config or {}
        self.dataset_id = dataset_id
        self.queued_logs = deque()

    def log(self, metrics: dict):
        if self.client is None:
            # lazily try to initialize the client
            try:
                self.client = Client(self.url, verbose=False)
            except BaseException as e:
                print(
                    f"Unable to instantiate log client; error was {e}. Will queue log item and try again on next log() call."
                )
        if self.client is None:
            # client can still be None for a Space while the Space is still initializing.
            # queue up log items for when the client is not None.
            self.queued_logs.append(
                dict(
                    api_name="/log",
                    project=self.project,
                    run=self.name,
                    metrics=metrics,
                    dataset_id=self.dataset_id,
                )
            )
        else:
            # flush the queued log items, if there are any
            if len(self.queued_logs) > 0:
                for queued_log in self.queued_logs:
                    self.client.predict(**queued_log)
                self.queued_logs.clear()
            # write the current log item
            self.client.predict(
                api_name="/log",
                project=self.project,
                run=self.name,
                metrics=metrics,
                dataset_id=self.dataset_id,
            )

    def finish(self):
        pass