File size: 14,461 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator, Sequence
from itertools import product
from math import ceil, floor, sqrt
from typing import Any, List, Optional, Union, no_type_check

import numpy as np
import torch
from torch import Tensor

from torchmetrics.utilities.imports import _LATEX_AVAILABLE, _MATPLOTLIB_AVAILABLE, _SCIENCEPLOT_AVAILABLE

if _MATPLOTLIB_AVAILABLE:
    import matplotlib
    import matplotlib.axes
    import matplotlib.pyplot as plt

    _PLOT_OUT_TYPE = tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]]
    _AX_TYPE = matplotlib.axes.Axes
    _CMAP_TYPE = Union[matplotlib.colors.Colormap, str]

    style_change = plt.style.context
else:
    _PLOT_OUT_TYPE = tuple[object, object]  # type: ignore[misc]
    _AX_TYPE = object
    _CMAP_TYPE = object  # type: ignore[misc]

    from contextlib import contextmanager

    @contextmanager
    def style_change(*args: Any, **kwargs: Any) -> Generator:
        """No-ops decorator if matplotlib is not installed."""
        yield


if _SCIENCEPLOT_AVAILABLE:
    import scienceplots  # noqa: F401

    _style = ["science", "no-latex"]

_style = ["science"] if _SCIENCEPLOT_AVAILABLE and _LATEX_AVAILABLE else ["default"]


def _error_on_missing_matplotlib() -> None:
    """Raise error if matplotlib is not installed."""
    if not _MATPLOTLIB_AVAILABLE:
        raise ModuleNotFoundError(
            "Plot function expects `matplotlib` to be installed. Please install with `pip install matplotlib`"
        )


@style_change(_style)
def plot_single_or_multi_val(
    val: Union[Tensor, Sequence[Tensor], dict[str, Tensor], Sequence[dict[str, Tensor]]],
    ax: Optional[_AX_TYPE] = None,  # type: ignore[valid-type]
    higher_is_better: Optional[bool] = None,
    lower_bound: Optional[float] = None,
    upper_bound: Optional[float] = None,
    legend_name: Optional[str] = None,
    name: Optional[str] = None,
) -> _PLOT_OUT_TYPE:
    """Plot a single metric value or multiple, including bounds of value if existing.

    Args:
        val: A single tensor with one or multiple values (multiclass/label/output format) or a list of such tensors.
            If a list is provided the values are interpreted as a time series of evolving values.
        ax: Axis from a figure.
        higher_is_better: Indicates if a label indicating where the optimal value it should be added to the figure
        lower_bound: lower value that the metric can take
        upper_bound: upper value that the metric can take
        legend_name: for class based metrics specify the legend prefix e.g. Class or Label to use when multiple values
            are provided
        name: Name of the metric to use for the y-axis label

    Returns:
        A tuple consisting of the figure and respective ax objects of the generated figure

    Raises:
        ModuleNotFoundError:
            If `matplotlib` is not installed

    """
    _error_on_missing_matplotlib()
    fig, ax = plt.subplots() if ax is None else (None, ax)
    ax.get_xaxis().set_visible(False)

    if isinstance(val, Tensor):
        if val.numel() == 1:
            ax.plot([val.detach().cpu()], marker="o", markersize=10)
        else:
            for i, v in enumerate(val):
                label = f"{legend_name} {i}" if legend_name else f"{i}"
                ax.plot(i, v.detach().cpu(), marker="o", markersize=10, linestyle="None", label=label)
    elif isinstance(val, dict):
        for i, (k, v) in enumerate(val.items()):
            if v.numel() != 1:
                ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=k)
                ax.get_xaxis().set_visible(True)
                ax.set_xlabel("Step")
                ax.set_xticks(torch.arange(len(v)))
            else:
                ax.plot(i, v.detach().cpu(), marker="o", markersize=10, label=k)
    elif isinstance(val, Sequence):
        n_steps = len(val)
        if isinstance(val[0], dict):
            val = {k: torch.stack([val[i][k] for i in range(n_steps)]) for k in val[0]}  # type: ignore
            for k, v in val.items():
                ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=k)
        else:
            val = torch.stack(val, 0)  # type: ignore
            multi_series = val.ndim != 1
            val = val.T if multi_series else val.unsqueeze(0)
            for i, v in enumerate(val):
                label = (f"{legend_name} {i}" if legend_name else f"{i}") if multi_series else ""
                ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=label)
        ax.get_xaxis().set_visible(True)
        ax.set_xlabel("Step")
        ax.set_xticks(torch.arange(n_steps))
    else:
        raise ValueError("Got unknown format for argument `val`.")

    handles, labels = ax.get_legend_handles_labels()
    if handles and labels:
        ax.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.15), ncol=3, fancybox=True, shadow=True)

    ylim = ax.get_ylim()
    if lower_bound is not None and upper_bound is not None:
        factor = 0.1 * (upper_bound - lower_bound)
    else:
        factor = 0.1 * (ylim[1] - ylim[0])

    ax.set_ylim(
        bottom=lower_bound - factor if lower_bound is not None else ylim[0] - factor,
        top=upper_bound + factor if upper_bound is not None else ylim[1] + factor,
    )

    ax.grid(True)
    ax.set_ylabel(name if name is not None else None)

    xlim = ax.get_xlim()
    factor = 0.1 * (xlim[1] - xlim[0])

    y_lines = []
    if lower_bound is not None:
        y_lines.append(lower_bound)
    if upper_bound is not None:
        y_lines.append(upper_bound)
    ax.hlines(y_lines, xlim[0], xlim[1], linestyles="dashed", colors="k")
    if higher_is_better is not None:
        if lower_bound is not None and not higher_is_better:
            ax.set_xlim(xlim[0] - factor, xlim[1])
            ax.text(
                xlim[0], lower_bound, s="Optimal \n value", horizontalalignment="center", verticalalignment="center"
            )
        if upper_bound is not None and higher_is_better:
            ax.set_xlim(xlim[0] - factor, xlim[1])
            ax.text(
                xlim[0], upper_bound, s="Optimal \n value", horizontalalignment="center", verticalalignment="center"
            )
    return fig, ax


