File size: 3,709 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from typing import Any, Optional
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from ultralytics.engine.results import Results
from ultralytics.models.yolo.pose import PosePredictor
from ultralytics.utils.plotting import Annotator
import wandb
from wandb.integration.ultralytics.bbox_utils import (
get_boxes,
get_ground_truth_bbox_annotations,
)
def annotate_keypoint_results(result: Results, visualize_skeleton: bool):
annotator = Annotator(np.ascontiguousarray(result.orig_img[:, :, ::-1]))
key_points = result.keypoints.data.numpy()
for idx in range(key_points.shape[0]):
annotator.kpts(key_points[idx], kpt_line=visualize_skeleton)
return annotator.im
def annotate_keypoint_batch(image_path: str, keypoints: Any, visualize_skeleton: bool):
with Image.open(image_path) as original_image:
original_image = np.ascontiguousarray(original_image)
annotator = Annotator(original_image)
annotator.kpts(keypoints.numpy(), kpt_line=visualize_skeleton)
return annotator.im
def plot_pose_predictions(
result: Results,
model_name: str,
visualize_skeleton: bool,
table: Optional[wandb.Table] = None,
):
result = result.to("cpu")
boxes, mean_confidence_map = get_boxes(result)
annotated_image = annotate_keypoint_results(result, visualize_skeleton)
prediction_image = wandb.Image(annotated_image, boxes=boxes)
table_row = [
model_name,
prediction_image,
len(boxes["predictions"]["box_data"]),
mean_confidence_map,
result.speed,
]
if table is not None:
table.add_data(*table_row)
return table
return table_row
def plot_pose_validation_results(
dataloader,
class_label_map,
model_name: str,
predictor: PosePredictor,
visualize_skeleton: bool,
table: wandb.Table,
max_validation_batches: int,
epoch: Optional[int] = None,
) -> wandb.Table:
data_idx = 0
num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size
max_validation_batches = min(max_validation_batches, num_dataloader_batches)
for batch_idx, batch in enumerate(dataloader):
prediction_results = predictor(batch["im_file"])
progress_bar_result_iterable = tqdm(
enumerate(prediction_results),
total=len(prediction_results),
desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}",
)
for img_idx, prediction_result in progress_bar_result_iterable:
prediction_result = prediction_result.to("cpu")
table_row = plot_pose_predictions(
prediction_result, model_name, visualize_skeleton
)
ground_truth_image = wandb.Image(
annotate_keypoint_batch(
batch["im_file"][img_idx],
batch["keypoints"][img_idx],
visualize_skeleton,
),
boxes={
"ground-truth": {
"box_data": get_ground_truth_bbox_annotations(
img_idx, batch["im_file"][img_idx], batch, class_label_map
),
"class_labels": class_label_map,
},
},
)
table_row = [data_idx, batch_idx, ground_truth_image] + table_row[1:]
table_row = [epoch] + table_row if epoch is not None else table_row
table_row = [model_name] + table_row
table.add_data(*table_row)
data_idx += 1
if batch_idx + 1 == max_validation_batches:
break
return table
|