MMaDA
/
venv
/lib
/python3.11
/site-packages
/wandb
/integration
/ultralytics
/classification_utils.py
from typing import Any, Optional | |
import numpy as np | |
from tqdm.auto import tqdm | |
from ultralytics.engine.results import Results | |
from ultralytics.models.yolo.classify import ClassificationPredictor | |
import wandb | |
def plot_classification_predictions( | |
result: Results, | |
model_name: str, | |
table: Optional[wandb.Table] = None, | |
original_image: Optional[np.array] = None, | |
): | |
"""Plot classification prediction results to a `wandb.Table` if the table is passed otherwise return the data.""" | |
result = result.to("cpu") | |
probabilities = result.probs | |
probabilities_list = probabilities.data.numpy().tolist() | |
class_id_to_label = {int(k): str(v) for k, v in result.names.items()} | |
original_image = ( | |
wandb.Image(original_image) | |
if original_image is not None | |
else wandb.Image(result.orig_img) | |
) | |
table_row = [ | |
model_name, | |
original_image, | |
class_id_to_label[int(probabilities.top1)], | |
probabilities.top1conf, | |
[class_id_to_label[int(class_idx)] for class_idx in list(probabilities.top5)], | |
[probabilities_list[int(class_idx)] for class_idx in list(probabilities.top5)], | |
{ | |
class_id_to_label[int(class_idx)]: probability | |
for class_idx, probability in enumerate(probabilities_list) | |
}, | |
result.speed, | |
] | |
if table is not None: | |
table.add_data(*table_row) | |
return table | |
return class_id_to_label, table_row | |
def plot_classification_validation_results( | |
dataloader: Any, | |
model_name: str, | |
predictor: ClassificationPredictor, | |
table: wandb.Table, | |
max_validation_batches: int, | |
epoch: Optional[int] = None, | |
) -> wandb.Table: | |
"""Plot classification results to a `wandb.Table`.""" | |
data_idx = 0 | |
num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size | |
max_validation_batches = min(max_validation_batches, num_dataloader_batches) | |
for batch_idx, batch in enumerate(dataloader): | |
image_batch = [ | |
image for image in np.transpose(batch["img"].numpy(), (0, 2, 3, 1)) | |
] | |
ground_truth = batch["cls"].numpy().tolist() | |
progress_bar_result_iterable = tqdm( | |
range(max_validation_batches), | |
desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}", | |
) | |
for img_idx in progress_bar_result_iterable: | |
try: | |
prediction_result = predictor(image_batch[img_idx])[0] | |
class_id_to_label, table_row = plot_classification_predictions( | |
prediction_result, model_name, original_image=image_batch[img_idx] | |
) | |
table_row = [data_idx, batch_idx] + table_row[1:] | |
table_row.insert(3, class_id_to_label[ground_truth[img_idx]]) | |
table_row = [epoch] + table_row if epoch is not None else table_row | |
table_row = [model_name] + table_row | |
table.add_data(*table_row) | |
data_idx += 1 | |
except Exception: | |
pass | |
if batch_idx + 1 == max_validation_batches: | |
break | |
return table | |