def _get_col_row_split(n: int) -> tuple[int, int]:
    """Split `n` figures into `rows` x `cols` figures."""
    nsq = sqrt(n)
    if int(nsq) == nsq:  # square number
        return int(nsq), int(nsq)
    if floor(nsq) * ceil(nsq) >= n:
        return floor(nsq), ceil(nsq)
    return ceil(nsq), ceil(nsq)


def _get_text_color(patch_color: tuple[float, float, float, float]) -> str:
    """Get the text color for a given value and colormap.

    Following Wikipedia's recommendations: https://en.wikipedia.org/wiki/Relative_luminance.

    Args:
        patch_color: RGBA color tuple

    """
    # Convert to linear color space
    r, g, b, a = patch_color
    r, g, b = (c / 12.92 if c <= 0.04045 else ((c + 0.055) / 1.055) ** 2.4 for c in (r, g, b))

    # Get the relative luminance
    y = 0.2126 * r + 0.7152 * g + 0.0722 * b

    return ".1" if y > 0.4 else "white"


def trim_axs(axs: Union[_AX_TYPE, np.ndarray], nb: int) -> Union[np.ndarray, _AX_TYPE]:  # type: ignore[valid-type]
    """Reduce `axs` to `nb` Axes.

    All further Axes are removed from the figure.

    """
    if isinstance(axs, _AX_TYPE):
        return axs

    axs = axs.flat  # type: ignore[union-attr]
    for ax in axs[nb:]:
        ax.remove()
    return axs[:nb]


