naderasadi's picture
Initial commit
5b2ab1c
raw
history blame contribute delete
471 Bytes
from .maskformer import MaskFormer, Mask2Former
SEGMENTATION_MODEL_DICT = {
"maskformer": MaskFormer,
"mask2former": Mask2Former,
}
def create_segmentation_model(segmentation_model_name: str, **kwargs):
assert (
segmentation_model_name in SEGMENTATION_MODEL_DICT.keys()
), "Segmentation model name must be one of " + ", ".join(
SEGMENTATION_MODEL_DICT.keys()
)
return SEGMENTATION_MODEL_DICT[segmentation_model_name](**kwargs)