File size: 5,824 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from __future__ import annotations
import numbers
from typing import TYPE_CHECKING, Sequence
import wandb
from wandb import util
from wandb.plot.custom_chart import plot_table
from wandb.plot.utils import test_missing, test_types
if TYPE_CHECKING:
from wandb.plot.custom_chart import CustomChart
def roc_curve(
y_true: Sequence[numbers.Number],
y_probas: Sequence[Sequence[float]] | None = None,
labels: list[str] | None = None,
classes_to_plot: list[numbers.Number] | None = None,
title: str = "ROC Curve",
split_table: bool = False,
) -> CustomChart:
"""Constructs Receiver Operating Characteristic (ROC) curve chart.
Args:
y_true (Sequence[numbers.Number]): The true class labels (ground truth)
for the target variable. Shape should be (num_samples,).
y_probas (Sequence[Sequence[float]]): The predicted probabilities or
decision scores for each class. Shape should be (num_samples, num_classes).
labels (list[str]): Human-readable labels corresponding to the class
indices in `y_true`. For example, if `labels=['dog', 'cat']`,
class 0 will be displayed as 'dog' and class 1 as 'cat' in the plot.
If None, the raw class indices from `y_true` will be used.
Default is None.
classes_to_plot (list[numbers.Number]): A subset of unique class labels
to include in the ROC curve. If None, all classes in `y_true` will
be plotted. Default is None.
title (str): Title of the ROC curve plot. Default is "ROC Curve".
split_table (bool): Whether the table should be split into a separate
section in the W&B UI. If `True`, the table will be displayed in a
section named "Custom Chart Tables". Default is `False`.
Returns:
CustomChart: A custom chart object that can be logged to W&B. To log the
chart, pass it to `wandb.log()`.
Raises:
wandb.Error: If numpy, pandas, or scikit-learn are not found.
Example:
```
import numpy as np
import wandb
# Simulate a medical diagnosis classification problem with three diseases
n_samples = 200
n_classes = 3
# True labels: assign "Diabetes", "Hypertension", or "Heart Disease" to
# each sample
disease_labels = ["Diabetes", "Hypertension", "Heart Disease"]
# 0: Diabetes, 1: Hypertension, 2: Heart Disease
y_true = np.random.choice([0, 1, 2], size=n_samples)
# Predicted probabilities: simulate predictions, ensuring they sum to 1
# for each sample
y_probas = np.random.dirichlet(np.ones(n_classes), size=n_samples)
# Specify classes to plot (plotting all three diseases)
classes_to_plot = [0, 1, 2]
# Initialize a W&B run and log a ROC curve plot for disease classification
with wandb.init(project="medical_diagnosis") as run:
roc_plot = wandb.plot.roc_curve(
y_true=y_true,
y_probas=y_probas,
labels=disease_labels,
classes_to_plot=classes_to_plot,
title="ROC Curve for Disease Classification",
)
run.log({"roc-curve": roc_plot})
```
"""
np = util.get_module(
"numpy",
required="roc requires the numpy library, install with `pip install numpy`",
)
pd = util.get_module(
"pandas",
required="roc requires the pandas library, install with `pip install pandas`",
)
sklearn_metrics = util.get_module(
"sklearn.metrics",
"roc requires the scikit library, install with `pip install scikit-learn`",
)
sklearn_utils = util.get_module(
"sklearn.utils",
"roc requires the scikit library, install with `pip install scikit-learn`",
)
y_true = np.array(y_true)
y_probas = np.array(y_probas)
if not test_missing(y_true=y_true, y_probas=y_probas):
return
if not test_types(y_true=y_true, y_probas=y_probas):
return
classes = np.unique(y_true)
if classes_to_plot is None:
classes_to_plot = classes
fpr = {}
tpr = {}
indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0]
for i in indices_to_plot:
if labels is not None and (
isinstance(classes[i], int) or isinstance(classes[0], np.integer)
):
class_label = labels[classes[i]]
else:
class_label = classes[i]
fpr[class_label], tpr[class_label], _ = sklearn_metrics.roc_curve(
y_true, y_probas[..., i], pos_label=classes[i]
)
df = pd.DataFrame(
{
"class": np.hstack([[k] * len(v) for k, v in fpr.items()]),
"fpr": np.hstack(list(fpr.values())),
"tpr": np.hstack(list(tpr.values())),
}
).round(3)
if len(df) > wandb.Table.MAX_ROWS:
wandb.termwarn(
f"wandb uses only {wandb.Table.MAX_ROWS} data points to create the plots."
)
# different sampling could be applied, possibly to ensure endpoints are kept
df = sklearn_utils.resample(
df,
replace=False,
n_samples=wandb.Table.MAX_ROWS,
random_state=42,
stratify=df["class"],
).sort_values(["fpr", "tpr", "class"])
return plot_table(
data_table=wandb.Table(dataframe=df),
vega_spec_name="wandb/area-under-curve/v0",
fields={
"x": "fpr",
"y": "tpr",
"class": "class",
},
string_fields={
"title": title,
"x-axis-title": "False positive rate",
"y-axis-title": "True positive rate",
},
split_table=split_table,
)
|