Spaces:
Build error
Build error
File size: 471 Bytes
5b2ab1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from .maskformer import MaskFormer, Mask2Former
SEGMENTATION_MODEL_DICT = {
"maskformer": MaskFormer,
"mask2former": Mask2Former,
}
def create_segmentation_model(segmentation_model_name: str, **kwargs):
assert (
segmentation_model_name in SEGMENTATION_MODEL_DICT.keys()
), "Segmentation model name must be one of " + ", ".join(
SEGMENTATION_MODEL_DICT.keys()
)
return SEGMENTATION_MODEL_DICT[segmentation_model_name](**kwargs)
|