Spaces:
Build error
Build error
from typing import Any, List, Optional, Tuple, Union | |
from functools import reduce | |
import numpy as np | |
from PIL import Image | |
import requests | |
import torch | |
from torch import nn | |
import torchvision.transforms as transforms | |
from torchvision.transforms.functional import to_pil_image | |
def visualize_segmentation_map( | |
semantic_map: torch.Tensor, original_image: Image.Image | |
) -> Image.Image: | |
""" | |
Visualizes a segmentation map by overlaying it on the original image. | |
Args: | |
semantic_map (torch.Tensor): Segmentation map tensor. | |
original_image (Image.Image): Original image. | |
Returns: | |
Image.Image: Overlay image with segmentation map. | |
""" | |
# Convert to RGB | |
color_seg = np.zeros( | |
(semantic_map.shape[0], semantic_map.shape[1], 3), dtype=np.uint8 | |
) # height, width, 3 | |
palette = np.array(ade_palette()) | |
for label, color in enumerate(palette): | |
color_seg[semantic_map == label, :] = color | |
# Convert to BGR | |
color_seg = color_seg[..., ::-1] | |
# Show image + mask | |
img = np.array(original_image) * 0.5 + color_seg * 0.5 | |
img = img.astype(np.uint8) | |
return Image.fromarray(img) | |
def get_masks_from_segmentation_map( | |
semantic_map: torch.Tensor, | |
) -> Tuple[List[np.array], List[int], List[str]]: | |
""" | |
Extracts masks, labels, and object names from a segmentation map. | |
Args: | |
semantic_map (torch.Tensor): Segmentation map tensor. | |
Returns: | |
Tuple[List[np.array], List[int], List[str]]: Tuple containing masks, labels, and object names. | |
""" | |
masks = [] | |
labels = [] | |
obj_names = [] | |
for label, color in enumerate(np.array(ade_palette())): | |
mask = np.ones( | |
(semantic_map.shape[0], semantic_map.shape[1]), dtype=np.uint8 | |
) # height, width | |
indices = semantic_map == label | |
mask[indices] = 0 | |
if indices.sum() > 0: | |
masks.append(mask) | |
labels.append(label) | |
obj_names.append(ADE_LABELS[str(label)]) | |
return masks, labels, obj_names | |
def get_mask_from_coordinates( | |
segmentation_maps: List[np.array], coordinates: Tuple[int, int] | |
): | |
""" | |
Retrieves a mask from a list of segmentation maps based on given coordinates. | |
Args: | |
segmentation_maps (List[np.array]): List of segmentation maps. | |
coordinates (Tuple[int, int]): Coordinates to filter the masks. | |
Returns: | |
np.array: Combined mask from the segmentation maps. | |
""" | |
masks = [] | |
for seg_map in segmentation_maps: | |
for coordinate in coordinates: | |
if seg_map[coordinate] == 0: | |
masks.append(seg_map) | |
return reduce(np.multiply, masks) | |
def get_masked_images( | |
control_image: Image.Image, | |
semantic_map: torch.Tensor, | |
coordinates: List[Tuple[int, int]], | |
return_tensors: bool = False, | |
) -> Union[torch.Tensor, Image.Image]: | |
""" | |
Retrieves masked images based on given control image, segmentation map, and coordinates. | |
Args: | |
control_image (Image.Image): Control image. | |
semantic_map (torch.Tensor): Segmentation map tensor. | |
coordinates (List[Tuple[int, int]]): List of coordinates. | |
return_tensors (bool, optional): Whether to return masked images as tensors. Defaults to False. | |
Returns: | |
Union[torch.Tensor, Image.Image]: Masked image tensor or PIL image. | |
""" | |
masks, labels, obj_names = get_masks_from_segmentation_map(semantic_map) | |
mask = get_mask_from_coordinates(masks, coordinates) | |
mask_image = np.logical_not(mask).astype(int) | |
mask_image = torch.Tensor(mask_image).repeat(3, 1, 1) | |
mask = torch.Tensor(mask).repeat(3, 1, 1) | |
control_image = transforms.ToTensor()(control_image) | |
masked_control_image = transforms.ToPILImage()(mask * control_image) | |
if not return_tensors: | |
mask_image = to_pil_image(mask_image) | |
return mask_image, masked_control_image | |
ADE_LABELS = requests.get( | |
"https://huggingface.co/datasets/huggingface/label-files/raw/main/ade20k-id2label.json" | |
).json() | |
def ade_palette(): | |
"""ADE20K palette that maps each class to RGB values.""" | |
return [ | |
[120, 120, 120], | |
[180, 120, 120], | |
[6, 230, 230], | |
[80, 50, 50], | |
[4, 200, 3], | |
[120, 120, 80], | |
[140, 140, 140], | |
[204, 5, 255], | |
[230, 230, 230], | |
[4, 250, 7], | |
[224, 5, 255], | |
[235, 255, 7], | |
[150, 5, 61], | |
[120, 120, 70], | |
[8, 255, 51], | |
[255, 6, 82], | |
[143, 255, 140], | |
[204, 255, 4], | |
[255, 51, 7], | |
[204, 70, 3], | |
[0, 102, 200], | |
[61, 230, 250], | |
[255, 6, 51], | |
[11, 102, 255], | |
[255, 7, 71], | |
[255, 9, 224], | |
[9, 7, 230], | |
[220, 220, 220], | |
[255, 9, 92], | |
[112, 9, 255], | |
[8, 255, 214], | |
[7, 255, 224], | |
[255, 184, 6], | |
[10, 255, 71], | |
[255, 41, 10], | |
[7, 255, 255], | |
[224, 255, 8], | |
[102, 8, 255], | |
[255, 61, 6], | |
[255, 194, 7], | |
[255, 122, 8], | |
[0, 255, 20], | |
[255, 8, 41], | |
[255, 5, 153], | |
[6, 51, 255], | |
[235, 12, 255], | |
[160, 150, 20], | |
[0, 163, 255], | |
[140, 140, 140], | |
[250, 10, 15], | |
[20, 255, 0], | |
[31, 255, 0], | |
[255, 31, 0], | |
[255, 224, 0], | |
[153, 255, 0], | |
[0, 0, 255], | |
[255, 71, 0], | |
[0, 235, 255], | |
[0, 173, 255], | |
[31, 0, 255], | |
[11, 200, 200], | |
[255, 82, 0], | |
[0, 255, 245], | |
[0, 61, 255], | |
[0, 255, 112], | |
[0, 255, 133], | |
[255, 0, 0], | |
[255, 163, 0], | |
[255, 102, 0], | |
[194, 255, 0], | |
[0, 143, 255], | |
[51, 255, 0], | |
[0, 82, 255], | |
[0, 255, 41], | |
[0, 255, 173], | |
[10, 0, 255], | |
[173, 255, 0], | |
[0, 255, 153], | |
[255, 92, 0], | |
[255, 0, 255], | |
[255, 0, 245], | |
[255, 0, 102], | |
[255, 173, 0], | |
[255, 0, 20], | |
[255, 184, 184], | |
[0, 31, 255], | |
[0, 255, 61], | |
[0, 71, 255], | |
[255, 0, 204], | |
[0, 255, 194], | |
[0, 255, 82], | |
[0, 10, 255], | |
[0, 112, 255], | |
[51, 0, 255], | |
[0, 194, 255], | |
[0, 122, 255], | |
[0, 255, 163], | |
[255, 153, 0], | |
[0, 255, 10], | |
[255, 112, 0], | |
[143, 255, 0], | |
[82, 0, 255], | |
[163, 255, 0], | |
[255, 235, 0], | |
[8, 184, 170], | |
[133, 0, 255], | |
[0, 255, 92], | |
[184, 0, 255], | |
[255, 0, 31], | |
[0, 184, 255], | |
[0, 214, 255], | |
[255, 0, 112], | |
[92, 255, 0], | |
[0, 224, 255], | |
[112, 224, 255], | |
[70, 184, 160], | |
[163, 0, 255], | |
[153, 0, 255], | |
[71, 255, 0], | |
[255, 0, 163], | |
[255, 204, 0], | |
[255, 0, 143], | |
[0, 255, 235], | |
[133, 255, 0], | |
[255, 0, 235], | |
[245, 0, 255], | |
[255, 0, 122], | |
[255, 245, 0], | |
[10, 190, 212], | |
[214, 255, 0], | |
[0, 204, 255], | |
[20, 0, 255], | |
[255, 255, 0], | |
[0, 153, 255], | |
[0, 41, 255], | |
[0, 255, 204], | |
[41, 0, 255], | |
[41, 255, 0], | |
[173, 0, 255], | |
[0, 245, 255], | |
[71, 0, 255], | |
[122, 0, 255], | |
[0, 255, 184], | |
[0, 92, 255], | |
[184, 255, 0], | |
[0, 133, 255], | |
[255, 214, 0], | |
[25, 194, 194], | |
[102, 255, 0], | |
[92, 0, 255], | |
] | |