abidlabs HF Staff commited on
Commit
90c2333
·
verified ·
1 Parent(s): 03b83b9

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ trackio_logo.png filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextvars
2
+ import time
3
+ import webbrowser
4
+ from pathlib import Path
5
+
6
+ import huggingface_hub
7
+ from gradio_client import Client
8
+ from httpx import ReadTimeout
9
+ from huggingface_hub.errors import RepositoryNotFoundError
10
+
11
+ from trackio.deploy import deploy_as_space
12
+ from trackio.run import Run
13
+ from trackio.ui import demo
14
+ from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_PATH, block_except_in_notebook
15
+
16
+ __version__ = Path(__file__).parent.joinpath("version.txt").read_text().strip()
17
+
18
+
19
+ current_run: contextvars.ContextVar[Run | None] = contextvars.ContextVar(
20
+ "current_run", default=None
21
+ )
22
+ current_project: contextvars.ContextVar[str | None] = contextvars.ContextVar(
23
+ "current_project", default=None
24
+ )
25
+ current_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
26
+ "current_server", default=None
27
+ )
28
+
29
+ config = {}
30
+ SPACE_URL = "https://huggingface.co/spaces/{space_id}"
31
+
32
+
33
+ def init(
34
+ project: str,
35
+ name: str | None = None,
36
+ space_id: str | None = None,
37
+ persistent_dataset: str | None = None,
38
+ persistent_dataset_dir: str | None = None,
39
+ config: dict | None = None,
40
+ ) -> Run:
41
+ """
42
+ Creates a new Trackio project and returns a Run object.
43
+
44
+ Args:
45
+ project: The name of the project (can be an existing project to continue tracking or a new project to start tracking from scratch).
46
+ name: The name of the run (if not provided, a default name will be generated).
47
+ space_id: If provided, the project will be logged to a Hugging Face Space instead of a local directory. Should be a complete Space name like "username/reponame". If the Space does not exist, it will be created. If the Space already exists, the project will be logged to it.
48
+ config: A dictionary of configuration options. Provided for compatibility with wandb.init()
49
+ """
50
+ if not current_server.get() and space_id is None:
51
+ _, url, _ = demo.launch(
52
+ show_api=False, inline=False, quiet=True, prevent_thread_lock=True
53
+ )
54
+ current_server.set(url)
55
+ else:
56
+ url = current_server.get()
57
+
58
+ if current_project.get() is None or current_project.get() != project:
59
+ print(f"* Trackio project initialized: {project}")
60
+
61
+ if space_id is None:
62
+ print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
63
+ print(
64
+ f'\n* View dashboard by running in your terminal: trackio show --project "{project}"'
65
+ )
66
+ print(f'* or by running in Python: trackio.show(project="{project}")')
67
+ else:
68
+ create_space_if_not_exists(
69
+ space_id, persistent_dataset, persistent_dataset_dir
70
+ )
71
+ print(
72
+ f"* View dashboard by going to: {SPACE_URL.format(space_id=space_id)}"
73
+ )
74
+ current_project.set(project)
75
+
76
+ space_or_url = space_id if space_id else url
77
+ client = Client(space_or_url, verbose=False)
78
+ run = Run(project=project, client=client, name=name, config=config)
79
+ current_run.set(run)
80
+ globals()["config"] = run.config
81
+ return run
82
+
83
+
84
+ def create_space_if_not_exists(
85
+ space_id: str,
86
+ persistent_dataset: str | None = None,
87
+ persistent_dataset_dir: str | None = None,
88
+ ) -> None:
89
+ """
90
+ Creates a new Hugging Face Space if it does not exist.
91
+
92
+ Args:
93
+ space_id: The ID of the Space to create.
94
+ """
95
+ if "/" not in space_id:
96
+ raise ValueError(
97
+ f"Invalid space ID: {space_id}. Must be in the format: username/reponame."
98
+ )
99
+ try:
100
+ huggingface_hub.repo_info(space_id, repo_type="space")
101
+ print(f"* Found existing space: {SPACE_URL.format(space_id=space_id)}")
102
+ return
103
+ except RepositoryNotFoundError:
104
+ pass
105
+
106
+ print(f"* Creating new space: {SPACE_URL.format(space_id=space_id)}")
107
+ deploy_as_space(space_id, persistent_dataset, persistent_dataset_dir)
108
+
109
+ client = None
110
+ for _ in range(30):
111
+ try:
112
+ client = Client(space_id, verbose=False)
113
+ if client:
114
+ break
115
+ except ReadTimeout:
116
+ print("* Space is not yet ready. Waiting 5 seconds...")
117
+ time.sleep(5)
118
+ except ValueError as e:
119
+ print(f"* Space gave error {e}. Trying again in 5 seconds...")
120
+ time.sleep(5)
121
+
122
+
123
+ def log(metrics: dict) -> None:
124
+ """
125
+ Logs metrics to the current run.
126
+
127
+ Args:
128
+ metrics: A dictionary of metrics to log.
129
+ """
130
+ if current_run.get() is None:
131
+ raise RuntimeError("Call trackio.init() before log().")
132
+ current_run.get().log(metrics)
133
+
134
+
135
+ def finish():
136
+ """
137
+ Finishes the current run.
138
+ """
139
+ if current_run.get() is None:
140
+ raise RuntimeError("Call trackio.init() before finish().")
141
+ current_run.get().finish()
142
+
143
+
144
+ def show(project: str | None = None):
145
+ """
146
+ Launches the Trackio dashboard.
147
+
148
+ Args:
149
+ project: The name of the project whose runs to show. If not provided, all projects will be shown and the user can select one.
150
+ """
151
+ _, url, share_url = demo.launch(
152
+ show_api=False,
153
+ quiet=True,
154
+ inline=False,
155
+ prevent_thread_lock=True,
156
+ favicon_path=TRACKIO_LOGO_PATH,
157
+ allowed_paths=[TRACKIO_LOGO_PATH],
158
+ )
159
+ base_url = share_url + "/" if share_url else url
160
+ dashboard_url = base_url + f"?project={project}" if project else base_url
161
+ print(f"* Trackio UI launched at: {dashboard_url}")
162
+ webbrowser.open(dashboard_url)
163
+ block_except_in_notebook()
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (4.96 kB). View file
 
