File size: 471 Bytes
5b2ab1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from .maskformer import MaskFormer, Mask2Former

SEGMENTATION_MODEL_DICT = {
    "maskformer": MaskFormer,
    "mask2former": Mask2Former,
}


def create_segmentation_model(segmentation_model_name: str, **kwargs):
    assert (
        segmentation_model_name in SEGMENTATION_MODEL_DICT.keys()
    ), "Segmentation model name must be one of " + ", ".join(
        SEGMENTATION_MODEL_DICT.keys()
    )

    return SEGMENTATION_MODEL_DICT[segmentation_model_name](**kwargs)