Spaces:
Build error
Build error
File size: 4,576 Bytes
5b2ab1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from typing import Any, List, Optional, Tuple, Union
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
from transformers import (
AutoImageProcessor,
Mask2FormerForUniversalSegmentation,
MaskFormerImageProcessor,
MaskFormerForInstanceSegmentation,
)
class MaskFormer:
"""MaskFormer semantic segmentation model.
Args:
model_size (str, optional):
Size of the MaskFormer model. Defaults to "large".
"""
def __init__(self, model_size: Optional[str] = "large") -> None:
assert model_size in [
"tiny",
"base",
"large",
], "Model size must be one of 'tiny', 'base', or 'large'"
self.processor = MaskFormerImageProcessor.from_pretrained(
f"facebook/maskformer-swin-{model_size}-ade"
)
self.model = MaskFormerForInstanceSegmentation.from_pretrained(
f"facebook/maskformer-swin-{model_size}-ade"
)
def process(self, images: List[Image.Image]):
inputs = self.processor(images=images, return_tensors="pt")
outputs = self.model(**inputs)
# model predicts class_queries_logits of shape `(batch_size, num_queries)`
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits
# you can pass them to processor for postprocessing
# we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
predicted_semantic_maps = self.processor.post_process_semantic_segmentation(
outputs, target_sizes=[images[0].size[::-1] * len(images)]
)
return predicted_semantic_maps
class Mask2Former(MaskFormer):
"""Mask2Former semantic segmentation model.
Args:
model_size (str, optional):
Size of the Mask2Former model. Defaults to "large".
"""
def __init__(self, model_size: Optional[str] = "large") -> None:
assert model_size in [
"tiny",
"base",
"large",
], "Model size must be one of 'tiny', 'base', or 'large'"
self.processor = AutoImageProcessor.from_pretrained(
f"facebook/mask2former-swin-{model_size}-ade-semantic"
)
self.model = Mask2FormerForUniversalSegmentation.from_pretrained(
f"facebook/mask2former-swin-{model_size}-ade-semantic"
)
# class ADESegmentation:
# def __init__(self, model_name: str):
# self.processor = MODEL_DICT[model_name]["processor"].from_pretrained(
# MODEL_DICT[model_name]["name"]
# )
# self.model = MODEL_DICT[model_name]["model"].from_pretrained(
# MODEL_DICT[model_name]["name"]
# )
# def predict(self, image: Image.Image):
# inputs = processor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# # model predicts class_queries_logits of shape `(batch_size, num_queries)`
# # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
# class_queries_logits = outputs.class_queries_logits
# masks_queries_logits = outputs.masks_queries_logits
# # you can pass them to processor for postprocessing
# # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
# predicted_semantic_maps = processor.post_process_semantic_segmentation(
# outputs, target_sizes=[image.size[::-1]]
# )
# return predicted_semantic_maps
# def get_mask(self, predicted_semantic_maps, class_id: int):
# masks, labels, obj_names = get_masks_from_segmentation_map(
# predicted_semantic_maps[0]
# )
# mask = masks[labels.index(ID)]
# object_mask = np.logical_not(mask).astype(int)
# mask = torch.Tensor(mask).repeat(3, 1, 1)
# object_mask = torch.Tensor(object_mask).repeat(3, 1, 1)
# return mask, object_mask
# def get_PIL_mask(self, predicted_semantic_maps, class_id: int):
# mask, object_mask = self.get_mask(predicted_semantic_maps[0], class_id=class_id)
# mask = transforms.ToPILImage()(mask)
# object_mask = transforms.ToPILImage()(object_mask)
# return mask, object_mask
# def get_PIL_segmentation_map(self, predicted_semantic_maps):
# return visualize_segmentation_map(predicted_semantic_maps[0])
|