__pycache__/__init__.cpython-312.pyc ADDED
Binary file (7.1 kB). View file
 
__pycache__/cli.cpython-312.pyc ADDED
Binary file (1.11 kB). View file
 
__pycache__/context.cpython-312.pyc ADDED
Binary file (440 Bytes). View file
 
__pycache__/deploy.cpython-310.pyc ADDED
Binary file (1.72 kB). View file
 
__pycache__/deploy.cpython-312.pyc ADDED
Binary file (2.27 kB). View file
 
__pycache__/dummy_commit_scheduler.cpython-310.pyc ADDED
Binary file (936 Bytes). View file
 
__pycache__/run.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
__pycache__/run.cpython-312.pyc ADDED
Binary file (1.28 kB). View file
 
__pycache__/sqlite_storage.cpython-310.pyc ADDED
Binary file (5.37 kB). View file
 
__pycache__/sqlite_storage.cpython-312.pyc ADDED
Binary file (6.72 kB). View file
 
__pycache__/storage.cpython-312.pyc ADDED
Binary file (4.6 kB). View file
 
__pycache__/ui.cpython-310.pyc ADDED
Binary file (7.83 kB). View file
 
__pycache__/ui.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
__pycache__/utils.cpython-312.pyc ADDED
Binary file (3.19 kB). View file
 
