|
from typing import Dict, Optional, Tuple |
|
|
|
import cv2 |
|
import numpy as np |
|
from tqdm.auto import tqdm |
|
from ultralytics.engine.results import Results |
|
from ultralytics.models.yolo.segment import SegmentationPredictor |
|
from ultralytics.utils.ops import scale_image |
|
|
|
import wandb |
|
from wandb.integration.ultralytics.bbox_utils import ( |
|
get_ground_truth_bbox_annotations, |
|
get_mean_confidence_map, |
|
) |
|
|
|
|
|
def instance_mask_to_semantic_mask(instance_mask, class_indices): |
|
height, width, num_instances = instance_mask.shape |
|
semantic_mask = np.zeros((height, width), dtype=np.uint8) |
|
for i in range(num_instances): |
|
instance_map = instance_mask[:, :, i] |
|
class_index = class_indices[i] |
|
semantic_mask[instance_map == 1] = class_index |
|
return semantic_mask |
|
|
|
|
|
def get_boxes_and_masks(result: Results) -> Tuple[Dict, Dict, Dict]: |
|
boxes = result.boxes.xywh.long().numpy() |
|
classes = result.boxes.cls.long().numpy() |
|
confidence = result.boxes.conf.numpy() |
|
class_id_to_label = {int(k): str(v) for k, v in result.names.items()} |
|
class_id_to_label.update({len(result.names.items()): "background"}) |
|
mean_confidence_map = get_mean_confidence_map( |
|
classes, confidence, class_id_to_label |
|
) |
|
masks = None |
|
if result.masks is not None: |
|
scaled_instance_mask = scale_image( |
|
np.transpose(result.masks.data.numpy(), (1, 2, 0)), |
|
result.orig_img[:, :, ::-1].shape, |
|
) |
|
scaled_semantic_mask = instance_mask_to_semantic_mask( |
|
scaled_instance_mask, classes.tolist() |
|
) |
|
scaled_semantic_mask[scaled_semantic_mask == 0] = len(result.names.items()) |
|
masks = { |
|
"predictions": { |
|
"mask_data": scaled_semantic_mask, |
|
"class_labels": class_id_to_label, |
|
} |
|
} |
|
box_data, total_confidence = [], 0.0 |
|
for idx in range(len(boxes)): |
|
box_data.append( |
|
{ |
|
"position": { |
|
"middle": [int(boxes[idx][0]), int(boxes[idx][1])], |
|
"width": int(boxes[idx][2]), |
|
"height": int(boxes[idx][3]), |
|
}, |
|
"domain": "pixel", |
|
"class_id": int(classes[idx]), |
|
"box_caption": class_id_to_label[int(classes[idx])], |
|
"scores": {"confidence": float(confidence[idx])}, |
|
} |
|
) |
|
total_confidence += float(confidence[idx]) |
|
|
|
boxes = { |
|
"predictions": { |
|
"box_data": box_data, |
|
"class_labels": class_id_to_label, |
|
}, |
|
} |
|
return boxes, masks, mean_confidence_map |
|
|
|
|
|
def plot_mask_predictions( |
|
result: Results, model_name: str, table: Optional[wandb.Table] = None |
|
) -> Tuple[wandb.Image, Dict, Dict, Dict]: |
|
result = result.to("cpu") |
|
boxes, masks, mean_confidence_map = get_boxes_and_masks(result) |
|
image = wandb.Image(result.orig_img[:, :, ::-1], boxes=boxes, masks=masks) |
|
if table is not None: |
|
table.add_data( |
|
model_name, |
|
image, |
|
len(boxes["predictions"]["box_data"]), |
|
mean_confidence_map, |
|
result.speed, |
|
) |
|
return table |
|
return image, masks, boxes["predictions"], mean_confidence_map |
|
|
|
|
|
def structure_prompts_and_image(image: np.array, prompt: Dict) -> Dict: |
|
wb_box_data = [] |
|
if prompt["bboxes"] is not None: |
|
wb_box_data.append( |
|
{ |
|
"position": { |
|
"middle": [prompt["bboxes"][0], prompt["bboxes"][1]], |
|
"width": prompt["bboxes"][2], |
|
"height": prompt["bboxes"][3], |
|
}, |
|
"domain": "pixel", |
|
"class_id": 1, |
|
"box_caption": "Prompt-Box", |
|
} |
|
) |
|
if prompt["points"] is not None: |
|
image = image.copy().astype(np.uint8) |
|
image = cv2.circle( |
|
image, tuple(prompt["points"]), 5, (0, 255, 0), -1, lineType=cv2.LINE_AA |
|
) |
|
wb_box_data = { |
|
"prompts": { |
|
"box_data": wb_box_data, |
|
"class_labels": {1: "Prompt-Box"}, |
|
} |
|
} |
|
return image, wb_box_data |
|
|
|
|
|
def plot_sam_predictions( |
|
result: Results, prompt: Dict, table: wandb.Table |
|
) -> wandb.Table: |
|
result = result.to("cpu") |
|
image = result.orig_img[:, :, ::-1] |
|
image, wb_box_data = structure_prompts_and_image(image, prompt) |
|
image = wandb.Image( |
|
image, |
|
boxes=wb_box_data, |
|
masks={ |
|
"predictions": { |
|
"mask_data": np.squeeze(result.masks.data.cpu().numpy().astype(int)), |
|
"class_labels": {0: "Background", 1: "Prediction"}, |
|
} |
|
}, |
|
) |
|
table.add_data(image) |
|
return table |
|
|
|
|
|
def plot_segmentation_validation_results( |
|
dataloader, |
|
class_label_map, |
|
model_name: str, |
|
predictor: SegmentationPredictor, |
|
table: wandb.Table, |
|
max_validation_batches: int, |
|
epoch: Optional[int] = None, |
|
): |
|
data_idx = 0 |
|
num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size |
|
max_validation_batches = min(max_validation_batches, num_dataloader_batches) |
|
for batch_idx, batch in enumerate(dataloader): |
|
prediction_results = predictor(batch["im_file"]) |
|
progress_bar_result_iterable = tqdm( |
|
enumerate(prediction_results), |
|
total=len(prediction_results), |
|
desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}", |
|
) |
|
for img_idx, prediction_result in progress_bar_result_iterable: |
|
prediction_result = prediction_result.to("cpu") |
|
( |
|
_, |
|
prediction_mask_data, |
|
prediction_box_data, |
|
mean_confidence_map, |
|
) = plot_mask_predictions(prediction_result, model_name) |
|
try: |
|
ground_truth_data = get_ground_truth_bbox_annotations( |
|
img_idx, batch["im_file"][img_idx], batch, class_label_map |
|
) |
|
wandb_image = wandb.Image( |
|
batch["im_file"][img_idx], |
|
boxes={ |
|
"ground-truth": { |
|
"box_data": ground_truth_data, |
|
"class_labels": class_label_map, |
|
}, |
|
"predictions": prediction_box_data, |
|
}, |
|
masks=prediction_mask_data, |
|
) |
|
table_rows = [ |
|
data_idx, |
|
batch_idx, |
|
wandb_image, |
|
mean_confidence_map, |
|
prediction_result.speed, |
|
] |
|
table_rows = [epoch] + table_rows if epoch is not None else table_rows |
|
table_rows = [model_name] + table_rows |
|
table.add_data(*table_rows) |
|
data_idx += 1 |
|
except TypeError: |
|
pass |
|
if batch_idx + 1 == max_validation_batches: |
|
break |
|
return table |
|
|