{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import pandas as pd\n", "import plotly.colors as pcolors\n", "import seaborn as sns\n", "from matplotlib import pyplot as plt\n", "\n", "from mlip_arena.models import MLIPEnum\n", "\n", "mlip_methods = [\n", " model.name\n", " for model in MLIPEnum\n", "]\n", "\n", "all_attributes = dir(pcolors.qualitative)\n", "color_palettes = {\n", " attr: getattr(pcolors.qualitative, attr)\n", " for attr in all_attributes\n", " if isinstance(getattr(pcolors.qualitative, attr), list)\n", "}\n", "color_palettes.pop(\"__all__\", None)\n", "\n", "palette_names = list(color_palettes.keys())\n", "palette_colors = list(color_palettes.values())\n", "palette_name = \"Plotly\"\n", "color_sequence = color_palettes[palette_name] # type: ignore\n", "\n", "method_color_mapping = {\n", " method: color_sequence[i % len(color_sequence)]\n", " for i, method in enumerate(mlip_methods)\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "import numpy as np\n", "\n", "from mlip_arena.models import MLIPEnum\n", "\n", "# Color mapping by class\n", "color_mapping = {\n", " \"DAC\": \"#e41a1c\",\n", " \"Flue Gas\": \"#377eb8\",\n", " \"General\": \"#4daf4a\"\n", "}\n", "\n", "# Decision boundary thresholds\n", "thresholds = {\n", " \"General\": (None, 35),\n", " \"Flue Gas\": (35, 50),\n", " \"DAC\": (50, 100)\n", "}\n", "\n", "# Collect data from all models\n", "all_data = []\n", "margins = []\n", "\n", "for model in MLIPEnum:\n", " fpath = Path(f\"{model.name}.pkl\")\n", " if not fpath.exists():\n", " continue\n", "\n", " df = pd.read_pickle(fpath)\n", " df = df.drop_duplicates(subset=[\"model\", \"name\", \"class\"], keep=\"last\")\n", " df_exploded = df.explode([\"henry_coefficient\", \"averaged_interaction_energy\", \"heat_of_adsorption\"])\n", " df_group = df_exploded.groupby([\"model\", \"name\", \"class\"])[[\"henry_coefficient\", \"averaged_interaction_energy\", \"heat_of_adsorption\"]].mean().reset_index()\n", "\n", " df_group[\"model_name\"] = model.name\n", " df_group[\"neg_heat\"] = -df_group[\"heat_of_adsorption\"] # negate for log scale\n", " df_group = df_group[df_group[\"neg_heat\"] > 0] # remove invalid values\n", "\n", " df_group = df_group[df_group[\"name\"] != \"MIL-96-Al\"]\n", "\n", " all_data.append(df_group)\n", "\n", " # Compute misclassification margin\n", " def point_misclassified(row):\n", " val = row[\"neg_heat\"]\n", " lower, upper = thresholds[row[\"class\"]]\n", " return (lower is not None and val < lower) or (upper is not None and val >= upper)\n", "\n", " misclassified = df_group[df_group.apply(point_misclassified, axis=1)]\n", "\n", " def distance_to_boundary(row):\n", " val = row[\"neg_heat\"]\n", " lower, upper = thresholds[row[\"class\"]]\n", " distances = []\n", " if lower is not None:\n", " distances.append(abs(val - lower))\n", " if upper is not None:\n", " distances.append(abs(val - upper))\n", " return min(distances)\n", "\n", " if not misclassified.empty:\n", " num_misclassified = len(misclassified) + (18 - len(df_group))\n", " margin = misclassified.apply(distance_to_boundary, axis=1).mean()\n", " else:\n", " num_misclassified = 0\n", " margin = 0.0\n", "\n", " margins.append((model.name, margin, num_misclassified))\n", "\n", "\n", "# Combine all into one DataFrame\n", "combined_df = pd.concat(all_data, ignore_index=True)\n", "margins_df = pd.DataFrame(margins, columns=[\"model_name\", \"misclassification_margin\", \"num_misclassified\"])\n", "\n", "# --- Plotting ---\n", "\n", "with plt.style.context(\"default\"):\n", "\n", " LARGE_SIZE = 10\n", " MEDIUM_SIZE = 8\n", " SMALL_SIZE = 6\n", "\n", " plt.rcParams.update({\n", " \"font.size\": SMALL_SIZE,\n", " \"axes.titlesize\": MEDIUM_SIZE,\n", " \"axes.labelsize\": MEDIUM_SIZE,\n", " \"xtick.labelsize\": SMALL_SIZE,\n", " \"ytick.labelsize\": SMALL_SIZE,\n", " \"legend.fontsize\": SMALL_SIZE,\n", " \"figure.titlesize\": LARGE_SIZE,\n", " })\n", "\n", " fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 4), sharex=False, gridspec_kw={\"height_ratios\": [3, 1.5]})\n", "\n", " # --- Main Stripplot ---\n", " sns.stripplot(\n", " data=combined_df,\n", " x=\"neg_heat\",\n", " y=\"model_name\",\n", " hue=\"class\",\n", " size=2,\n", " palette=color_mapping,\n", " dodge=True,\n", " jitter=0.1,\n", " alpha=1,\n", " ax=ax1,\n", " )\n", "\n", " xmin, xmax = ax1.get_xlim()\n", "\n", " ax1.axvspan(xmin, 35, color=color_mapping[\"General\"], alpha=0.1, label=\"General\")\n", " ax1.axvspan(35, 50, color=color_mapping[\"Flue Gas\"], alpha=0.1, label=\"Flue Gas\")\n", " ax1.axvspan(50, 100, color=color_mapping[\"DAC\"], alpha=0.1, label=\"DAC\")\n", "\n", " ax1.axvline(x=35, linestyle=\"--\", color=\"gray\", label=\"Exp. $\\\\mathregular{CO_2}$ $Q_\\\\text{st}$ = 35 kJ/mol\")\n", " ax1.axvline(x=50, linestyle=\"--\", color=\"gray\", label=\"Exp. $\\\\mathregular{CO_2}$ $Q_\\\\text{st}$ = 50 kJ/mol\")\n", " ax1.axvline(x=100, linestyle=\"--\", color=\"gray\", label=\"Exp. $\\\\mathregular{CO_2}$ $Q_\\\\text{st}$ = 100 kJ/mol\")\n", "\n", " ax1.set_xscale(\"log\")\n", " ax1.set_xlabel(\"Heat of $\\\\mathregular{CO_2}$ Adsorption $Q_\\\\text{st}$ [kJ/mol]\")\n", " ax1.set_ylabel(\"\")\n", " ax1.set_xlim(xmin, xmax)\n", "\n", " yticks = ax1.get_yticks()\n", " yticks = np.array(yticks)\n", " yticks = yticks[np.isfinite(yticks)] # Remove any NaNs\n", "\n", " # Draw horizontal lines between models (skip the last one)\n", " for y in yticks[:-1] + np.diff(yticks) / 2:\n", " ax1.axhline(y=y, color=\"gray\", linestyle=\":\", linewidth=0.7, alpha=0.5, zorder=0)\n", "\n", " handles, labels = ax1.get_legend_handles_labels()\n", " legend_dict = dict(zip(labels, handles, strict=False))\n", "\n", " desired_order = [\n", " \"General\", \"Exp. $\\\\mathregular{CO_2}$ $Q_\\\\text{st}$ = 35 kJ/mol\", \"Flue Gas\",\n", " \"Exp. $\\\\mathregular{CO_2}$ $Q_\\\\text{st}$ = 50 kJ/mol\", \"DAC\", \"Exp. $\\\\mathregular{CO_2}$ $Q_\\\\text{st}$ = 100 kJ/mol\"\n", " ]\n", "\n", " ordered_handles = [legend_dict[label] for label in desired_order if label in legend_dict]\n", "\n", " ax1.legend(\n", " ordered_handles,\n", " desired_order,\n", " loc=\"lower center\",\n", " bbox_to_anchor=(0.5, 1),\n", " ncol=3,\n", " frameon=True\n", " )\n", "\n", "\n", " ax1.spines[\"top\"].set_visible(False)\n", " ax1.spines[\"right\"].set_visible(False)\n", "\n", " # --- Misclassification Margin Barplot ---\n", "\n", " # Sort by error margin\n", " margins_df_sorted = margins_df.sort_values(by=\"misclassification_margin\", ascending=True)\n", "\n", " # Extract color values in order\n", " bar_colors = [method_color_mapping[m] for m in margins_df_sorted[\"model_name\"]]\n", "\n", " sns.scatterplot(\n", " data=margins_df_sorted,\n", " x=\"num_misclassified\",\n", " y=\"misclassification_margin\",\n", " hue=\"model_name\",\n", " palette=bar_colors,\n", " ax=ax2\n", " )\n", "\n", " for _, row in margins_df_sorted.iterrows():\n", " x = row[\"num_misclassified\"]\n", " y = row[\"misclassification_margin\"]\n", " model = row[\"model_name\"]\n", " color = bar_colors[margins_df_sorted[\"model_name\"].tolist().index(model)]\n", "\n", " ax2.text(\n", " x+0.1,\n", " y,\n", " f\"{y:.2f}\",\n", " fontsize=SMALL_SIZE,\n", " ha=\"left\",\n", " va=\"bottom\",\n", " color=color,\n", " alpha=0.9\n", " )\n", "\n", " ax2.set_ylabel(\"Misclass. margin [kJ/mol]\")\n", " ax2.set_xlabel(\"Missing + misclass. count\")\n", " ax2.spines[\"top\"].set_visible(False)\n", " ax2.spines[\"right\"].set_visible(False)\n", " # ax2.set_xticklabels(margins_df_sorted[\"model_name\"], rotation=45)\n", " ax2.set_yscale(\"log\")\n", "\n", " handles, labels = ax2.get_legend_handles_labels()\n", " legend_dict = dict(zip(labels, handles, strict=False))\n", " ax2.legend(\n", " legend_dict.values(),\n", " legend_dict.keys(),\n", " loc=\"upper left\",\n", " bbox_to_anchor=(0, 1),\n", " ncol=3,\n", " frameon=True\n", " )\n", "\n", " plt.tight_layout()\n", " plt.savefig(\"mof-misclassification_margin.pdf\", bbox_inches=\"tight\")\n", " plt.show()\n" ] } ], "metadata": { "kernelspec": { "display_name": "mlip-arena", "language": "python", "name": "mlip-arena" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 2 }