cli.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from trackio import show
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description="Trackio CLI")
8
+ subparsers = parser.add_subparsers(dest="command")
9
+
10
+ ui_parser = subparsers.add_parser(
11
+ "show", help="Show the Trackio dashboard UI for a project"
12
+ )
13
+ ui_parser.add_argument(
14
+ "--project", required=False, help="Project name to show in the dashboard"
15
+ )
16
+
17
+ args = parser.parse_args()
18
+
19
+ if args.command == "show":
20
+ show(args.project)
21
+ else:
22
+ parser.print_help()
23
+
24
+
25
+ if __name__ == "__main__":
26
+ main()
deploy.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from importlib.resources import files
4
+ from pathlib import Path
5
+
6
+ import gradio
7
+ import huggingface_hub
8
+
9
+
10
+ def deploy_as_space(
11
+ title: str,
12
+ persistent_dataset: str | None = None,
13
+ persistent_dataset_dir: str | None = None,
14
+ ):
15
+ if (
16
+ os.getenv("SYSTEM") == "spaces"
17
+ ): # in case a repo with this function is uploaded to spaces
18
+ return
19
+
20
+ trackio_path = files("trackio")
21
+
22
+ hf_api = huggingface_hub.HfApi()
23
+ whoami = None
24
+ login = False
25
+ try:
26
+ whoami = hf_api.whoami()
27
+ if whoami["auth"]["accessToken"]["role"] != "write":
28
+ login = True
29
+ except OSError:
30
+ login = True
31
+ if login:
32
+ print("Need 'write' access token to create a Spaces repo.")
33
+ huggingface_hub.login(add_to_git_credential=False)
34
+ whoami = hf_api.whoami()
35
+
36
+ space_id = huggingface_hub.create_repo(
37
+ title,
38
+ space_sdk="gradio",
39
+ repo_type="space",
40
+ exist_ok=True,
41
+ ).repo_id
42
+ assert space_id == title # not sure why these would differ
43
+
44
+ with open(Path(trackio_path, "README.md"), "r") as f:
45
+ readme_content = f.read()
46
+ readme_content = readme_content.replace("{GRADIO_VERSION}", gradio.__version__)
47
+ readme_buffer = io.BytesIO(readme_content.encode("utf-8"))
48
+ hf_api.upload_file(
49
+ path_or_fileobj=readme_buffer,
50
+ path_in_repo="README.md",
51
+ repo_id=space_id,
52
+ repo_type="space",
53
+ )
54
+
55
+ huggingface_hub.utils.disable_progress_bars()
56
+ hf_api.upload_folder(
57
+ repo_id=space_id,
58
+ repo_type="space",
59
+ folder_path=trackio_path,
60
+ ignore_patterns=["README.md"],
61
+ )
62
+
63
+ # add HF_TOKEN so we have access to dataset to persist data
64
+ HF_TOKEN = os.environ.get("HF_TOKEN")
65
+ if HF_TOKEN is not None:
66
+ huggingface_hub.add_space_secret(space_id, "HF_TOKEN", HF_TOKEN)
67
+ if persistent_dataset is not None:
68
+ huggingface_hub.add_space_variable(
69
+ space_id, "PERSIST_TO_DATASET", persistent_dataset
70
+ )
71
+ if persistent_dataset_dir is not None:
72
+ huggingface_hub.add_space_variable(
73
+ space_id, "PERSIST_TO_DATASET_DIR", persistent_dataset_dir
74
+ )
dummy_commit_scheduler.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A dummy object to fit the interface of huggingface_hub's CommitScheduler
2
+ class DummyCommitSchedulerLock:
3
+ def __enter__(self):
4
+ return None
5
+
6
+ def __exit__(self, exception_type, exception_value, exception_traceback):
7
+ pass
8
+
9
+
10
+ class DummyCommitScheduler:
11
+ def __init__(self):
12
+ self.lock = DummyCommitSchedulerLock()
run.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+
3
+ from trackio.utils import generate_readable_name
4
+
5
+
6
+ class Run:
7
+ def __init__(
8
+ self,
9
+ project: str,
10
+ client: Client,
11
+ name: str | None = None,
12
+ config: dict | None = None,
13
+ ):
14
+ self.project = project
15
+ self.client = client
16
+ self.name = name or generate_readable_name()
17
+ self.config = config or {}
18
+
19
+ def log(self, metrics: dict):
20
+ self.client.predict(
21
+ api_name="/log", project=self.project, run=self.name, metrics=metrics
22
+ )
23
+
24
+ def finish(self):
25
+ pass
sqlite_storage.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sqlite3
4
+
5
+ from huggingface_hub import CommitScheduler
6
+
7
+ try:
8
+ from trackio.dummy_commit_scheduler import DummyCommitScheduler
9
+ from trackio.utils import RESERVED_KEYS, TRACKIO_DIR
10
+ except: # noqa: E722
11
+ from dummy_commit_scheduler import DummyCommitScheduler
12
+ from utils import RESERVED_KEYS, TRACKIO_DIR
13
+
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
+ PERSIST_TO_DATASET = os.environ.get("PERSIST_TO_DATASET")
16
+ PERSIST_TO_DATASET_DIR = os.environ.get("PERSIST_TO_DATASET_DIR")
17
+ if PERSIST_TO_DATASET is None:
18
+ scheduler = DummyCommitScheduler()
19
+ else:
20
+ scheduler = CommitScheduler(
21
+ repo_id=PERSIST_TO_DATASET,
22
+ repo_type="dataset",
23
+ folder_path=TRACKIO_DIR,
24
+ path_in_repo=PERSIST_TO_DATASET_DIR,
25
+ private=True,
26
+ )
27
+
28
+
29
+ class SQLiteStorage:
30
+ def __init__(self, project: str, name: str, config: dict):
31
+ self.project = project
32
+ self.name = name
33
+ self.config = config
34
+ self.db_path = os.path.join(TRACKIO_DIR, "trackio.db")
35
+
36
+ os.makedirs(TRACKIO_DIR, exist_ok=True)
37
+
38
+ self._init_db()
39
+ self._save_config()
40
+
41
+ def _init_db(self):
42
+ """Initialize the SQLite database with required tables."""
43
+ with scheduler.lock:
44
+ with sqlite3.connect(self.db_path) as conn:
45
+ cursor = conn.cursor()
46
+
47
+ cursor.execute("""
48
+ CREATE TABLE IF NOT EXISTS metrics (
49
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
50
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
51
+ project_name TEXT NOT NULL,
52
+ run_name TEXT NOT NULL,
53
+ metrics TEXT NOT NULL
54
+ )
55
+ """)
56
+
57
+ cursor.execute("""
58
+ CREATE TABLE IF NOT EXISTS configs (
59
+ project_name TEXT NOT NULL,
60
+ run_name TEXT NOT NULL,
61
+ config TEXT NOT NULL,
62
+ created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
63
+ PRIMARY KEY (project_name, run_name)
64
+ )
65
+ """)
66
+
67
+ conn.commit()
68
+
69
+ def _save_config(self):
70
+ """Save the run configuration to the database."""
71
+ with scheduler.lock:
72
+ with sqlite3.connect(self.db_path) as conn:
73
+ cursor = conn.cursor()
74
+ cursor.execute(
75
+ "INSERT OR REPLACE INTO configs (project_name, run_name, config) VALUES (?, ?, ?)",
76
+ (self.project, self.name, json.dumps(self.config)),
77
+ )
78
+ conn.commit()
79
+
80
+ def log(self, metrics: dict):
81
+ """Log metrics to the database."""
82
+ for k in metrics.keys():
83
+ if k in RESERVED_KEYS or k.startswith("__"):
84
+ raise ValueError(
85
+ f"Please do not use this reserved key as a metric: {k}"
86
+ )
87
+
88
+ with scheduler.lock:
89
+ with sqlite3.connect(self.db_path) as conn:
90
+ cursor = conn.cursor()
91
+ cursor.execute(
92
+ """
93
+ INSERT INTO metrics
94
+ (project_name, run_name, metrics)
95
+ VALUES (?, ?, ?)
96
+ """,
97
+ (self.project, self.name, json.dumps(metrics)),
98
+ )
99
+ conn.commit()
100
+
101
+ def get_metrics(self, project: str, run: str) -> list[dict]:
102
+ """Retrieve metrics for a specific run."""
103
+ with sqlite3.connect(self.db_path) as conn:
104
+ cursor = conn.cursor()
105
+ cursor.execute(
106
+ """
107
+ SELECT timestamp, metrics
108
+ FROM metrics
109
+ WHERE project_name = ? AND run_name = ?
110
+ ORDER BY timestamp
111
+ """,
112
+ (project, run),
113
+ )
114
+ rows = cursor.fetchall()
115
+
116
+ results = []
117
+ for row in rows:
118
+ timestamp, metrics_json = row
119
+ metrics = json.loads(metrics_json)
120
+ metrics["timestamp"] = timestamp
121
+ results.append(metrics)
122
+
123
+ return results
124
+
125
+ def get_projects(self) -> list[str]:
126
+ """Get list of all projects."""
127
+ with sqlite3.connect(self.db_path) as conn:
128
+ cursor = conn.cursor()
129
+ cursor.execute("SELECT DISTINCT project_name FROM metrics")
130
+ return [row[0] for row in cursor.fetchall()]
131
+
132
+ def get_runs(self, project: str) -> list[str]:
133
+ """Get list of all runs for a project."""
134
+ with sqlite3.connect(self.db_path) as conn:
135
+ cursor = conn.cursor()
136
+ cursor.execute(
137
+ "SELECT DISTINCT run_name FROM metrics WHERE project_name = ?",
138
+ (project,),
139
+ )
140
+ return [row[0] for row in cursor.fetchall()]
141
+
142
+ def finish(self):
143
+ """Cleanup when run is finished."""
144
+ pass
trackio_logo.png ADDED

