File size: 4,576 Bytes
5b2ab1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import Any, List, Optional, Tuple, Union
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
from transformers import (
    AutoImageProcessor,
    Mask2FormerForUniversalSegmentation,
    MaskFormerImageProcessor,
    MaskFormerForInstanceSegmentation,
)


class MaskFormer:
    """MaskFormer semantic segmentation model.

    Args:
        model_size (str, optional):
            Size of the MaskFormer model. Defaults to "large".
    """

    def __init__(self, model_size: Optional[str] = "large") -> None:
        assert model_size in [
            "tiny",
            "base",
            "large",
        ], "Model size must be one of 'tiny', 'base', or 'large'"

        self.processor = MaskFormerImageProcessor.from_pretrained(
            f"facebook/maskformer-swin-{model_size}-ade"
        )
        self.model = MaskFormerForInstanceSegmentation.from_pretrained(
            f"facebook/maskformer-swin-{model_size}-ade"
        )

    def process(self, images: List[Image.Image]):
        inputs = self.processor(images=images, return_tensors="pt")
        outputs = self.model(**inputs)
        # model predicts class_queries_logits of shape `(batch_size, num_queries)`
        # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
        class_queries_logits = outputs.class_queries_logits
        masks_queries_logits = outputs.masks_queries_logits

        # you can pass them to processor for postprocessing
        # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
        predicted_semantic_maps = self.processor.post_process_semantic_segmentation(
            outputs, target_sizes=[images[0].size[::-1] * len(images)]
        )

        return predicted_semantic_maps


class Mask2Former(MaskFormer):
    """Mask2Former semantic segmentation model.

    Args:
        model_size (str, optional):
            Size of the Mask2Former model. Defaults to "large".
    """

    def __init__(self, model_size: Optional[str] = "large") -> None:
        assert model_size in [
            "tiny",
            "base",
            "large",
        ], "Model size must be one of 'tiny', 'base', or 'large'"
        self.processor = AutoImageProcessor.from_pretrained(
            f"facebook/mask2former-swin-{model_size}-ade-semantic"
        )
        self.model = Mask2FormerForUniversalSegmentation.from_pretrained(
            f"facebook/mask2former-swin-{model_size}-ade-semantic"
        )


# class ADESegmentation:
#     def __init__(self, model_name: str):
#         self.processor = MODEL_DICT[model_name]["processor"].from_pretrained(
#             MODEL_DICT[model_name]["name"]
#         )
#         self.model = MODEL_DICT[model_name]["model"].from_pretrained(
#             MODEL_DICT[model_name]["name"]
#         )

#     def predict(self, image: Image.Image):
#         inputs = processor(images=image, return_tensors="pt")
#         outputs = model(**inputs)
#         # model predicts class_queries_logits of shape `(batch_size, num_queries)`
#         # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
#         class_queries_logits = outputs.class_queries_logits
#         masks_queries_logits = outputs.masks_queries_logits

#         # you can pass them to processor for postprocessing
#         # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
#         predicted_semantic_maps = processor.post_process_semantic_segmentation(
#             outputs, target_sizes=[image.size[::-1]]
#         )

#         return predicted_semantic_maps

#     def get_mask(self, predicted_semantic_maps, class_id: int):
#         masks, labels, obj_names = get_masks_from_segmentation_map(
#             predicted_semantic_maps[0]
#         )

#         mask = masks[labels.index(ID)]
#         object_mask = np.logical_not(mask).astype(int)

#         mask = torch.Tensor(mask).repeat(3, 1, 1)
#         object_mask = torch.Tensor(object_mask).repeat(3, 1, 1)

#         return mask, object_mask

#     def get_PIL_mask(self, predicted_semantic_maps, class_id: int):
#         mask, object_mask = self.get_mask(predicted_semantic_maps[0], class_id=class_id)

#         mask = transforms.ToPILImage()(mask)
#         object_mask = transforms.ToPILImage()(object_mask)

#         return mask, object_mask

#     def get_PIL_segmentation_map(self, predicted_semantic_maps):
#         return visualize_segmentation_map(predicted_semantic_maps[0])