Spaces:
Sleeping
Sleeping
import json | |
import os | |
import sqlite3 | |
try: | |
from trackio.utils import RESERVED_KEYS, TRACKIO_DIR | |
except: # noqa: E722 | |
from utils import RESERVED_KEYS, TRACKIO_DIR | |
class SQLiteStorage: | |
def __init__(self, project: str, name: str, config: dict): | |
self.project = project | |
self.name = name | |
self.config = config | |
self.db_path = os.path.join(TRACKIO_DIR, "trackio.db") | |
os.makedirs(TRACKIO_DIR, exist_ok=True) | |
self._init_db() | |
self._save_config() | |
def _init_db(self): | |
"""Initialize the SQLite database with required tables.""" | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS metrics ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, | |
project_name TEXT NOT NULL, | |
run_name TEXT NOT NULL, | |
metrics TEXT NOT NULL | |
) | |
""") | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS configs ( | |
project_name TEXT NOT NULL, | |
run_name TEXT NOT NULL, | |
config TEXT NOT NULL, | |
created_at DATETIME DEFAULT CURRENT_TIMESTAMP, | |
PRIMARY KEY (project_name, run_name) | |
) | |
""") | |
conn.commit() | |
def _save_config(self): | |
"""Save the run configuration to the database.""" | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.execute( | |
"INSERT OR REPLACE INTO configs (project_name, run_name, config) VALUES (?, ?, ?)", | |
(self.project, self.name, json.dumps(self.config)), | |
) | |
conn.commit() | |
def log(self, metrics: dict): | |
"""Log metrics to the database.""" | |
for k in metrics.keys(): | |
if k in RESERVED_KEYS or k.startswith("__"): | |
raise ValueError( | |
f"Please do not use this reserved key as a metric: {k}" | |
) | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.execute( | |
""" | |
INSERT INTO metrics | |
(project_name, run_name, metrics) | |
VALUES (?, ?, ?) | |
""", | |
(self.project, self.name, json.dumps(metrics)), | |
) | |
conn.commit() | |
def get_metrics(self, project: str, run: str) -> list[dict]: | |
"""Retrieve metrics for a specific run.""" | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.execute( | |
""" | |
SELECT timestamp, metrics | |
FROM metrics | |
WHERE project_name = ? AND run_name = ? | |
ORDER BY timestamp | |
""", | |
(project, run), | |
) | |
rows = cursor.fetchall() | |
results = [] | |
for row in rows: | |
timestamp, metrics_json = row | |
metrics = json.loads(metrics_json) | |
metrics["timestamp"] = timestamp | |
results.append(metrics) | |
return results | |
def get_projects(self) -> list[str]: | |
"""Get list of all projects.""" | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.execute("SELECT DISTINCT project_name FROM metrics") | |
return [row[0] for row in cursor.fetchall()] | |
def get_runs(self, project: str) -> list[str]: | |
"""Get list of all runs for a project.""" | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.cursor() | |
cursor.execute( | |
"SELECT DISTINCT run_name FROM metrics WHERE project_name = ?", | |
(project,), | |
) | |
return [row[0] for row in cursor.fetchall()] | |
def finish(self): | |
"""Cleanup when run is finished.""" | |
pass | |