File size: 6,753 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from __future__ import annotations

import numbers
from typing import TYPE_CHECKING, Iterable, TypeVar

import wandb
from wandb import util
from wandb.plot.custom_chart import plot_table
from wandb.plot.utils import test_missing, test_types

if TYPE_CHECKING:
    from wandb.plot.custom_chart import CustomChart


T = TypeVar("T")


def pr_curve(
    y_true: Iterable[T] | None = None,
    y_probas: Iterable[numbers.Number] | None = None,
    labels: list[str] | None = None,
    classes_to_plot: list[T] | None = None,
    interp_size: int = 21,
    title: str = "Precision-Recall Curve",
    split_table: bool = False,
) -> CustomChart:
    """Constructs a Precision-Recall (PR) curve.

    The Precision-Recall curve is particularly useful for evaluating classifiers
    on imbalanced datasets. A high area under the PR curve signifies both high
    precision (a low false positive rate) and high recall (a low false negative
    rate). The curve provides insights into the balance between false positives
    and false negatives at various threshold levels, aiding in the assessment of
    a model's performance.

    Args:
        y_true (Iterable): True binary labels. The shape should be (`num_samples`,).
        y_probas (Iterable): Predicted scores or probabilities for each class.
            These can be probability estimates, confidence scores, or non-thresholded
            decision values. The shape should be (`num_samples`, `num_classes`).
        labels (list[str] | None): Optional list of class names to replace
            numeric values in `y_true` for easier plot interpretation.
            For example, `labels = ['dog', 'cat', 'owl']` will replace 0 with
            'dog', 1 with 'cat', and 2 with 'owl' in the plot. If not provided,
            numeric values from `y_true` will be used.
        classes_to_plot (list | None): Optional list of unique class values from
            y_true to be included in the plot. If not specified, all unique
            classes in y_true will be plotted.
        interp_size (int): Number of points to interpolate recall values. The
            recall values will be fixed to `interp_size` uniformly distributed
            points in the range [0, 1], and the precision will be interpolated
            accordingly.
        title (str): Title of the plot. Defaults to "Precision-Recall Curve".
        split_table (bool): Whether the table should be split into a separate section
            in the W&B UI. If `True`, the table will be displayed in a section named
            "Custom Chart Tables". Default is `False`.

    Returns:
        CustomChart: A custom chart object that can be logged to W&B. To log the
            chart, pass it to `wandb.log()`.

    Raises:
        wandb.Error: If numpy, pandas, or scikit-learn is not installed.


    Example:
        ```
        import wandb

        # Example for spam detection (binary classification)
        y_true = [0, 1, 1, 0, 1]  # 0 = not spam, 1 = spam
        y_probas = [
            [0.9, 0.1],  # Predicted probabilities for the first sample (not spam)
            [0.2, 0.8],  # Second sample (spam), and so on
            [0.1, 0.9],
            [0.8, 0.2],
            [0.3, 0.7],
        ]

        labels = ["not spam", "spam"]  # Optional class names for readability

        with wandb.init(project="spam-detection") as run:
            pr_curve = wandb.plot.pr_curve(
                y_true=y_true,
                y_probas=y_probas,
                labels=labels,
                title="Precision-Recall Curve for Spam Detection",
            )
            run.log({"pr-curve": pr_curve})
        ```
    """
    np = util.get_module(
        "numpy",
        required="roc requires the numpy library, install with `pip install numpy`",
    )
    pd = util.get_module(
        "pandas",
        required="roc requires the pandas library, install with `pip install pandas`",
    )
    sklearn_metrics = util.get_module(
        "sklearn.metrics",
        "roc requires the scikit library, install with `pip install scikit-learn`",
    )
    sklearn_utils = util.get_module(
        "sklearn.utils",
        "roc requires the scikit library, install with `pip install scikit-learn`",
    )

    def _step(x):
        y = np.array(x)
        for i in range(1, len(y)):
            y[i] = max(y[i], y[i - 1])
        return y

    y_true = np.array(y_true)
    y_probas = np.array(y_probas)

    if not test_missing(y_true=y_true, y_probas=y_probas):
        return
    if not test_types(y_true=y_true, y_probas=y_probas):
        return

    classes = np.unique(y_true)
    if classes_to_plot is None:
        classes_to_plot = classes

    precision = {}
    interp_recall = np.linspace(0, 1, interp_size)[::-1]
    indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0]
    for i in indices_to_plot:
        if labels is not None and (
            isinstance(classes[i], int) or isinstance(classes[0], np.integer)
        ):
            class_label = labels[classes[i]]
        else:
            class_label = classes[i]

        cur_precision, cur_recall, _ = sklearn_metrics.precision_recall_curve(
            y_true, y_probas[:, i], pos_label=classes[i]
        )
        # smooth the precision (monotonically increasing)
        cur_precision = _step(cur_precision)

        # reverse order so that recall in ascending
        cur_precision = cur_precision[::-1]
        cur_recall = cur_recall[::-1]
        indices = np.searchsorted(cur_recall, interp_recall, side="left")
        precision[class_label] = cur_precision[indices]

    df = pd.DataFrame(
        {
            "class": np.hstack([[k] * len(v) for k, v in precision.items()]),
            "precision": np.hstack(list(precision.values())),
            "recall": np.tile(interp_recall, len(precision)),
        }
    ).round(3)

    if len(df) > wandb.Table.MAX_ROWS:
        wandb.termwarn(
            f"Table has a limit of {wandb.Table.MAX_ROWS} rows. Resampling to fit."
        )
        # different sampling could be applied, possibly to ensure endpoints are kept
        df = sklearn_utils.resample(
            df,
            replace=False,
            n_samples=wandb.Table.MAX_ROWS,
            random_state=42,
            stratify=df["class"],
        ).sort_values(["precision", "recall", "class"])

    return plot_table(
        data_table=wandb.Table(dataframe=df),
        vega_spec_name="wandb/area-under-curve/v0",
        fields={
            "x": "recall",
            "y": "precision",
            "class": "class",
        },
        string_fields={
            "title": title,
            "x-axis-title": "Recall",
            "y-axis-title": "Precision",
        },
        split_table=split_table,
    )