import json import os import sqlite3 from huggingface_hub import CommitScheduler try: from trackio.dummy_commit_scheduler import DummyCommitScheduler from trackio.utils import RESERVED_KEYS, TRACKIO_DIR except: # noqa: E722 from dummy_commit_scheduler import DummyCommitScheduler from utils import RESERVED_KEYS, TRACKIO_DIR class SQLiteStorage: def __init__( self, project: str, name: str, config: dict, dataset_id: str | None = None ): self.project = project self.name = name self.config = config self.db_path = os.path.join(TRACKIO_DIR, "trackio.db") self.dataset_id = dataset_id self.scheduler = self._get_scheduler() os.makedirs(TRACKIO_DIR, exist_ok=True) self._init_db() self._save_config() def _get_scheduler(self): hf_token = os.environ.get( "HF_TOKEN" ) # Get the token from the environment variable on Spaces dataset_id = self.dataset_id or os.environ.get("TRACKIO_DATASET_ID") if dataset_id is None: scheduler = DummyCommitScheduler() else: scheduler = CommitScheduler( repo_id=dataset_id, repo_type="dataset", folder_path=TRACKIO_DIR, private=True, squash_history=True, token=hf_token, ) return scheduler def _init_db(self): """Initialize the SQLite database with required tables.""" with self.scheduler.lock: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS metrics ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, project_name TEXT NOT NULL, run_name TEXT NOT NULL, metrics TEXT NOT NULL ) """) cursor.execute(""" CREATE TABLE IF NOT EXISTS configs ( project_name TEXT NOT NULL, run_name TEXT NOT NULL, config TEXT NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (project_name, run_name) ) """) conn.commit() def _save_config(self): """Save the run configuration to the database.""" with self.scheduler.lock: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute( "INSERT OR REPLACE INTO configs (project_name, run_name, config) VALUES (?, ?, ?)", (self.project, self.name, json.dumps(self.config)), ) conn.commit() def log(self, metrics: dict): """Log metrics to the database.""" for k in metrics.keys(): if k in RESERVED_KEYS or k.startswith("__"): raise ValueError( f"Please do not use this reserved key as a metric: {k}" ) with self.scheduler.lock: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute( """ INSERT INTO metrics (project_name, run_name, metrics) VALUES (?, ?, ?) """, (self.project, self.name, json.dumps(metrics)), ) conn.commit() def get_metrics(self, project: str, run: str) -> list[dict]: """Retrieve metrics for a specific run.""" with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute( """ SELECT timestamp, metrics FROM metrics WHERE project_name = ? AND run_name = ? ORDER BY timestamp """, (project, run), ) rows = cursor.fetchall() results = [] for row in rows: timestamp, metrics_json = row metrics = json.loads(metrics_json) metrics["timestamp"] = timestamp results.append(metrics) return results def get_projects(self) -> list[str]: """Get list of all projects.""" with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute("SELECT DISTINCT project_name FROM metrics") return [row[0] for row in cursor.fetchall()] def get_runs(self, project: str) -> list[str]: """Get list of all runs for a project.""" with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute( "SELECT DISTINCT run_name FROM metrics WHERE project_name = ?", (project,), ) return [row[0] for row in cursor.fetchall()] def finish(self): """Cleanup when run is finished.""" pass