Spaces:
Build error
Build error
from .maskformer import MaskFormer, Mask2Former | |
SEGMENTATION_MODEL_DICT = { | |
"maskformer": MaskFormer, | |
"mask2former": Mask2Former, | |
} | |
def create_segmentation_model(segmentation_model_name: str, **kwargs): | |
assert ( | |
segmentation_model_name in SEGMENTATION_MODEL_DICT.keys() | |
), "Segmentation model name must be one of " + ", ".join( | |
SEGMENTATION_MODEL_DICT.keys() | |
) | |
return SEGMENTATION_MODEL_DICT[segmentation_model_name](**kwargs) | |