naderasadi's picture
Initial commit
5b2ab1c
from typing import Any, List, Optional, Tuple, Union
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
from transformers import (
AutoImageProcessor,
Mask2FormerForUniversalSegmentation,
MaskFormerImageProcessor,
MaskFormerForInstanceSegmentation,
)
class MaskFormer:
"""MaskFormer semantic segmentation model.
Args:
model_size (str, optional):
Size of the MaskFormer model. Defaults to "large".
"""
def __init__(self, model_size: Optional[str] = "large") -> None:
assert model_size in [
"tiny",
"base",
"large",
], "Model size must be one of 'tiny', 'base', or 'large'"
self.processor = MaskFormerImageProcessor.from_pretrained(
f"facebook/maskformer-swin-{model_size}-ade"
)
self.model = MaskFormerForInstanceSegmentation.from_pretrained(
f"facebook/maskformer-swin-{model_size}-ade"
)
def process(self, images: List[Image.Image]):
inputs = self.processor(images=images, return_tensors="pt")
outputs = self.model(**inputs)
# model predicts class_queries_logits of shape `(batch_size, num_queries)`
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits
# you can pass them to processor for postprocessing
# we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
predicted_semantic_maps = self.processor.post_process_semantic_segmentation(
outputs, target_sizes=[images[0].size[::-1] * len(images)]
)
return predicted_semantic_maps
class Mask2Former(MaskFormer):
"""Mask2Former semantic segmentation model.
Args:
model_size (str, optional):
Size of the Mask2Former model. Defaults to "large".
"""
def __init__(self, model_size: Optional[str] = "large") -> None:
assert model_size in [
"tiny",
"base",
"large",
], "Model size must be one of 'tiny', 'base', or 'large'"
self.processor = AutoImageProcessor.from_pretrained(
f"facebook/mask2former-swin-{model_size}-ade-semantic"
)
self.model = Mask2FormerForUniversalSegmentation.from_pretrained(
f"facebook/mask2former-swin-{model_size}-ade-semantic"
)
# class ADESegmentation:
# def __init__(self, model_name: str):
# self.processor = MODEL_DICT[model_name]["processor"].from_pretrained(
# MODEL_DICT[model_name]["name"]
# )
# self.model = MODEL_DICT[model_name]["model"].from_pretrained(
# MODEL_DICT[model_name]["name"]
# )
# def predict(self, image: Image.Image):
# inputs = processor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# # model predicts class_queries_logits of shape `(batch_size, num_queries)`
# # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
# class_queries_logits = outputs.class_queries_logits
# masks_queries_logits = outputs.masks_queries_logits
# # you can pass them to processor for postprocessing
# # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
# predicted_semantic_maps = processor.post_process_semantic_segmentation(
# outputs, target_sizes=[image.size[::-1]]
# )
# return predicted_semantic_maps
# def get_mask(self, predicted_semantic_maps, class_id: int):
# masks, labels, obj_names = get_masks_from_segmentation_map(
# predicted_semantic_maps[0]
# )
# mask = masks[labels.index(ID)]
# object_mask = np.logical_not(mask).astype(int)
# mask = torch.Tensor(mask).repeat(3, 1, 1)
# object_mask = torch.Tensor(object_mask).repeat(3, 1, 1)
# return mask, object_mask
# def get_PIL_mask(self, predicted_semantic_maps, class_id: int):
# mask, object_mask = self.get_mask(predicted_semantic_maps[0], class_id=class_id)
# mask = transforms.ToPILImage()(mask)
# object_mask = transforms.ToPILImage()(object_mask)
# return mask, object_mask
# def get_PIL_segmentation_map(self, predicted_semantic_maps):
# return visualize_segmentation_map(predicted_semantic_maps[0])