File size: 5,824 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from __future__ import annotations

import numbers
from typing import TYPE_CHECKING, Sequence

import wandb
from wandb import util
from wandb.plot.custom_chart import plot_table
from wandb.plot.utils import test_missing, test_types

if TYPE_CHECKING:
    from wandb.plot.custom_chart import CustomChart


def roc_curve(
    y_true: Sequence[numbers.Number],
    y_probas: Sequence[Sequence[float]] | None = None,
    labels: list[str] | None = None,
    classes_to_plot: list[numbers.Number] | None = None,
    title: str = "ROC Curve",
    split_table: bool = False,
) -> CustomChart:
    """Constructs Receiver Operating Characteristic (ROC) curve chart.

    Args:
        y_true (Sequence[numbers.Number]): The true class labels (ground truth)
            for the target variable. Shape should be (num_samples,).
        y_probas (Sequence[Sequence[float]]): The predicted probabilities or
            decision scores for each class. Shape should be (num_samples, num_classes).
        labels (list[str]): Human-readable labels corresponding to the class
            indices in `y_true`. For example, if `labels=['dog', 'cat']`,
            class 0 will be displayed as 'dog' and class 1 as 'cat' in the plot.
            If None, the raw class indices from `y_true` will be used.
            Default is None.
        classes_to_plot (list[numbers.Number]): A subset of unique class labels
            to include in the ROC curve. If None, all classes in `y_true` will
            be plotted. Default is None.
        title (str): Title of the ROC curve plot. Default is "ROC Curve".
        split_table (bool): Whether the table should be split into a separate
            section in the W&B UI. If `True`, the table will be displayed in a
            section named "Custom Chart Tables". Default is `False`.

    Returns:
        CustomChart: A custom chart object that can be logged to W&B. To log the
            chart, pass it to `wandb.log()`.

    Raises:
        wandb.Error: If numpy, pandas, or scikit-learn are not found.

    Example:
        ```
        import numpy as np
        import wandb

        # Simulate a medical diagnosis classification problem with three diseases
        n_samples = 200
        n_classes = 3

        # True labels: assign "Diabetes", "Hypertension", or "Heart Disease" to
        # each sample
        disease_labels = ["Diabetes", "Hypertension", "Heart Disease"]
        # 0: Diabetes, 1: Hypertension, 2: Heart Disease
        y_true = np.random.choice([0, 1, 2], size=n_samples)

        # Predicted probabilities: simulate predictions, ensuring they sum to 1
        # for each sample
        y_probas = np.random.dirichlet(np.ones(n_classes), size=n_samples)

        # Specify classes to plot (plotting all three diseases)
        classes_to_plot = [0, 1, 2]

        # Initialize a W&B run and log a ROC curve plot for disease classification
        with wandb.init(project="medical_diagnosis") as run:
            roc_plot = wandb.plot.roc_curve(
                y_true=y_true,
                y_probas=y_probas,
                labels=disease_labels,
                classes_to_plot=classes_to_plot,
                title="ROC Curve for Disease Classification",
            )
            run.log({"roc-curve": roc_plot})
        ```
    """
    np = util.get_module(
        "numpy",
        required="roc requires the numpy library, install with `pip install numpy`",
    )
    pd = util.get_module(
        "pandas",
        required="roc requires the pandas library, install with `pip install pandas`",
    )
    sklearn_metrics = util.get_module(
        "sklearn.metrics",
        "roc requires the scikit library, install with `pip install scikit-learn`",
    )
    sklearn_utils = util.get_module(
        "sklearn.utils",
        "roc requires the scikit library, install with `pip install scikit-learn`",
    )

    y_true = np.array(y_true)
    y_probas = np.array(y_probas)

    if not test_missing(y_true=y_true, y_probas=y_probas):
        return
    if not test_types(y_true=y_true, y_probas=y_probas):
        return

    classes = np.unique(y_true)
    if classes_to_plot is None:
        classes_to_plot = classes

    fpr = {}
    tpr = {}
    indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0]
    for i in indices_to_plot:
        if labels is not None and (
            isinstance(classes[i], int) or isinstance(classes[0], np.integer)
        ):
            class_label = labels[classes[i]]
        else:
            class_label = classes[i]

        fpr[class_label], tpr[class_label], _ = sklearn_metrics.roc_curve(
            y_true, y_probas[..., i], pos_label=classes[i]
        )

    df = pd.DataFrame(
        {
            "class": np.hstack([[k] * len(v) for k, v in fpr.items()]),
            "fpr": np.hstack(list(fpr.values())),
            "tpr": np.hstack(list(tpr.values())),
        }
    ).round(3)

    if len(df) > wandb.Table.MAX_ROWS:
        wandb.termwarn(
            f"wandb uses only {wandb.Table.MAX_ROWS} data points to create the plots."
        )
        # different sampling could be applied, possibly to ensure endpoints are kept
        df = sklearn_utils.resample(
            df,
            replace=False,
            n_samples=wandb.Table.MAX_ROWS,
            random_state=42,
            stratify=df["class"],
        ).sort_values(["fpr", "tpr", "class"])

    return plot_table(
        data_table=wandb.Table(dataframe=df),
        vega_spec_name="wandb/area-under-curve/v0",
        fields={
            "x": "fpr",
            "y": "tpr",
            "class": "class",
        },
        string_fields={
            "title": title,
            "x-axis-title": "False positive rate",
            "y-axis-title": "True positive rate",
        },
        split_table=split_table,
    )