# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Generator, Sequence from itertools import product from math import ceil, floor, sqrt from typing import Any, List, Optional, Union, no_type_check import numpy as np import torch from torch import Tensor from torchmetrics.utilities.imports import _LATEX_AVAILABLE, _MATPLOTLIB_AVAILABLE, _SCIENCEPLOT_AVAILABLE if _MATPLOTLIB_AVAILABLE: import matplotlib import matplotlib.axes import matplotlib.pyplot as plt _PLOT_OUT_TYPE = tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]] _AX_TYPE = matplotlib.axes.Axes _CMAP_TYPE = Union[matplotlib.colors.Colormap, str] style_change = plt.style.context else: _PLOT_OUT_TYPE = tuple[object, object] # type: ignore[misc] _AX_TYPE = object _CMAP_TYPE = object # type: ignore[misc] from contextlib import contextmanager @contextmanager def style_change(*args: Any, **kwargs: Any) -> Generator: """No-ops decorator if matplotlib is not installed.""" yield if _SCIENCEPLOT_AVAILABLE: import scienceplots # noqa: F401 _style = ["science", "no-latex"] _style = ["science"] if _SCIENCEPLOT_AVAILABLE and _LATEX_AVAILABLE else ["default"] def _error_on_missing_matplotlib() -> None: """Raise error if matplotlib is not installed.""" if not _MATPLOTLIB_AVAILABLE: raise ModuleNotFoundError( "Plot function expects `matplotlib` to be installed. Please install with `pip install matplotlib`" ) @style_change(_style) def plot_single_or_multi_val( val: Union[Tensor, Sequence[Tensor], dict[str, Tensor], Sequence[dict[str, Tensor]]], ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type] higher_is_better: Optional[bool] = None, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None, legend_name: Optional[str] = None, name: Optional[str] = None, ) -> _PLOT_OUT_TYPE: """Plot a single metric value or multiple, including bounds of value if existing. Args: val: A single tensor with one or multiple values (multiclass/label/output format) or a list of such tensors. If a list is provided the values are interpreted as a time series of evolving values. ax: Axis from a figure. higher_is_better: Indicates if a label indicating where the optimal value it should be added to the figure lower_bound: lower value that the metric can take upper_bound: upper value that the metric can take legend_name: for class based metrics specify the legend prefix e.g. Class or Label to use when multiple values are provided name: Name of the metric to use for the y-axis label Returns: A tuple consisting of the figure and respective ax objects of the generated figure Raises: ModuleNotFoundError: If `matplotlib` is not installed """ _error_on_missing_matplotlib() fig, ax = plt.subplots() if ax is None else (None, ax) ax.get_xaxis().set_visible(False) if isinstance(val, Tensor): if val.numel() == 1: ax.plot([val.detach().cpu()], marker="o", markersize=10) else: for i, v in enumerate(val): label = f"{legend_name} {i}" if legend_name else f"{i}" ax.plot(i, v.detach().cpu(), marker="o", markersize=10, linestyle="None", label=label) elif isinstance(val, dict): for i, (k, v) in enumerate(val.items()): if v.numel() != 1: ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=k) ax.get_xaxis().set_visible(True) ax.set_xlabel("Step") ax.set_xticks(torch.arange(len(v))) else: ax.plot(i, v.detach().cpu(), marker="o", markersize=10, label=k) elif isinstance(val, Sequence): n_steps = len(val) if isinstance(val[0], dict): val = {k: torch.stack([val[i][k] for i in range(n_steps)]) for k in val[0]} # type: ignore for k, v in val.items(): ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=k) else: val = torch.stack(val, 0) # type: ignore multi_series = val.ndim != 1 val = val.T if multi_series else val.unsqueeze(0) for i, v in enumerate(val): label = (f"{legend_name} {i}" if legend_name else f"{i}") if multi_series else "" ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=label) ax.get_xaxis().set_visible(True) ax.set_xlabel("Step") ax.set_xticks(torch.arange(n_steps)) else: raise ValueError("Got unknown format for argument `val`.") handles, labels = ax.get_legend_handles_labels() if handles and labels: ax.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.15), ncol=3, fancybox=True, shadow=True) ylim = ax.get_ylim() if lower_bound is not None and upper_bound is not None: factor = 0.1 * (upper_bound - lower_bound) else: factor = 0.1 * (ylim[1] - ylim[0]) ax.set_ylim( bottom=lower_bound - factor if lower_bound is not None else ylim[0] - factor, top=upper_bound + factor if upper_bound is not None else ylim[1] + factor, ) ax.grid(True) ax.set_ylabel(name if name is not None else None) xlim = ax.get_xlim() factor = 0.1 * (xlim[1] - xlim[0]) y_lines = [] if lower_bound is not None: y_lines.append(lower_bound) if upper_bound is not None: y_lines.append(upper_bound) ax.hlines(y_lines, xlim[0], xlim[1], linestyles="dashed", colors="k") if higher_is_better is not None: if lower_bound is not None and not higher_is_better: ax.set_xlim(xlim[0] - factor, xlim[1]) ax.text( xlim[0], lower_bound, s="Optimal \n value", horizontalalignment="center", verticalalignment="center" ) if upper_bound is not None and higher_is_better: ax.set_xlim(xlim[0] - factor, xlim[1]) ax.text( xlim[0], upper_bound, s="Optimal \n value", horizontalalignment="center", verticalalignment="center" ) return fig, ax def _get_col_row_split(n: int) -> tuple[int, int]: """Split `n` figures into `rows` x `cols` figures.""" nsq = sqrt(n) if int(nsq) == nsq: # square number return int(nsq), int(nsq) if floor(nsq) * ceil(nsq) >= n: return floor(nsq), ceil(nsq) return ceil(nsq), ceil(nsq) def _get_text_color(patch_color: tuple[float, float, float, float]) -> str: """Get the text color for a given value and colormap. Following Wikipedia's recommendations: https://en.wikipedia.org/wiki/Relative_luminance. Args: patch_color: RGBA color tuple """ # Convert to linear color space r, g, b, a = patch_color r, g, b = (c / 12.92 if c <= 0.04045 else ((c + 0.055) / 1.055) ** 2.4 for c in (r, g, b)) # Get the relative luminance y = 0.2126 * r + 0.7152 * g + 0.0722 * b return ".1" if y > 0.4 else "white" def trim_axs(axs: Union[_AX_TYPE, np.ndarray], nb: int) -> Union[np.ndarray, _AX_TYPE]: # type: ignore[valid-type] """Reduce `axs` to `nb` Axes. All further Axes are removed from the figure. """ if isinstance(axs, _AX_TYPE): return axs axs = axs.flat # type: ignore[union-attr] for ax in axs[nb:]: ax.remove() return axs[:nb] @style_change(_style) @no_type_check def plot_confusion_matrix( confmat: Tensor, ax: Optional[_AX_TYPE] = None, add_text: bool = True, labels: Optional[list[Union[int, str]]] = None, cmap: Optional[_CMAP_TYPE] = None, ) -> _PLOT_OUT_TYPE: """Plot an confusion matrix. Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/confusion_matrix.py. Works for both binary, multiclass and multilabel confusion matrices. Args: confmat: the confusion matrix. Either should be an [N,N] matrix in the binary and multiclass cases or an [N, 2, 2] matrix for multilabel classification ax: Axis from a figure. If not provided, a new figure and axis will be created add_text: if text should be added to each cell with the given value labels: labels to add the x- and y-axis cmap: matplotlib colormap to use for the confusion matrix https://matplotlib.org/stable/users/explain/colors/colormaps.html Returns: A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure Raises: ModuleNotFoundError: If `matplotlib` is not installed """ _error_on_missing_matplotlib() if confmat.ndim == 3: # multilabel nb, n_classes = confmat.shape[0], 2 rows, cols = _get_col_row_split(nb) else: nb, n_classes, rows, cols = 1, confmat.shape[0], 1, 1 if labels is not None and confmat.ndim != 3 and len(labels) != n_classes: raise ValueError( "Expected number of elements in arg `labels` to match number of labels in confmat but " f"got {len(labels)} and {n_classes}" ) if confmat.ndim == 3: fig_label = labels or np.arange(nb) labels = list(map(str, range(n_classes))) else: fig_label = None labels = labels or np.arange(n_classes).tolist() fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax) axs = trim_axs(axs, nb) for i in range(nb): ax = axs[i] if (rows != 1 or cols != 1) else axs if fig_label is not None: ax.set_title(f"Label {fig_label[i]}", fontsize=15) im = ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap) if i // cols == rows - 1: # bottom row only ax.set_xlabel("Predicted class", fontsize=15) if i % cols == 0: # leftmost column only ax.set_ylabel("True class", fontsize=15) ax.set_xticks(list(range(n_classes))) ax.set_yticks(list(range(n_classes))) ax.set_xticklabels(labels, rotation=45, fontsize=10) ax.set_yticklabels(labels, rotation=25, fontsize=10) if add_text: for ii, jj in product(range(n_classes), range(n_classes)): val = confmat[i, ii, jj] if confmat.ndim == 3 else confmat[ii, jj] patch_color = im.cmap(im.norm(val.item())) c = _get_text_color(patch_color) ax.text(jj, ii, str(round(val.item(), 2)), ha="center", va="center", fontsize=15, color=c) return fig, axs @style_change(_style) def plot_curve( curve: Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]], score: Optional[Tensor] = None, ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type] label_names: Optional[tuple[str, str]] = None, legend_name: Optional[str] = None, name: Optional[str] = None, labels: Optional[list[Union[int, str]]] = None, ) -> _PLOT_OUT_TYPE: """Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py. Plots a curve object Args: curve: a tuple of (x, y, t) where x and y are the coordinates of the curve and t are the thresholds used to compute the curve score: optional area under the curve added as label to the plot ax: Axis from a figure label_names: Tuple containing the names of the x and y axis legend_name: Name of the curve to be used in the legend name: Custom name to describe the metric labels: Optional labels for the different curves that will be added to the plot Returns: A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure Raises: ModuleNotFoundError: If `matplotlib` is not installed ValueError: If `curve` does not have 3 elements, being in the wrong format """ if len(curve) < 2: raise ValueError(f"Expected 2 or 3 elements in curve but got {len(curve)}") x, y = curve[:2] _error_on_missing_matplotlib() fig, ax = plt.subplots() if ax is None else (None, ax) if isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 1 and y.ndim == 1: label = f"AUC={score.item():0.3f}" if score is not None else None ax.plot(x.detach().cpu(), y.detach().cpu(), linestyle="-", linewidth=2, label=label) if label is not None: ax.legend() elif (isinstance(x, list) and isinstance(y, list)) or ( isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 2 and y.ndim == 2 ): n_classes = len(x) if labels is not None and len(labels) != n_classes: raise ValueError( "Expected number of elements in arg `labels` to match number of labels in roc curves but " f"got {len(labels)} and {n_classes}" ) for i, (x_, y_) in enumerate(zip(x, y)): label = f"{legend_name}_{i}" if legend_name is not None else str(i) if labels is None else str(labels[i]) label += f" AUC={score[i].item():0.3f}" if score is not None else "" ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label) ax.legend() else: raise ValueError( f"Unknown format for argument `x` and `y`. Expected either list or tensors but got {type(x)} and {type(y)}." ) if label_names is not None: ax.set_xlabel(label_names[0]) ax.set_ylabel(label_names[1]) ax.grid(True) ax.set_title(name) return fig, ax