Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from src.model import ClipSegMultiClassModel | |
| from src.config import ClipSegMultiClassConfig | |
| # === Load model === | |
| class_labels = ["background", "Pig", "Horse", "Sheep"] | |
| label2color = { | |
| 0: [0, 0, 0], | |
| 1: [255, 0, 0], | |
| 2: [0, 255, 0], | |
| 3: [0, 0, 255], | |
| } | |
| config = ClipSegMultiClassConfig( | |
| class_labels=class_labels, | |
| label2color=label2color, | |
| model="CIDAS/clipseg-rd64-refined", | |
| ) | |
| model = ClipSegMultiClassModel.from_pretrained("BioMike/clipsegmulticlass_v1") | |
| model.eval() | |
| def colorize_mask(mask_tensor, label2color): | |
| mask = mask_tensor.squeeze().cpu().numpy() | |
| h, w = mask.shape | |
| color_mask = np.zeros((h, w, 3), dtype=np.uint8) | |
| for class_id, color in label2color.items(): | |
| color_mask[mask == class_id] = color | |
| return color_mask | |
| def segment_with_legend(input_img): | |
| if isinstance(input_img, str): | |
| input_img = Image.open(input_img).convert("RGB") | |
| elif isinstance(input_img, np.ndarray): | |
| input_img = Image.fromarray(input_img).convert("RGB") | |
| pred_mask = model.predict(input_img) | |
| color_mask = colorize_mask(pred_mask, label2color) | |
| overlay = Image.blend(input_img.resize((color_mask.shape[1], color_mask.shape[0])), Image.fromarray(color_mask), alpha=0.5) | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.imshow(overlay) | |
| ax.axis("off") | |
| legend_patches = [ | |
| plt.Line2D([0], [0], marker='o', color='w', | |
| label=label, | |
| markerfacecolor=np.array(color) / 255.0, | |
| markersize=10) | |
| for label, color in zip(class_labels, label2color.values()) | |
| ] | |
| ax.legend(handles=legend_patches, loc='lower right', framealpha=0.8) | |
| return fig | |