|
from typing import Any, Optional |
|
|
|
import numpy as np |
|
from PIL import Image |
|
from tqdm.auto import tqdm |
|
from ultralytics.engine.results import Results |
|
from ultralytics.models.yolo.pose import PosePredictor |
|
from ultralytics.utils.plotting import Annotator |
|
|
|
import wandb |
|
from wandb.integration.ultralytics.bbox_utils import ( |
|
get_boxes, |
|
get_ground_truth_bbox_annotations, |
|
) |
|
|
|
|
|
def annotate_keypoint_results(result: Results, visualize_skeleton: bool): |
|
annotator = Annotator(np.ascontiguousarray(result.orig_img[:, :, ::-1])) |
|
key_points = result.keypoints.data.numpy() |
|
for idx in range(key_points.shape[0]): |
|
annotator.kpts(key_points[idx], kpt_line=visualize_skeleton) |
|
return annotator.im |
|
|
|
|
|
def annotate_keypoint_batch(image_path: str, keypoints: Any, visualize_skeleton: bool): |
|
with Image.open(image_path) as original_image: |
|
original_image = np.ascontiguousarray(original_image) |
|
annotator = Annotator(original_image) |
|
annotator.kpts(keypoints.numpy(), kpt_line=visualize_skeleton) |
|
return annotator.im |
|
|
|
|
|
def plot_pose_predictions( |
|
result: Results, |
|
model_name: str, |
|
visualize_skeleton: bool, |
|
table: Optional[wandb.Table] = None, |
|
): |
|
result = result.to("cpu") |
|
boxes, mean_confidence_map = get_boxes(result) |
|
annotated_image = annotate_keypoint_results(result, visualize_skeleton) |
|
prediction_image = wandb.Image(annotated_image, boxes=boxes) |
|
table_row = [ |
|
model_name, |
|
prediction_image, |
|
len(boxes["predictions"]["box_data"]), |
|
mean_confidence_map, |
|
result.speed, |
|
] |
|
if table is not None: |
|
table.add_data(*table_row) |
|
return table |
|
return table_row |
|
|
|
|
|
def plot_pose_validation_results( |
|
dataloader, |
|
class_label_map, |
|
model_name: str, |
|
predictor: PosePredictor, |
|
visualize_skeleton: bool, |
|
table: wandb.Table, |
|
max_validation_batches: int, |
|
epoch: Optional[int] = None, |
|
) -> wandb.Table: |
|
data_idx = 0 |
|
num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size |
|
max_validation_batches = min(max_validation_batches, num_dataloader_batches) |
|
for batch_idx, batch in enumerate(dataloader): |
|
prediction_results = predictor(batch["im_file"]) |
|
progress_bar_result_iterable = tqdm( |
|
enumerate(prediction_results), |
|
total=len(prediction_results), |
|
desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}", |
|
) |
|
for img_idx, prediction_result in progress_bar_result_iterable: |
|
prediction_result = prediction_result.to("cpu") |
|
table_row = plot_pose_predictions( |
|
prediction_result, model_name, visualize_skeleton |
|
) |
|
ground_truth_image = wandb.Image( |
|
annotate_keypoint_batch( |
|
batch["im_file"][img_idx], |
|
batch["keypoints"][img_idx], |
|
visualize_skeleton, |
|
), |
|
boxes={ |
|
"ground-truth": { |
|
"box_data": get_ground_truth_bbox_annotations( |
|
img_idx, batch["im_file"][img_idx], batch, class_label_map |
|
), |
|
"class_labels": class_label_map, |
|
}, |
|
}, |
|
) |
|
table_row = [data_idx, batch_idx, ground_truth_image] + table_row[1:] |
|
table_row = [epoch] + table_row if epoch is not None else table_row |
|
table_row = [model_name] + table_row |
|
table.add_data(*table_row) |
|
data_idx += 1 |
|
if batch_idx + 1 == max_validation_batches: |
|
break |
|
return table |
|
|