File size: 3,168 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from typing import Any, Optional
import numpy as np
from tqdm.auto import tqdm
from ultralytics.engine.results import Results
from ultralytics.models.yolo.classify import ClassificationPredictor
import wandb
def plot_classification_predictions(
result: Results,
model_name: str,
table: Optional[wandb.Table] = None,
original_image: Optional[np.array] = None,
):
"""Plot classification prediction results to a `wandb.Table` if the table is passed otherwise return the data."""
result = result.to("cpu")
probabilities = result.probs
probabilities_list = probabilities.data.numpy().tolist()
class_id_to_label = {int(k): str(v) for k, v in result.names.items()}
original_image = (
wandb.Image(original_image)
if original_image is not None
else wandb.Image(result.orig_img)
)
table_row = [
model_name,
original_image,
class_id_to_label[int(probabilities.top1)],
probabilities.top1conf,
[class_id_to_label[int(class_idx)] for class_idx in list(probabilities.top5)],
[probabilities_list[int(class_idx)] for class_idx in list(probabilities.top5)],
{
class_id_to_label[int(class_idx)]: probability
for class_idx, probability in enumerate(probabilities_list)
},
result.speed,
]
if table is not None:
table.add_data(*table_row)
return table
return class_id_to_label, table_row
def plot_classification_validation_results(
dataloader: Any,
model_name: str,
predictor: ClassificationPredictor,
table: wandb.Table,
max_validation_batches: int,
epoch: Optional[int] = None,
) -> wandb.Table:
"""Plot classification results to a `wandb.Table`."""
data_idx = 0
num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size
max_validation_batches = min(max_validation_batches, num_dataloader_batches)
for batch_idx, batch in enumerate(dataloader):
image_batch = [
image for image in np.transpose(batch["img"].numpy(), (0, 2, 3, 1))
]
ground_truth = batch["cls"].numpy().tolist()
progress_bar_result_iterable = tqdm(
range(max_validation_batches),
desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}",
)
for img_idx in progress_bar_result_iterable:
try:
prediction_result = predictor(image_batch[img_idx])[0]
class_id_to_label, table_row = plot_classification_predictions(
prediction_result, model_name, original_image=image_batch[img_idx]
)
table_row = [data_idx, batch_idx] + table_row[1:]
table_row.insert(3, class_id_to_label[ground_truth[img_idx]])
table_row = [epoch] + table_row if epoch is not None else table_row
table_row = [model_name] + table_row
table.add_data(*table_row)
data_idx += 1
except Exception:
pass
if batch_idx + 1 == max_validation_batches:
break
return table
|