Spaces:
Sleeping
Sleeping
from typing import Any | |
import gradio as gr | |
import pandas as pd | |
try: | |
from trackio.sqlite_storage import SQLiteStorage | |
from trackio.utils import RESERVED_KEYS, TRACKIO_LOGO_PATH | |
except: # noqa: E722 | |
from sqlite_storage import SQLiteStorage | |
from utils import RESERVED_KEYS, TRACKIO_LOGO_PATH | |
def get_projects(request: gr.Request): | |
storage = SQLiteStorage("", "", {}) | |
projects = storage.get_projects() | |
if project := request.query_params.get("project"): | |
interactive = False | |
else: | |
interactive = True | |
project = projects[0] if projects else None | |
return gr.Dropdown( | |
label="Project", | |
choices=projects, | |
value=project, | |
allow_custom_value=True, | |
interactive=interactive, | |
) | |
def get_runs(project): | |
if not project: | |
return [] | |
storage = SQLiteStorage("", "", {}) | |
return storage.get_runs(project) | |
def load_run_data(project: str | None, run: str | None, smoothing: bool): | |
if not project or not run: | |
return None | |
storage = SQLiteStorage("", "", {}) | |
metrics = storage.get_metrics(project, run) | |
if not metrics: | |
return None | |
df = pd.DataFrame(metrics) | |
if smoothing: | |
numeric_cols = df.select_dtypes(include="number").columns | |
numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS] | |
df[numeric_cols] = df[numeric_cols].ewm(alpha=0.1).mean() | |
if "step" not in df.columns: | |
df["step"] = range(len(df)) | |
return df | |
def update_runs(project): | |
if project is None: | |
runs = [] | |
else: | |
runs = get_runs(project) | |
return gr.Dropdown(choices=runs, value=runs) | |
def toggle_timer(cb_value): | |
if cb_value: | |
return gr.Timer(active=True) | |
else: | |
return gr.Timer(active=False) | |
def log(project: str, run: str, metrics: dict[str, Any]) -> None: | |
storage = SQLiteStorage(project, run, {}) | |
storage.log(metrics) | |
def configure(request: gr.Request): | |
if metrics := request.query_params.get("metrics"): | |
return metrics.split(",") | |
else: | |
return [] | |
with gr.Blocks(theme="citrus", title="Trackio Dashboard") as demo: | |
with gr.Sidebar() as sidebar: | |
gr.Markdown( | |
f"<div style='display: flex; align-items: center; gap: 8px;'><img src='/gradio_api/file={TRACKIO_LOGO_PATH}' width='32' height='32'><span style='font-size: 2em; font-weight: bold;'>Trackio</span></div>" | |
) | |
project_dd = gr.Dropdown(label="Project", allow_custom_value=True) | |
gr.Markdown("### ⚙️ Settings") | |
realtime_cb = gr.Checkbox(label="Refresh realtime", value=True) | |
smoothing_cb = gr.Checkbox(label="Smoothing", value=True) | |
with gr.Row(): | |
run_dd = gr.Dropdown(label="Run", choices=[], multiselect=True) | |
timer = gr.Timer(value=1) | |
metrics_subset = gr.State([]) | |
gr.on( | |
[demo.load], | |
fn=configure, | |
outputs=metrics_subset, | |
) | |
gr.on( | |
[demo.load, timer.tick], | |
fn=get_projects, | |
outputs=project_dd, | |
show_progress="hidden", | |
) | |
gr.on( | |
[demo.load, project_dd.change, timer.tick], | |
fn=update_runs, | |
inputs=project_dd, | |
outputs=run_dd, | |
show_progress="hidden", | |
) | |
realtime_cb.change( | |
fn=toggle_timer, | |
inputs=realtime_cb, | |
outputs=timer, | |
api_name="toggle_timer", | |
) | |
gr.api( | |
fn=log, | |
api_name="log", | |
) | |
x_lim = gr.State(None) | |
def update_x_lim(select_data: gr.SelectData): | |
return select_data.index | |
def update_dashboard(project, runs, smoothing, metrics_subset, x_lim_value): | |
dfs = [] | |
for run in runs: | |
df = load_run_data(project, run, smoothing) | |
if df is not None: | |
df["run"] = run | |
dfs.append(df) | |
if dfs: | |
master_df = pd.concat(dfs, ignore_index=True) | |
else: | |
master_df = pd.DataFrame() | |
numeric_cols = master_df.select_dtypes(include="number").columns | |
numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS] | |
if metrics_subset: | |
numeric_cols = [c for c in numeric_cols if c in metrics_subset] | |
plots: list[gr.LinePlot] = [] | |
for col in range(len(numeric_cols) // 2): | |
with gr.Row(key=f"row-{col}"): | |
for i in range(2): | |
plot = gr.LinePlot( | |
master_df, | |
x="step", | |
y=numeric_cols[2 * col + i], | |
color="run" if "run" in master_df.columns else None, | |
title=numeric_cols[2 * col + i], | |
key=f"plot-{col}-{i}", | |
preserved_by_key=None, | |
x_lim=x_lim_value, | |
y_lim=[ | |
min(master_df[numeric_cols[2 * col + i]]), | |
max(master_df[numeric_cols[2 * col + i]]), | |
], | |
show_fullscreen_button=True, | |
) | |
plots.append(plot) | |
for plot in plots: | |
plot.select(update_x_lim, outputs=x_lim) | |
plot.double_click(lambda: None, outputs=x_lim) | |
if __name__ == "__main__": | |
demo.launch(allowed_paths=[TRACKIO_LOGO_PATH]) | |