File size: 3,709 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import Any, Optional

import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from ultralytics.engine.results import Results
from ultralytics.models.yolo.pose import PosePredictor
from ultralytics.utils.plotting import Annotator

import wandb
from wandb.integration.ultralytics.bbox_utils import (
    get_boxes,
    get_ground_truth_bbox_annotations,
)


def annotate_keypoint_results(result: Results, visualize_skeleton: bool):
    annotator = Annotator(np.ascontiguousarray(result.orig_img[:, :, ::-1]))
    key_points = result.keypoints.data.numpy()
    for idx in range(key_points.shape[0]):
        annotator.kpts(key_points[idx], kpt_line=visualize_skeleton)
    return annotator.im


def annotate_keypoint_batch(image_path: str, keypoints: Any, visualize_skeleton: bool):
    with Image.open(image_path) as original_image:
        original_image = np.ascontiguousarray(original_image)
        annotator = Annotator(original_image)
        annotator.kpts(keypoints.numpy(), kpt_line=visualize_skeleton)
        return annotator.im


def plot_pose_predictions(
    result: Results,
    model_name: str,
    visualize_skeleton: bool,
    table: Optional[wandb.Table] = None,
):
    result = result.to("cpu")
    boxes, mean_confidence_map = get_boxes(result)
    annotated_image = annotate_keypoint_results(result, visualize_skeleton)
    prediction_image = wandb.Image(annotated_image, boxes=boxes)
    table_row = [
        model_name,
        prediction_image,
        len(boxes["predictions"]["box_data"]),
        mean_confidence_map,
        result.speed,
    ]
    if table is not None:
        table.add_data(*table_row)
        return table
    return table_row


def plot_pose_validation_results(
    dataloader,
    class_label_map,
    model_name: str,
    predictor: PosePredictor,
    visualize_skeleton: bool,
    table: wandb.Table,
    max_validation_batches: int,
    epoch: Optional[int] = None,
) -> wandb.Table:
    data_idx = 0
    num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size
    max_validation_batches = min(max_validation_batches, num_dataloader_batches)
    for batch_idx, batch in enumerate(dataloader):
        prediction_results = predictor(batch["im_file"])
        progress_bar_result_iterable = tqdm(
            enumerate(prediction_results),
            total=len(prediction_results),
            desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}",
        )
        for img_idx, prediction_result in progress_bar_result_iterable:
            prediction_result = prediction_result.to("cpu")
            table_row = plot_pose_predictions(
                prediction_result, model_name, visualize_skeleton
            )
            ground_truth_image = wandb.Image(
                annotate_keypoint_batch(
                    batch["im_file"][img_idx],
                    batch["keypoints"][img_idx],
                    visualize_skeleton,
                ),
                boxes={
                    "ground-truth": {
                        "box_data": get_ground_truth_bbox_annotations(
                            img_idx, batch["im_file"][img_idx], batch, class_label_map
                        ),
                        "class_labels": class_label_map,
                    },
                },
            )
            table_row = [data_idx, batch_idx, ground_truth_image] + table_row[1:]
            table_row = [epoch] + table_row if epoch is not None else table_row
            table_row = [model_name] + table_row
            table.add_data(*table_row)
            data_idx += 1
        if batch_idx + 1 == max_validation_batches:
            break
    return table