DesignGenie / src /designgenie /utils /segmentation_utils.py
naderasadi's picture
Initial commit
5b2ab1c
from typing import Any, List, Optional, Tuple, Union
from functools import reduce
import numpy as np
from PIL import Image
import requests
import torch
from torch import nn
import torchvision.transforms as transforms
from torchvision.transforms.functional import to_pil_image
def visualize_segmentation_map(
semantic_map: torch.Tensor, original_image: Image.Image
) -> Image.Image:
"""
Visualizes a segmentation map by overlaying it on the original image.
Args:
semantic_map (torch.Tensor): Segmentation map tensor.
original_image (Image.Image): Original image.
Returns:
Image.Image: Overlay image with segmentation map.
"""
# Convert to RGB
color_seg = np.zeros(
(semantic_map.shape[0], semantic_map.shape[1], 3), dtype=np.uint8
) # height, width, 3
palette = np.array(ade_palette())
for label, color in enumerate(palette):
color_seg[semantic_map == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]
# Show image + mask
img = np.array(original_image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)
return Image.fromarray(img)
def get_masks_from_segmentation_map(
semantic_map: torch.Tensor,
) -> Tuple[List[np.array], List[int], List[str]]:
"""
Extracts masks, labels, and object names from a segmentation map.
Args:
semantic_map (torch.Tensor): Segmentation map tensor.
Returns:
Tuple[List[np.array], List[int], List[str]]: Tuple containing masks, labels, and object names.
"""
masks = []
labels = []
obj_names = []
for label, color in enumerate(np.array(ade_palette())):
mask = np.ones(
(semantic_map.shape[0], semantic_map.shape[1]), dtype=np.uint8
) # height, width
indices = semantic_map == label
mask[indices] = 0
if indices.sum() > 0:
masks.append(mask)
labels.append(label)
obj_names.append(ADE_LABELS[str(label)])
return masks, labels, obj_names
def get_mask_from_coordinates(
segmentation_maps: List[np.array], coordinates: Tuple[int, int]
):
"""
Retrieves a mask from a list of segmentation maps based on given coordinates.
Args:
segmentation_maps (List[np.array]): List of segmentation maps.
coordinates (Tuple[int, int]): Coordinates to filter the masks.
Returns:
np.array: Combined mask from the segmentation maps.
"""
masks = []
for seg_map in segmentation_maps:
for coordinate in coordinates:
if seg_map[coordinate] == 0:
masks.append(seg_map)
return reduce(np.multiply, masks)
def get_masked_images(
control_image: Image.Image,
semantic_map: torch.Tensor,
coordinates: List[Tuple[int, int]],
return_tensors: bool = False,
) -> Union[torch.Tensor, Image.Image]:
"""
Retrieves masked images based on given control image, segmentation map, and coordinates.
Args:
control_image (Image.Image): Control image.
semantic_map (torch.Tensor): Segmentation map tensor.
coordinates (List[Tuple[int, int]]): List of coordinates.
return_tensors (bool, optional): Whether to return masked images as tensors. Defaults to False.
Returns:
Union[torch.Tensor, Image.Image]: Masked image tensor or PIL image.
"""
masks, labels, obj_names = get_masks_from_segmentation_map(semantic_map)
mask = get_mask_from_coordinates(masks, coordinates)
mask_image = np.logical_not(mask).astype(int)
mask_image = torch.Tensor(mask_image).repeat(3, 1, 1)
mask = torch.Tensor(mask).repeat(3, 1, 1)
control_image = transforms.ToTensor()(control_image)
masked_control_image = transforms.ToPILImage()(mask * control_image)
if not return_tensors:
mask_image = to_pil_image(mask_image)
return mask_image, masked_control_image
ADE_LABELS = requests.get(
"https://huggingface.co/datasets/huggingface/label-files/raw/main/ade20k-id2label.json"
).json()
def ade_palette():
"""ADE20K palette that maps each class to RGB values."""
return [
[120, 120, 120],
[180, 120, 120],
[6, 230, 230],
[80, 50, 50],
[4, 200, 3],
[120, 120, 80],
[140, 140, 140],
[204, 5, 255],
[230, 230, 230],
[4, 250, 7],
[224, 5, 255],
[235, 255, 7],
[150, 5, 61],
[120, 120, 70],
[8, 255, 51],
[255, 6, 82],
[143, 255, 140],
[204, 255, 4],
[255, 51, 7],
[204, 70, 3],
[0, 102, 200],
[61, 230, 250],
[255, 6, 51],
[11, 102, 255],
[255, 7, 71],
[255, 9, 224],
[9, 7, 230],
[220, 220, 220],
[255, 9, 92],
[112, 9, 255],
[8, 255, 214],
[7, 255, 224],
[255, 184, 6],
[10, 255, 71],
[255, 41, 10],
[7, 255, 255],
[224, 255, 8],
[102, 8, 255],
[255, 61, 6],
[255, 194, 7],
[255, 122, 8],
[0, 255, 20],
[255, 8, 41],
[255, 5, 153],
[6, 51, 255],
[235, 12, 255],
[160, 150, 20],
[0, 163, 255],
[140, 140, 140],
[250, 10, 15],
[20, 255, 0],
[31, 255, 0],
[255, 31, 0],
[255, 224, 0],
[153, 255, 0],
[0, 0, 255],
[255, 71, 0],
[0, 235, 255],
[0, 173, 255],
[31, 0, 255],
[11, 200, 200],
[255, 82, 0],
[0, 255, 245],
[0, 61, 255],
[0, 255, 112],
[0, 255, 133],
[255, 0, 0],
[255, 163, 0],
[255, 102, 0],
[194, 255, 0],
[0, 143, 255],
[51, 255, 0],
[0, 82, 255],
[0, 255, 41],
[0, 255, 173],
[10, 0, 255],
[173, 255, 0],
[0, 255, 153],
[255, 92, 0],
[255, 0, 255],
[255, 0, 245],
[255, 0, 102],
[255, 173, 0],
[255, 0, 20],
[255, 184, 184],
[0, 31, 255],
[0, 255, 61],
[0, 71, 255],
[255, 0, 204],
[0, 255, 194],
[0, 255, 82],
[0, 10, 255],
[0, 112, 255],
[51, 0, 255],
[0, 194, 255],
[0, 122, 255],
[0, 255, 163],
[255, 153, 0],
[0, 255, 10],
[255, 112, 0],
[143, 255, 0],
[82, 0, 255],
[163, 255, 0],
[255, 235, 0],
[8, 184, 170],
[133, 0, 255],
[0, 255, 92],
[184, 0, 255],
[255, 0, 31],
[0, 184, 255],
[0, 214, 255],
[255, 0, 112],
[92, 255, 0],
[0, 224, 255],
[112, 224, 255],
[70, 184, 160],
[163, 0, 255],
[153, 0, 255],
[71, 255, 0],
[255, 0, 163],
[255, 204, 0],
[255, 0, 143],
[0, 255, 235],
[133, 255, 0],
[255, 0, 235],
[245, 0, 255],
[255, 0, 122],
[255, 245, 0],
[10, 190, 212],
[214, 255, 0],
[0, 204, 255],
[20, 0, 255],
[255, 255, 0],
[0, 153, 255],
[0, 41, 255],
[0, 255, 204],
[41, 0, 255],
[41, 255, 0],
[173, 0, 255],
[0, 245, 255],
[71, 0, 255],
[122, 0, 255],
[0, 255, 184],
[0, 92, 255],
[184, 255, 0],
[0, 133, 255],
[255, 214, 0],
[25, 194, 194],
[102, 255, 0],
[92, 0, 255],
]