Spaces:
Sleeping
Sleeping
File size: 4,117 Bytes
a9ee714 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import json
import os
import sqlite3
try:
from trackio.utils import RESERVED_KEYS, TRACKIO_DIR
except: # noqa: E722
from utils import RESERVED_KEYS, TRACKIO_DIR
class SQLiteStorage:
def __init__(self, project: str, name: str, config: dict):
self.project = project
self.name = name
self.config = config
self.db_path = os.path.join(TRACKIO_DIR, "trackio.db")
os.makedirs(TRACKIO_DIR, exist_ok=True)
self._init_db()
self._save_config()
def _init_db(self):
"""Initialize the SQLite database with required tables."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS metrics (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
project_name TEXT NOT NULL,
run_name TEXT NOT NULL,
metrics TEXT NOT NULL
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS configs (
project_name TEXT NOT NULL,
run_name TEXT NOT NULL,
config TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (project_name, run_name)
)
""")
conn.commit()
def _save_config(self):
"""Save the run configuration to the database."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO configs (project_name, run_name, config) VALUES (?, ?, ?)",
(self.project, self.name, json.dumps(self.config)),
)
conn.commit()
def log(self, metrics: dict):
"""Log metrics to the database."""
for k in metrics.keys():
if k in RESERVED_KEYS or k.startswith("__"):
raise ValueError(
f"Please do not use this reserved key as a metric: {k}"
)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT INTO metrics
(project_name, run_name, metrics)
VALUES (?, ?, ?)
""",
(self.project, self.name, json.dumps(metrics)),
)
conn.commit()
def get_metrics(self, project: str, run: str) -> list[dict]:
"""Retrieve metrics for a specific run."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT timestamp, metrics
FROM metrics
WHERE project_name = ? AND run_name = ?
ORDER BY timestamp
""",
(project, run),
)
rows = cursor.fetchall()
results = []
for row in rows:
timestamp, metrics_json = row
metrics = json.loads(metrics_json)
metrics["timestamp"] = timestamp
results.append(metrics)
return results
def get_projects(self) -> list[str]:
"""Get list of all projects."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT DISTINCT project_name FROM metrics")
return [row[0] for row in cursor.fetchall()]
def get_runs(self, project: str) -> list[str]:
"""Get list of all runs for a project."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT DISTINCT run_name FROM metrics WHERE project_name = ?",
(project,),
)
return [row[0] for row in cursor.fetchall()]
def finish(self):
"""Cleanup when run is finished."""
pass
|