jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator, Sequence
from itertools import product
from math import ceil, floor, sqrt
from typing import Any, List, Optional, Union, no_type_check
import numpy as np
import torch
from torch import Tensor
from torchmetrics.utilities.imports import _LATEX_AVAILABLE, _MATPLOTLIB_AVAILABLE, _SCIENCEPLOT_AVAILABLE
if _MATPLOTLIB_AVAILABLE:
import matplotlib
import matplotlib.axes
import matplotlib.pyplot as plt
_PLOT_OUT_TYPE = tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]]
_AX_TYPE = matplotlib.axes.Axes
_CMAP_TYPE = Union[matplotlib.colors.Colormap, str]
style_change = plt.style.context
else:
_PLOT_OUT_TYPE = tuple[object, object] # type: ignore[misc]
_AX_TYPE = object
_CMAP_TYPE = object # type: ignore[misc]
from contextlib import contextmanager
@contextmanager
def style_change(*args: Any, **kwargs: Any) -> Generator:
"""No-ops decorator if matplotlib is not installed."""
yield
if _SCIENCEPLOT_AVAILABLE:
import scienceplots # noqa: F401
_style = ["science", "no-latex"]
_style = ["science"] if _SCIENCEPLOT_AVAILABLE and _LATEX_AVAILABLE else ["default"]
def _error_on_missing_matplotlib() -> None:
"""Raise error if matplotlib is not installed."""
if not _MATPLOTLIB_AVAILABLE:
raise ModuleNotFoundError(
"Plot function expects `matplotlib` to be installed. Please install with `pip install matplotlib`"
)
@style_change(_style)
def plot_single_or_multi_val(
val: Union[Tensor, Sequence[Tensor], dict[str, Tensor], Sequence[dict[str, Tensor]]],
ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type]
higher_is_better: Optional[bool] = None,
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
legend_name: Optional[str] = None,
name: Optional[str] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single metric value or multiple, including bounds of value if existing.
Args:
val: A single tensor with one or multiple values (multiclass/label/output format) or a list of such tensors.
If a list is provided the values are interpreted as a time series of evolving values.
ax: Axis from a figure.
higher_is_better: Indicates if a label indicating where the optimal value it should be added to the figure
lower_bound: lower value that the metric can take
upper_bound: upper value that the metric can take
legend_name: for class based metrics specify the legend prefix e.g. Class or Label to use when multiple values
are provided
name: Name of the metric to use for the y-axis label
Returns:
A tuple consisting of the figure and respective ax objects of the generated figure
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
"""
_error_on_missing_matplotlib()
fig, ax = plt.subplots() if ax is None else (None, ax)
ax.get_xaxis().set_visible(False)
if isinstance(val, Tensor):
if val.numel() == 1:
ax.plot([val.detach().cpu()], marker="o", markersize=10)
else:
for i, v in enumerate(val):
label = f"{legend_name} {i}" if legend_name else f"{i}"
ax.plot(i, v.detach().cpu(), marker="o", markersize=10, linestyle="None", label=label)
elif isinstance(val, dict):
for i, (k, v) in enumerate(val.items()):
if v.numel() != 1:
ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=k)
ax.get_xaxis().set_visible(True)
ax.set_xlabel("Step")
ax.set_xticks(torch.arange(len(v)))
else:
ax.plot(i, v.detach().cpu(), marker="o", markersize=10, label=k)
elif isinstance(val, Sequence):
n_steps = len(val)
if isinstance(val[0], dict):
val = {k: torch.stack([val[i][k] for i in range(n_steps)]) for k in val[0]} # type: ignore
for k, v in val.items():
ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=k)
else:
val = torch.stack(val, 0) # type: ignore
multi_series = val.ndim != 1
val = val.T if multi_series else val.unsqueeze(0)
for i, v in enumerate(val):
label = (f"{legend_name} {i}" if legend_name else f"{i}") if multi_series else ""
ax.plot(v.detach().cpu(), marker="o", markersize=10, linestyle="-", label=label)
ax.get_xaxis().set_visible(True)
ax.set_xlabel("Step")
ax.set_xticks(torch.arange(n_steps))
else:
raise ValueError("Got unknown format for argument `val`.")
handles, labels = ax.get_legend_handles_labels()
if handles and labels:
ax.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.15), ncol=3, fancybox=True, shadow=True)
ylim = ax.get_ylim()
if lower_bound is not None and upper_bound is not None:
factor = 0.1 * (upper_bound - lower_bound)
else:
factor = 0.1 * (ylim[1] - ylim[0])
ax.set_ylim(
bottom=lower_bound - factor if lower_bound is not None else ylim[0] - factor,
top=upper_bound + factor if upper_bound is not None else ylim[1] + factor,
)
ax.grid(True)
ax.set_ylabel(name if name is not None else None)
xlim = ax.get_xlim()
factor = 0.1 * (xlim[1] - xlim[0])
y_lines = []
if lower_bound is not None:
y_lines.append(lower_bound)
if upper_bound is not None:
y_lines.append(upper_bound)
ax.hlines(y_lines, xlim[0], xlim[1], linestyles="dashed", colors="k")
if higher_is_better is not None:
if lower_bound is not None and not higher_is_better:
ax.set_xlim(xlim[0] - factor, xlim[1])
ax.text(
xlim[0], lower_bound, s="Optimal \n value", horizontalalignment="center", verticalalignment="center"
)
if upper_bound is not None and higher_is_better:
ax.set_xlim(xlim[0] - factor, xlim[1])
ax.text(
xlim[0], upper_bound, s="Optimal \n value", horizontalalignment="center", verticalalignment="center"
)
return fig, ax
def _get_col_row_split(n: int) -> tuple[int, int]:
"""Split `n` figures into `rows` x `cols` figures."""
nsq = sqrt(n)
if int(nsq) == nsq: # square number
return int(nsq), int(nsq)
if floor(nsq) * ceil(nsq) >= n:
return floor(nsq), ceil(nsq)
return ceil(nsq), ceil(nsq)
def _get_text_color(patch_color: tuple[float, float, float, float]) -> str:
"""Get the text color for a given value and colormap.
Following Wikipedia's recommendations: https://en.wikipedia.org/wiki/Relative_luminance.
Args:
patch_color: RGBA color tuple
"""
# Convert to linear color space
r, g, b, a = patch_color
r, g, b = (c / 12.92 if c <= 0.04045 else ((c + 0.055) / 1.055) ** 2.4 for c in (r, g, b))
# Get the relative luminance
y = 0.2126 * r + 0.7152 * g + 0.0722 * b
return ".1" if y > 0.4 else "white"
def trim_axs(axs: Union[_AX_TYPE, np.ndarray], nb: int) -> Union[np.ndarray, _AX_TYPE]: # type: ignore[valid-type]
"""Reduce `axs` to `nb` Axes.
All further Axes are removed from the figure.
"""
if isinstance(axs, _AX_TYPE):
return axs
axs = axs.flat # type: ignore[union-attr]
for ax in axs[nb:]:
ax.remove()
return axs[:nb]
@style_change(_style)
@no_type_check
def plot_confusion_matrix(
confmat: Tensor,
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[list[Union[int, str]]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot an confusion matrix.
Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/confusion_matrix.py.
Works for both binary, multiclass and multilabel confusion matrices.
Args:
confmat: the confusion matrix. Either should be an [N,N] matrix in the binary and multiclass cases or an
[N, 2, 2] matrix for multilabel classification
ax: Axis from a figure. If not provided, a new figure and axis will be created
add_text: if text should be added to each cell with the given value
labels: labels to add the x- and y-axis
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html
Returns:
A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
"""
_error_on_missing_matplotlib()
if confmat.ndim == 3: # multilabel
nb, n_classes = confmat.shape[0], 2
rows, cols = _get_col_row_split(nb)
else:
nb, n_classes, rows, cols = 1, confmat.shape[0], 1, 1
if labels is not None and confmat.ndim != 3 and len(labels) != n_classes:
raise ValueError(
"Expected number of elements in arg `labels` to match number of labels in confmat but "
f"got {len(labels)} and {n_classes}"
)
if confmat.ndim == 3:
fig_label = labels or np.arange(nb)
labels = list(map(str, range(n_classes)))
else:
fig_label = None
labels = labels or np.arange(n_classes).tolist()
fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax)
axs = trim_axs(axs, nb)
for i in range(nb):
ax = axs[i] if (rows != 1 or cols != 1) else axs
if fig_label is not None:
ax.set_title(f"Label {fig_label[i]}", fontsize=15)
im = ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap)
if i // cols == rows - 1: # bottom row only
ax.set_xlabel("Predicted class", fontsize=15)
if i % cols == 0: # leftmost column only
ax.set_ylabel("True class", fontsize=15)
ax.set_xticks(list(range(n_classes)))
ax.set_yticks(list(range(n_classes)))
ax.set_xticklabels(labels, rotation=45, fontsize=10)
ax.set_yticklabels(labels, rotation=25, fontsize=10)
if add_text:
for ii, jj in product(range(n_classes), range(n_classes)):
val = confmat[i, ii, jj] if confmat.ndim == 3 else confmat[ii, jj]
patch_color = im.cmap(im.norm(val.item()))
c = _get_text_color(patch_color)
ax.text(jj, ii, str(round(val.item(), 2)), ha="center", va="center", fontsize=15, color=c)
return fig, axs
@style_change(_style)
def plot_curve(
curve: Union[tuple[Tensor, Tensor, Tensor], tuple[List[Tensor], List[Tensor], List[Tensor]]],
score: Optional[Tensor] = None,
ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type]
label_names: Optional[tuple[str, str]] = None,
legend_name: Optional[str] = None,
name: Optional[str] = None,
labels: Optional[list[Union[int, str]]] = None,
) -> _PLOT_OUT_TYPE:
"""Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py.
Plots a curve object
Args:
curve: a tuple of (x, y, t) where x and y are the coordinates of the curve and t are the thresholds used
to compute the curve
score: optional area under the curve added as label to the plot
ax: Axis from a figure
label_names: Tuple containing the names of the x and y axis
legend_name: Name of the curve to be used in the legend
name: Custom name to describe the metric
labels: Optional labels for the different curves that will be added to the plot
Returns:
A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
ValueError:
If `curve` does not have 3 elements, being in the wrong format
"""
if len(curve) < 2:
raise ValueError(f"Expected 2 or 3 elements in curve but got {len(curve)}")
x, y = curve[:2]
_error_on_missing_matplotlib()
fig, ax = plt.subplots() if ax is None else (None, ax)
if isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 1 and y.ndim == 1:
label = f"AUC={score.item():0.3f}" if score is not None else None
ax.plot(x.detach().cpu(), y.detach().cpu(), linestyle="-", linewidth=2, label=label)
if label is not None:
ax.legend()
elif (isinstance(x, list) and isinstance(y, list)) or (
isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 2 and y.ndim == 2
):
n_classes = len(x)
if labels is not None and len(labels) != n_classes:
raise ValueError(
"Expected number of elements in arg `labels` to match number of labels in roc curves but "
f"got {len(labels)} and {n_classes}"
)
for i, (x_, y_) in enumerate(zip(x, y)):
label = f"{legend_name}_{i}" if legend_name is not None else str(i) if labels is None else str(labels[i])
label += f" AUC={score[i].item():0.3f}" if score is not None else ""
ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label)
ax.legend()
else:
raise ValueError(
f"Unknown format for argument `x` and `y`. Expected either list or tensors but got {type(x)} and {type(y)}."
)
if label_names is not None:
ax.set_xlabel(label_names[0])
ax.set_ylabel(label_names[1])
ax.grid(True)
ax.set_title(name)
return fig, ax