@style_change(_style)
@no_type_check
def plot_confusion_matrix(
    confmat: Tensor,
    ax: Optional[_AX_TYPE] = None,
    add_text: bool = True,
    labels: Optional[list[Union[int, str]]] = None,
    cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
    """Plot an confusion matrix.

    Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/confusion_matrix.py.
    Works for both binary, multiclass and multilabel confusion matrices.

    Args:
        confmat: the confusion matrix. Either should be an [N,N] matrix in the binary and multiclass cases or an
            [N, 2, 2] matrix for multilabel classification
        ax: Axis from a figure. If not provided, a new figure and axis will be created
        add_text: if text should be added to each cell with the given value
        labels: labels to add the x- and y-axis
        cmap: matplotlib colormap to use for the confusion matrix
            https://matplotlib.org/stable/users/explain/colors/colormaps.html

    Returns:
        A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure

    Raises:
        ModuleNotFoundError:
            If `matplotlib` is not installed

    """
    _error_on_missing_matplotlib()

    if confmat.ndim == 3:  # multilabel
        nb, n_classes = confmat.shape[0], 2
        rows, cols = _get_col_row_split(nb)
    else:
        nb, n_classes, rows, cols = 1, confmat.shape[0], 1, 1

    if labels is not None and confmat.ndim != 3 and len(labels) != n_classes:
        raise ValueError(
            "Expected number of elements in arg `labels` to match number of labels in confmat but "
            f"got {len(labels)} and {n_classes}"
        )
    if confmat.ndim == 3:
        fig_label = labels or np.arange(nb)
        labels = list(map(str, range(n_classes)))
    else:
        fig_label = None
        labels = labels or np.arange(n_classes).tolist()

    fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax)
    axs = trim_axs(axs, nb)
    for i in range(nb):
        ax = axs[i] if (rows != 1 or cols != 1) else axs
        if fig_label is not None:
            ax.set_title(f"Label {fig_label[i]}", fontsize=15)
        im = ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap)
        if i // cols == rows - 1:  # bottom row only
            ax.set_xlabel("Predicted class", fontsize=15)
        if i % cols == 0:  # leftmost column only
            ax.set_ylabel("True class", fontsize=15)
        ax.set_xticks(list(range(n_classes)))
        ax.set_yticks(list(range(n_classes)))
        ax.set_xticklabels(labels, rotation=45, fontsize=10)
        ax.set_yticklabels(labels, rotation=25, fontsize=10)

        if add_text:
            for ii, jj in product(range(n_classes), range(n_classes)):
                val = confmat[i, ii, jj] if confmat.ndim == 3 else confmat[ii, jj]
                patch_color = im.cmap(im.norm(val.item()))
                c = _get_text_color(patch_color)
                ax.text(jj, ii, str(round(val.item(), 2)), ha="center", va="center", fontsize=15, color=c)

    return fig, axs


@style_change(_style)
def plot_curve(
    curve: Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]],
    score: Optional[Tensor] = None,
    ax: Optional[_AX_TYPE] = None,  # type: ignore[valid-type]
    label_names: Optional[tuple[str, str]] = None,
    legend_name: Optional[str] = None,
    name: Optional[str] = None,
    labels: Optional[list[Union[int, str]]] = None,
) -> _PLOT_OUT_TYPE:
    """Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py.

    Plots a curve object

    Args:
        curve: a tuple of (x, y, t) where x and y are the coordinates of the curve and t are the thresholds used
            to compute the curve
        score: optional area under the curve added as label to the plot
        ax: Axis from a figure
        label_names: Tuple containing the names of the x and y axis
        legend_name: Name of the curve to be used in the legend
        name: Custom name to describe the metric
        labels: Optional labels for the different curves that will be added to the plot

    Returns:
        A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure

    Raises:
        ModuleNotFoundError:
            If `matplotlib` is not installed
        ValueError:
            If `curve` does not have 3 elements, being in the wrong format
    """
    if len(curve) < 2:
        raise ValueError(f"Expected 2 or 3 elements in curve but got {len(curve)}")
    x, y = curve[:2]

    _error_on_missing_matplotlib()
    fig, ax = plt.subplots() if ax is None else (None, ax)

    if isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 1 and y.ndim == 1:
        label = f"AUC={score.item():0.3f}" if score is not None else None
        ax.plot(x.detach().cpu(), y.detach().cpu(), linestyle="-", linewidth=2, label=label)
        if label is not None:
            ax.legend()
    elif (isinstance(x, list) and isinstance(y, list)) or (
        isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 2 and y.ndim == 2
    ):
        n_classes = len(x)
        if labels is not None and len(labels) != n_classes:
            raise ValueError(
                "Expected number of elements in arg `labels` to match number of labels in roc curves but "
                f"got {len(labels)} and {n_classes}"
            )

        for i, (x_, y_) in enumerate(zip(x, y)):
            label = f"{legend_name}_{i}" if legend_name is not None else str(i) if labels is None else str(labels[i])
            label += f" AUC={score[i].item():0.3f}" if score is not None else ""
            ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label)
            ax.legend()
    else:
        raise ValueError(
            f"Unknown format for argument `x` and `y`. Expected either list or tensors but got {type(x)} and {type(y)}."
        )
    if label_names is not None:
        ax.set_xlabel(label_names[0])
        ax.set_ylabel(label_names[1])
    ax.grid(True)
    ax.set_title(name)

    return fig, ax