mlip-arena / serve /tasks /stability.py
github-actions[ci]
Clean sync from main branch - 2025-09-17 22:25:23
f880d1c
raw
history blame
10.2 kB
from pathlib import Path
from typing import Literal
import numpy as np
import pandas as pd
import plotly.colors as pcolors
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from plotly.subplots import make_subplots
from mlip_arena.models import REGISTRY
st.title("Stability")
DATA_DIR = Path(__file__).parents[2] / "benchmarks" / "stability"
st.markdown("### Methods")
container = st.container(border=True)
# Filter models that have valid parquet results
valid_models = [
model
for model, metadata in REGISTRY.items()
if (
DATA_DIR / REGISTRY[str(model)]["family"].lower() / f"{model}-heating.parquet"
).exists()
]
models = container.multiselect(
"MLIPs",
valid_models,
[
"MACE-MP(M)",
"CHGNet",
"SevenNet",
"ORBv2",
"eqV2(OMat)",
"M3GNet",
"MatterSim",
"MACE-MPA",
],
)
st.markdown("### Settings")
vis = st.container(border=True)
# Build available color palettes from Plotly
color_palettes = {
attr: getattr(pcolors.qualitative, attr)
for attr in dir(pcolors.qualitative)
if isinstance(getattr(pcolors.qualitative, attr), list)
}
color_palettes.pop("__all__", None)
palette_name = vis.selectbox(
"Color sequence", options=list(color_palettes.keys()), index=22
)
color_sequence = color_palettes[palette_name]
if not models:
st.stop()
@st.cache_data
def get_data(model_list, run_type: Literal["heating", "compression"]) -> pd.DataFrame:
"""Load parquet files for selected models."""
dfs = []
for m in model_list:
fpath = (
DATA_DIR / REGISTRY[str(m)]["family"].lower() / f"{m}-{run_type}.parquet"
)
if not fpath.exists():
continue
df_local = pd.read_parquet(fpath)
df_local["method"] = str(m)
dfs.append(df_local)
return pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()
df_nvt = get_data(models, run_type="heating")
df_npt = get_data(models, run_type="compression")
# Map model → color
method_color_mapping = {
method: color_sequence[i % len(color_sequence)]
for i, method in enumerate(df_nvt["method"].unique())
}
@st.cache_data
def prepare_scatter_df(df_in: pd.DataFrame, max_points: int = 20000) -> pd.DataFrame:
"""Prepare scatter dataframe with marker sizes scaled by total steps."""
dfp = df_in.dropna(subset=["natoms", "steps_per_second"]).copy()
if dfp.empty:
return dfp
# Downsample if too many points
if len(dfp) > max_points:
dfp = dfp.sample(max_points, random_state=1)
if "total_steps" in dfp.columns:
ts_local = dfp["total_steps"].fillna(dfp["total_steps"].median()).astype(float)
ts_range = ts_local.max() - ts_local.min()
scaled = (ts_local - ts_local.min()) / (ts_range if ts_range != 0 else 1.0)
dfp["_marker_size"] = (scaled * 40) + 5
else:
dfp["_marker_size"] = 8
return dfp
@st.cache_data
def compute_power_law_fits(df_in: pd.DataFrame) -> dict:
"""Fit power-law scaling: steps/s ~ a * N^(-n)."""
fits = {}
for name, grp in df_in.groupby("method"):
grp_clean = grp.dropna(subset=["natoms", "steps_per_second"])
grp_clean = grp_clean[
(grp_clean["natoms"] > 0) & (grp_clean["steps_per_second"] > 0)
]
if len(grp_clean) < 3:
continue
try:
logsx = np.log(grp_clean["natoms"].astype(float))
logsy = np.log(grp_clean["steps_per_second"].astype(float))
slope, intercept = np.polyfit(logsx, logsy, 1)
fits[name] = (float(np.exp(intercept)), float(-slope)) # (a, n)
except Exception:
continue
return fits
@st.cache_data
def build_speed_figure(
df_in: pd.DataFrame, color_map: dict, show_scatter: bool
) -> go.Figure:
"""Build scatter plot of inference speed vs number of atoms with power-law fits."""
fig = go.Figure()
# Optionally add scatter points
if show_scatter:
dfp = prepare_scatter_df(df_in)
scatter_fig = px.scatter(
dfp,
x="natoms",
y="steps_per_second",
color="method",
size="_marker_size",
hover_data=[c for c in ["material_id", "formula"] if c in dfp.columns],
color_discrete_map=color_map,
log_x=True,
log_y=True,
render_mode="webgl",
labels={
"steps_per_second": "Steps per second",
"natoms": "Number of atoms",
},
)
for trace in scatter_fig.data:
fig.add_trace(trace)
# Overlay fits
fits = compute_power_law_fits(df_in)
for method, (a, n) in fits.items():
grp = df_in[df_in["method"] == method]
if grp["natoms"].dropna().empty:
continue
xs = np.logspace(
np.log10(grp["natoms"].min()), np.log10(grp["natoms"].max()), 200
)
ys = a * xs ** (-n)
fig.add_trace(
go.Scatter(
x=xs,
y=ys,
mode="lines",
line=dict(color=color_map.get(method, "black"), width=2),
showlegend=not show_scatter,
name=f"{method}",
# zorder=0,
# text=hover_text,
# hoverinfo='text', # use the custom text
)
)
fig.update_layout(
height=520,
title="Inference speed (steps/s)",
xaxis=dict(type="log", title="Number of atoms"),
yaxis=dict(type="log", title="Steps per second"),
)
return fig
@st.cache_data
def build_nvt_figure(
df_in: pd.DataFrame, color_map: dict, show_scatter: bool
) -> go.Figure:
"""Build subplot: NVT valid runs (cumulative) + speed scaling plot."""
fig = make_subplots(
rows=1,
cols=2,
column_widths=[0.4, 0.6],
subplot_titles=("Valid runs", "Inference speed: steps/s vs N"),
)
# Right panel: speed scaling
speed_fig = build_speed_figure(df_in, color_map, show_scatter)
for trace in speed_fig.data:
fig.add_trace(trace, row=1, col=2)
# Left panel: cumulative valid runs
for method, df_model in df_in.groupby("method"):
df_model_grp = df_model.drop_duplicates(["formula"])
hist, bin_edges = np.histogram(
df_model_grp["normalized_final_step"], bins=np.linspace(0, 1, 50)
)
cumulative_population = np.cumsum(hist)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
fig.add_trace(
go.Scatter(
x=bin_centers[:-1],
y=(cumulative_population[-1] - cumulative_population[:-1]) / 120 * 100,
mode="lines",
line=dict(color=color_map.get(method)),
name=str(method),
showlegend=False,
),
row=1,
col=1,
)
fig.update_xaxes(title_text="Normalized time", row=1, col=1, range=[0, 1])
fig.update_yaxes(title_text="Valid runs (%)", row=1, col=1)
fig.update_xaxes(type="log", row=1, col=2, title_text="Number of atoms")
fig.update_yaxes(type="log", row=1, col=2, title_text="Steps per second")
fig.update_layout(height=520, width=1000)
return fig
@st.cache_data
def build_npt_figure(
df_in: pd.DataFrame, color_map: dict, show_scatter: bool
) -> go.Figure:
"""Build subplot: NPT valid runs (cumulative) + speed scaling plot."""
fig = make_subplots(
rows=1,
cols=2,
column_widths=[0.4, 0.6],
subplot_titles=("Valid runs", "Inference speed: steps/s vs N"),
)
# Right panel: speed scaling
speed_fig = build_speed_figure(df_in, color_map, show_scatter)
for trace in speed_fig.data:
fig.add_trace(trace, row=1, col=2)
# Left panel: cumulative valid runs
for method, df_model in df_in.groupby("method"):
df_model_grp = df_model.drop_duplicates(["formula"])
hist, bin_edges = np.histogram(
df_model_grp["normalized_final_step"], bins=np.linspace(0, 1, 50)
)
cumulative_population = np.cumsum(hist)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
fig.add_trace(
go.Scatter(
x=bin_centers[:-1],
y=(cumulative_population[-1] - cumulative_population[:-1]) / 80 * 100,
mode="lines",
line=dict(color=color_map.get(method)),
name=str(method),
showlegend=False,
),
row=1,
col=1,
)
fig.update_xaxes(title_text="Normalized time", row=1, col=1, range=[0, 1])
fig.update_yaxes(title_text="Valid runs (%)", row=1, col=1)
fig.update_xaxes(type="log", row=1, col=2, title_text="Number of atoms")
fig.update_yaxes(type="log", row=1, col=2, title_text="Steps per second")
fig.update_layout(height=520, width=1000)
return fig
if df_nvt.empty and df_npt.empty:
st.info("No data available to display for selected models.")
else:
st.markdown("""
## Heating
Isochoric-isothermal (NVT) MD simulations on RM24 structures, with temperature ramp from 300K to 3000K over 10 ps.
""")
show_scatter_nvt = st.toggle(
"Show scatter points", key="show_scatter_nvt", value=True
)
# Toggle for scatter points
# show_scatter = vis.checkbox("Show scatter points", value=True)
st.plotly_chart(
build_nvt_figure(df_nvt, method_color_mapping, show_scatter_nvt),
use_container_width=True,
)
st.markdown("""
## Compression
Isothermal-isobaric (NPT) MD simulations on RM24 structures, with pressure ramp from 0 GPa to 500 GPa and temperature ramp from 300K to 3000K over 10 ps.
""")
show_scatter_npt = st.toggle(
"Show scatter points", key="show_scatter_npt", value=True
)
# Toggle for scatter points
# show_scatter = vis.checkbox("Show scatter points", value=True)
st.plotly_chart(
build_npt_figure(df_npt, method_color_mapping, show_scatter_npt),
use_container_width=True,
)