Git LFS Details

  • SHA256: 3922c4d1e465270ad4d8abb12023f3beed5d9f7f338528a4c0ac21dcf358a1c8
  • Pointer size: 131 Bytes
  • Size of remote file: 487 kB
ui.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+
6
+ try:
7
+ from trackio.sqlite_storage import SQLiteStorage
8
+ from trackio.utils import RESERVED_KEYS, TRACKIO_LOGO_PATH
9
+ except: # noqa: E722
10
+ from sqlite_storage import SQLiteStorage
11
+ from utils import RESERVED_KEYS, TRACKIO_LOGO_PATH
12
+
13
+ css = """
14
+ #run-cb .wrap {
15
+ gap: 2px;
16
+ }
17
+ #run-cb .wrap label {
18
+ line-height: 1;
19
+ padding: 6px;
20
+ }
21
+ """
22
+
23
+ COLOR_PALETTE = [
24
+ "#3B82F6",
25
+ "#EF4444",
26
+ "#10B981",
27
+ "#F59E0B",
28
+ "#8B5CF6",
29
+ "#EC4899",
30
+ "#06B6D4",
31
+ "#84CC16",
32
+ "#F97316",
33
+ "#6366F1",
34
+ ]
35
+
36
+
37
+ def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]:
38
+ """Generate color mapping for runs, with transparency for original data when smoothing is enabled."""
39
+ color_map = {}
40
+
41
+ for i, run in enumerate(runs):
42
+ base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)]
43
+
44
+ if smoothing:
45
+ color_map[f"{run}_smoothed"] = base_color
46
+ color_map[f"{run}_original"] = base_color + "4D"
47
+ else:
48
+ color_map[run] = base_color
49
+
50
+ return color_map
51
+
52
+
53
+ def get_projects(request: gr.Request):
54
+ storage = SQLiteStorage("", "", {})
55
+ projects = storage.get_projects()
56
+ if project := request.query_params.get("project"):
57
+ interactive = False
58
+ else:
59
+ interactive = True
60
+ project = projects[0] if projects else None
61
+ return gr.Dropdown(
62
+ label="Project",
63
+ choices=projects,
64
+ value=project,
65
+ allow_custom_value=True,
66
+ interactive=interactive,
67
+ )
68
+
69
+
70
+ def get_runs(project):
71
+ if not project:
72
+ return []
73
+ storage = SQLiteStorage("", "", {})
74
+ return storage.get_runs(project)
75
+
76
+
77
+ def load_run_data(project: str | None, run: str | None, smoothing: bool):
78
+ if not project or not run:
79
+ return None
80
+ storage = SQLiteStorage("", "", {})
81
+ metrics = storage.get_metrics(project, run)
82
+ if not metrics:
83
+ return None
84
+ df = pd.DataFrame(metrics)
85
+
86
+ if "step" not in df.columns:
87
+ df["step"] = range(len(df))
88
+
89
+ if smoothing:
90
+ numeric_cols = df.select_dtypes(include="number").columns
91
+ numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS]
92
+
93
+ df_original = df.copy()
94
+ df_original["run"] = f"{run}_original"
95
+ df_original["data_type"] = "original"
96
+
97
+ df_smoothed = df.copy()
98
+ df_smoothed[numeric_cols] = df_smoothed[numeric_cols].ewm(alpha=0.1).mean()
99
+ df_smoothed["run"] = f"{run}_smoothed"
100
+ df_smoothed["data_type"] = "smoothed"
101
+
102
+ combined_df = pd.concat([df_original, df_smoothed], ignore_index=True)
103
+ return combined_df
104
+ else:
105
+ df["run"] = run
106
+ df["data_type"] = "original"
107
+ return df
108
+
109
+
110
+ def update_runs(project, filter_text, user_interacted_with_runs=False):
111
+ if project is None:
112
+ runs = []
113
+ num_runs = 0
114
+ else:
115
+ runs = get_runs(project)
116
+ num_runs = len(runs)
117
+ if filter_text:
118
+ runs = [r for r in runs if filter_text in r]
119
+ if not user_interacted_with_runs:
120
+ return gr.CheckboxGroup(
121
+ choices=runs, value=[runs[0]] if runs else []
122
+ ), gr.Textbox(label=f"Runs ({num_runs})")
123
+ else:
124
+ return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})")
125
+
126
+
127
+ def filter_runs(project, filter_text):
128
+ runs = get_runs(project)
129
+ runs = [r for r in runs if filter_text in r]
130
+ return gr.CheckboxGroup(choices=runs, value=runs)
131
+
132
+
133
+ def toggle_timer(cb_value):
134
+ if cb_value:
135
+ return gr.Timer(active=True)
136
+ else:
137
+ return gr.Timer(active=False)
138
+
139
+
140
+ def log(project: str, run: str, metrics: dict[str, Any]) -> None:
141
+ storage = SQLiteStorage(project, run, {})
142
+ storage.log(metrics)
143
+
144
+
145
+ def sort_metrics_by_prefix(metrics: list[str]) -> list[str]:
146
+ """
147
+ Sort metrics by grouping prefixes together.
148
+ Metrics without prefixes come first, then grouped by prefix.
149
+
150
+ Example:
151
+ Input: ["train/loss", "loss", "train/acc", "val/loss"]
152
+ Output: ["loss", "train/acc", "train/loss", "val/loss"]
153
+ """
154
+ no_prefix = []
155
+ with_prefix = []
156
+
157
+ for metric in metrics:
158
+ if "/" in metric:
159
+ with_prefix.append(metric)
160
+ else:
161
+ no_prefix.append(metric)
162
+
163
+ no_prefix.sort()
164
+
165
+ prefix_groups = {}
166
+ for metric in with_prefix:
167
+ prefix = metric.split("/")[0]
168
+ if prefix not in prefix_groups:
169
+ prefix_groups[prefix] = []
170
+ prefix_groups[prefix].append(metric)
171
+
172
+ sorted_with_prefix = []
173
+ for prefix in sorted(prefix_groups.keys()):
174
+ sorted_with_prefix.extend(sorted(prefix_groups[prefix]))
175
+
176
+ return no_prefix + sorted_with_prefix
177
+
178
+
179
+ def configure(request: gr.Request):
180
+ if metrics := request.query_params.get("metrics"):
181
+ return metrics.split(",")
182
+ else:
183
+ return []
184
+
185
+
186
+ with gr.Blocks(theme="citrus", title="Trackio Dashboard", css=css) as demo:
187
+ with gr.Sidebar() as sidebar:
188
+ gr.Markdown(
189
+ f"<div style='display: flex; align-items: center; gap: 8px;'><img src='/gradio_api/file={TRACKIO_LOGO_PATH}' width='32' height='32'><span style='font-size: 2em; font-weight: bold;'>Trackio</span></div>"
190
+ )
191
+ project_dd = gr.Dropdown(label="Project")
192
+ run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...")
193
+ run_cb = gr.CheckboxGroup(
194
+ label="Runs", choices=[], interactive=True, elem_id="run-cb"
195
+ )
196
+ with gr.Sidebar(position="right", open=False) as settings_sidebar:
197
+ gr.Markdown("### ⚙️ Settings")
198
+ realtime_cb = gr.Checkbox(label="Refresh realtime", value=True)
199
+ smoothing_cb = gr.Checkbox(label="Smoothing", value=True)
200
+
201
+ timer = gr.Timer(value=1)
202
+ metrics_subset = gr.State([])
203
+ user_interacted_with_run_cb = gr.State(False)
204
+
205
+ gr.on(
206
+ [demo.load],
207
+ fn=configure,
208
+ outputs=metrics_subset,
209
+ )
210
+ gr.on(
211
+ [demo.load],
212
+ fn=get_projects,
213
+ outputs=project_dd,
214
+ show_progress="hidden",
215
+ )
216
+ gr.on(
217
+ [timer.tick],
218
+ fn=update_runs,
219
+ inputs=[project_dd, run_tb, user_interacted_with_run_cb],
220
+ outputs=[run_cb, run_tb],
221
+ show_progress="hidden",
222
+ )
223
+ gr.on(
224
+ [demo.load, project_dd.change],
225
+ fn=update_runs,
226
+ inputs=[project_dd, run_tb],
227
+ outputs=[run_cb, run_tb],
228
+ show_progress="hidden",
229
+ )
230
+
231
+ realtime_cb.change(
232
+ fn=toggle_timer,
233
+ inputs=realtime_cb,
234
+ outputs=timer,
235
+ api_name="toggle_timer",
236
+ )
237
+ run_cb.input(
238
+ fn=lambda: True,
239
+ outputs=user_interacted_with_run_cb,
240
+ )
241
+ run_tb.input(
242
+ fn=filter_runs,
243
+ inputs=[project_dd, run_tb],
244
+ outputs=run_cb,
245
+ )
246
+
247
+ gr.api(
248
+ fn=log,
249
+ api_name="log",
250
+ )
251
+
252
+ x_lim = gr.State(None)
253
+
254
+ def update_x_lim(select_data: gr.SelectData):
255
+ return select_data.index
256
+
257
+ @gr.render(
258
+ triggers=[
259
+ demo.load,
260
+ run_cb.change,
261
+ timer.tick,
262
+ smoothing_cb.change,
263
+ x_lim.change,
264
+ ],
265
+ inputs=[project_dd, run_cb, smoothing_cb, metrics_subset, x_lim],
266
+ )
267
+ def update_dashboard(project, runs, smoothing, metrics_subset, x_lim_value):
268
+ dfs = []
269
+ original_runs = runs.copy()
270
+
271
+ for run in runs:
272
+ df = load_run_data(project, run, smoothing)
273
+ if df is not None:
274
+ dfs.append(df)
275
+
276
+ if dfs:
277
+ master_df = pd.concat(dfs, ignore_index=True)
278
+ else:
279
+ master_df = pd.DataFrame()
280
+
281
+ if master_df.empty:
282
+ return
283
+
284
+ numeric_cols = master_df.select_dtypes(include="number").columns
285
+ numeric_cols = [
286
+ c for c in numeric_cols if c not in RESERVED_KEYS and c != "step"
287
+ ]
288
+ if metrics_subset:
289
+ numeric_cols = [c for c in numeric_cols if c in metrics_subset]
290
+ numeric_cols = sort_metrics_by_prefix(list(numeric_cols))
291
+
292
+ color_map = get_color_mapping(original_runs, smoothing)
293
+
294
+ plots: list[gr.LinePlot] = []
295
+ for col in range((len(numeric_cols) + 1) // 2):
296
+ with gr.Row(key=f"row-{col}"):
297
+ for i in range(2):
298
+ metric_idx = 2 * col + i
299
+ if metric_idx < len(numeric_cols):
300
+ metric_name = numeric_cols[metric_idx]
301
+
302
+ metric_df = master_df.dropna(subset=[metric_name])
303
+
304
+ if not metric_df.empty:
305
+ plot = gr.LinePlot(
306
+ metric_df,
307
+ x="step",
308
+ y=metric_name,
309
+ color="run" if "run" in metric_df.columns else None,
310
+ color_map=color_map,
311
+ title=metric_name,
312
+ key=f"plot-{col}-{i}",
313
+ preserved_by_key=None,
314
+ x_lim=x_lim_value,
315
+ y_lim=[
316
+ metric_df[metric_name].min(),
317
+ metric_df[metric_name].max(),
318
+ ],
319
+ show_fullscreen_button=True,
320
+ )
321
+ plots.append(plot)
322
+
323
+ for plot in plots:
324
+ plot.select(update_x_lim, outputs=x_lim)
325
+ plot.double_click(lambda: None, outputs=x_lim)
326
+
327
+
328
+ if __name__ == "__main__":
329
+ demo.launch(allowed_paths=[TRACKIO_LOGO_PATH], show_api=False)
utils.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import sys
4
+ import time
5
+ from pathlib import Path
6
+
7
+ from huggingface_hub.constants import HF_HOME
8
+
9
+ RESERVED_KEYS = ["project", "run", "timestamp", "step"]
10
+ TRACKIO_DIR = os.path.join(HF_HOME, "trackio")
11
+
12
+ TRACKIO_LOGO_PATH = str(Path(__file__).parent.joinpath("trackio_logo.png"))
13
+
14
+
15
+ def generate_readable_name():
16
+ """
17
+ Generates a random, readable name like "dainty-sunset-1"
18
+ """
19
+ adjectives = [
20
+ "dainty",
21
+ "brave",
22
+ "calm",
23
+ "eager",
24
+ "fancy",
25
+ "gentle",
26
+ "happy",
27
+ "jolly",
28
+ "kind",
29
+ "lively",
30
+ "merry",
31
+ "nice",
32
+ "proud",
33
+ "quick",
34
+ "silly",
35
+ "tidy",
36
+ "witty",
37
+ "zealous",
38
+ "bright",
39
+ "shy",
40
+ "bold",
41
+ "clever",
42
+ "daring",
43
+ "elegant",
44
+ "faithful",
45
+ "graceful",
46
+ "honest",
47
+ "inventive",
48
+ "jovial",
49
+ "keen",
50
+ "lucky",
51
+ "modest",
52
+ "noble",
53
+ "optimistic",
54
+ "patient",
55
+ "quirky",
56
+ "resourceful",
57
+ "sincere",
58
+ "thoughtful",
59
+ "upbeat",
60
+ "valiant",
61
+ "warm",
62
+ "youthful",
63
+ "zesty",
64
+ "adventurous",
65
+ "breezy",
66
+ "cheerful",
67
+ "delightful",
68
+ "energetic",
69
+ "fearless",
70
+ "glad",
71
+ "hopeful",
72
+ "imaginative",
73
+ "joyful",
74
+ "kindly",
75
+ "luminous",
76
+ "mysterious",
77
+ "neat",
78
+ "outgoing",
79
+ "playful",
80
+ "radiant",
81
+ "spirited",
82
+ "tranquil",
83
+ "unique",
84
+ "vivid",
85
+ "wise",
86
+ "zany",
87
+ "artful",
88
+ "bubbly",
89
+ "charming",
90
+ "dazzling",
91
+ "earnest",
92
+ "festive",
93
+ "gentlemanly",
94
+ "hearty",
95
+ "intrepid",
96
+ "jubilant",
97
+ "knightly",
98
+ "lively",
99
+ "magnetic",
100
+ "nimble",
101
+ "orderly",
102
+ "peaceful",
103
+ "quick-witted",
104
+ "robust",
105
+ "sturdy",
106
+ "trusty",
107
+ "upstanding",
108
+ "vibrant",
109
+ "whimsical",
110
+ ]
111
+ nouns = [
112
+ "sunset",
113
+ "forest",
114
+ "river",
115
+ "mountain",
116
+ "breeze",
117
+ "meadow",
118
+ "ocean",
119
+ "valley",
120
+ "sky",
121
+ "field",
122
+ "cloud",
123
+ "star",
124
+ "rain",
125
+ "leaf",
126
+ "stone",
127
+ "flower",
128
+ "bird",
129
+ "tree",
130
+ "wave",
131
+ "trail",
132
+ "island",
133
+ "desert",
134
+ "hill",
135
+ "lake",
136
+ "pond",
137
+ "grove",
138
+ "canyon",
139
+ "reef",
140
+ "bay",
141
+ "peak",
142
+ "glade",
143
+ "marsh",
144
+ "cliff",
145
+ "dune",
146
+ "spring",
147
+ "brook",
148
+ "cave",
149
+ "plain",
150
+ "ridge",
151
+ "wood",
152
+ "blossom",
153
+ "petal",
154
+ "root",
155
+ "branch",
156
+ "seed",
157
+ "acorn",
158
+ "pine",
159
+ "willow",
160
+ "cedar",
161
+ "elm",
162
+ "falcon",
163
+ "eagle",
164
+ "sparrow",
165
+ "robin",
166
+ "owl",
167
+ "finch",
168
+ "heron",
169
+ "crane",
170
+ "duck",
171
+ "swan",
172
+ "fox",
173
+ "wolf",
174
+ "bear",
175
+ "deer",
176
+ "moose",
177
+ "otter",
178
+ "beaver",
179
+ "lynx",
180
+ "hare",
181
+ "badger",
182
+ "butterfly",
183
+ "bee",
184
+ "ant",
185
+ "beetle",
186
+ "dragonfly",
187
+ "firefly",
188
+ "ladybug",
189
+ "moth",
190
+ "spider",
191
+ "worm",
192
+ "coral",
193
+ "kelp",
194
+ "shell",
195
+ "pebble",
196
+ "boulder",
197
+ "cobble",
198
+ "sand",
199
+ "wavelet",
200
+ "tide",
201
+ "current",
202
+ ]
203
+ adjective = random.choice(adjectives)
204
+ noun = random.choice(nouns)
205
+ number = random.randint(1, 99)
206
+ return f"{adjective}-{noun}-{number}"
207
+
208
+
209
+ def block_except_in_notebook():
210
+ in_notebook = bool(getattr(sys, "ps1", sys.flags.interactive))
211
+ if in_notebook:
212
+ return
213
+ try:
214
+ while True:
215
+ time.sleep(0.1)
216
+ except (KeyboardInterrupt, OSError):
217
+ print("Keyboard interruption in main thread... closing dashboard.")
version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.0.10