Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,641 Bytes
a51c6d2 4138a21 a51c6d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gc
import numpy as np
import torch
from segment_anything import SamPredictor, sam_model_registry
# Try to import HF SAM support
try:
from app_3rd.sam_utils.hf_sam_predictor import get_hf_sam_predictor, HFSamPredictor
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
models = {
'vit_b': 'app_3rd/sam_utils/checkpoints/sam_vit_b_01ec64.pth',
'vit_l': 'app_3rd/sam_utils/checkpoints/sam_vit_l_0b3195.pth',
'vit_h': 'app_3rd/sam_utils/checkpoints/sam_vit_h_4b8939.pth'
}
def get_sam_predictor(model_type='vit_b', device=None, image=None, use_hf=True):
"""
Get SAM predictor with option to use HuggingFace version
Args:
model_type: Model type ('vit_b', 'vit_l', 'vit_h')
device: Device to run on
image: Optional image to set immediately
use_hf: Whether to use HuggingFace SAM instead of original SAM
"""
if use_hf:
if not HF_AVAILABLE:
raise ImportError("HuggingFace SAM not available. Install transformers and huggingface_hub.")
return get_hf_sam_predictor(model_type, device, image)
# Original SAM logic
if device is None and torch.cuda.is_available():
device = 'cuda'
elif device is None:
device = 'cpu'
# sam model
sam = sam_model_registry[model_type](checkpoint=models[model_type])
sam = sam.to(device)
predictor = SamPredictor(sam)
if image is not None:
predictor.set_image(image)
return predictor
def run_inference(predictor, input_x, selected_points, multi_object: bool = False):
"""
Run inference with either original SAM or HF SAM predictor
Args:
predictor: SamPredictor or HFSamPredictor instance
input_x: Input image
selected_points: List of (point, label) tuples
multi_object: Whether to handle multiple objects
"""
if len(selected_points) == 0:
return []
# Check if using HF SAM
if isinstance(predictor, HFSamPredictor):
return _run_hf_inference(predictor, input_x, selected_points, multi_object)
else:
return _run_original_inference(predictor, input_x, selected_points, multi_object)
def _run_original_inference(predictor: SamPredictor, input_x, selected_points, multi_object: bool = False):
"""Run inference with original SAM"""
points = torch.Tensor(
[p for p, _ in selected_points]
).to(predictor.device).unsqueeze(1)
labels = torch.Tensor(
[int(l) for _, l in selected_points]
).to(predictor.device).unsqueeze(1)
transformed_points = predictor.transform.apply_coords_torch(
points, input_x.shape[:2])
masks, scores, logits = predictor.predict_torch(
point_coords=transformed_points[:,0][None],
point_labels=labels[:,0][None],
multimask_output=False,
)
masks = masks[0].cpu().numpy() # N 1 H W N is the number of points
gc.collect()
torch.cuda.empty_cache()
return [(masks, 'final_mask')]
def _run_hf_inference(predictor: HFSamPredictor, input_x, selected_points, multi_object: bool = False):
"""Run inference with HF SAM"""
# Prepare points and labels for HF SAM
select_pts = [[list(p) for p, _ in selected_points]]
select_lbls = [[int(l) for _, l in selected_points]]
# Preprocess inputs
inputs = predictor.preprocess(input_x, select_pts, select_lbls)
# Run inference
with torch.no_grad():
outputs = predictor.model(**inputs)
# Post-process masks
masks = predictor.processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu(),
)
masks = masks[0][:,:1,...].cpu().numpy()
gc.collect()
torch.cuda.empty_cache()
return [(masks, 'final_mask')] |