File size: 6,700 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence, TypeVar
import wandb
from wandb import util
from wandb.plot.custom_chart import plot_table
if TYPE_CHECKING:
from wandb.plot.custom_chart import CustomChart
T = TypeVar("T")
def confusion_matrix(
probs: Sequence[Sequence[float]] | None = None,
y_true: Sequence[T] | None = None,
preds: Sequence[T] | None = None,
class_names: Sequence[str] | None = None,
title: str = "Confusion Matrix Curve",
split_table: bool = False,
) -> CustomChart:
"""Constructs a confusion matrix from a sequence of probabilities or predictions.
Args:
probs (Sequence[Sequence[float]] | None): A sequence of predicted probabilities for each
class. The sequence shape should be (N, K) where N is the number of samples
and K is the number of classes. If provided, `preds` should not be provided.
y_true (Sequence[T] | None): A sequence of true labels.
preds (Sequence[T] | None): A sequence of predicted class labels. If provided,
`probs` should not be provided.
class_names (Sequence[str] | None): Sequence of class names. If not
provided, class names will be defined as "Class_1", "Class_2", etc.
title (str): Title of the confusion matrix chart.
split_table (bool): Whether the table should be split into a separate section
in the W&B UI. If `True`, the table will be displayed in a section named
"Custom Chart Tables". Default is `False`.
Returns:
CustomChart: A custom chart object that can be logged to W&B. To log the
chart, pass it to `wandb.log()`.
Raises:
ValueError: If both `probs` and `preds` are provided or if the number of
predictions and true labels are not equal. If the number of unique
predicted classes exceeds the number of class names or if the number of
unique true labels exceeds the number of class names.
wandb.Error: If numpy is not installed.
Examples:
1. Logging a confusion matrix with random probabilities for wildlife
classification:
```
import numpy as np
import wandb
# Define class names for wildlife
wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]
# Generate random true labels (0 to 3 for 10 samples)
wildlife_y_true = np.random.randint(0, 4, size=10)
# Generate random probabilities for each class (10 samples x 4 classes)
wildlife_probs = np.random.rand(10, 4)
wildlife_probs = np.exp(wildlife_probs) / np.sum(
np.exp(wildlife_probs),
axis=1,
keepdims=True,
)
# Initialize W&B run and log confusion matrix
with wandb.init(project="wildlife_classification") as run:
confusion_matrix = wandb.plot.confusion_matrix(
probs=wildlife_probs,
y_true=wildlife_y_true,
class_names=wildlife_class_names,
title="Wildlife Classification Confusion Matrix",
)
run.log({"wildlife_confusion_matrix": confusion_matrix})
```
In this example, random probabilities are used to generate a confusion
matrix.
2. Logging a confusion matrix with simulated model predictions and 85%
accuracy:
```
import numpy as np
import wandb
# Define class names for wildlife
wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]
# Simulate true labels for 200 animal images (imbalanced distribution)
wildlife_y_true = np.random.choice(
[0, 1, 2, 3],
size=200,
p=[0.2, 0.3, 0.25, 0.25],
)
# Simulate model predictions with 85% accuracy
wildlife_preds = [
y_t
if np.random.rand() < 0.85
else np.random.choice([x for x in range(4) if x != y_t])
for y_t in wildlife_y_true
]
# Initialize W&B run and log confusion matrix
with wandb.init(project="wildlife_classification") as run:
confusion_matrix = wandb.plot.confusion_matrix(
preds=wildlife_preds,
y_true=wildlife_y_true,
class_names=wildlife_class_names,
title="Simulated Wildlife Classification Confusion Matrix"
)
run.log({"wildlife_confusion_matrix": confusion_matrix})
```
In this example, predictions are simulated with 85% accuracy to generate a
confusion matrix.
"""
np = util.get_module(
"numpy",
required=(
"numpy is required to use wandb.plot.confusion_matrix, "
"install with `pip install numpy`",
),
)
if probs is not None and preds is not None:
raise ValueError("Only one of `probs` or `preds` should be provided, not both.")
if probs is not None:
preds = np.argmax(probs, axis=1).tolist()
if len(preds) != len(y_true):
raise ValueError("The number of predictions and true labels must be equal.")
if class_names is not None:
n_classes = len(class_names)
class_idx = list(range(n_classes))
if len(set(preds)) > len(class_names):
raise ValueError(
"The number of unique predicted classes exceeds the number of class names."
)
if len(set(y_true)) > len(class_names):
raise ValueError(
"The number of unique true labels exceeds the number of class names."
)
else:
class_idx = set(preds).union(set(y_true))
n_classes = len(class_idx)
class_names = [f"Class_{i + 1}" for i in range(n_classes)]
# Create a mapping from class name to index
class_mapping = {val: i for i, val in enumerate(sorted(list(class_idx)))}
counts = np.zeros((n_classes, n_classes))
for i in range(len(preds)):
counts[class_mapping[y_true[i]], class_mapping[preds[i]]] += 1
data = [
[class_names[i], class_names[j], counts[i, j]]
for i in range(n_classes)
for j in range(n_classes)
]
return plot_table(
data_table=wandb.Table(
columns=["Actual", "Predicted", "nPredictions"],
data=data,
),
vega_spec_name="wandb/confusion_matrix/v1",
fields={
"Actual": "Actual",
"Predicted": "Predicted",
"nPredictions": "nPredictions",
},
string_fields={"title": title},
split_table=split_table,
)
|