File size: 3,641 Bytes
a51c6d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4138a21
a51c6d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gc

import numpy as np
import torch
from segment_anything import SamPredictor, sam_model_registry

# Try to import HF SAM support
try:
    from app_3rd.sam_utils.hf_sam_predictor import get_hf_sam_predictor, HFSamPredictor
    HF_AVAILABLE = True
except ImportError:
    HF_AVAILABLE = False

models = {
  'vit_b': 'app_3rd/sam_utils/checkpoints/sam_vit_b_01ec64.pth',
  'vit_l': 'app_3rd/sam_utils/checkpoints/sam_vit_l_0b3195.pth',
  'vit_h': 'app_3rd/sam_utils/checkpoints/sam_vit_h_4b8939.pth'
}


def get_sam_predictor(model_type='vit_b', device=None, image=None, use_hf=True):
  """
  Get SAM predictor with option to use HuggingFace version
  
  Args:
      model_type: Model type ('vit_b', 'vit_l', 'vit_h')
      device: Device to run on
      image: Optional image to set immediately
      use_hf: Whether to use HuggingFace SAM instead of original SAM
  """
  if use_hf:
    if not HF_AVAILABLE:
      raise ImportError("HuggingFace SAM not available. Install transformers and huggingface_hub.")
    return get_hf_sam_predictor(model_type, device, image)
  
  # Original SAM logic
  if device is None and torch.cuda.is_available():
    device = 'cuda'
  elif device is None:
    device = 'cpu'
  # sam model
  sam = sam_model_registry[model_type](checkpoint=models[model_type])
  sam = sam.to(device)

  predictor = SamPredictor(sam)
  if image is not None:
    predictor.set_image(image)
  return predictor


def run_inference(predictor, input_x, selected_points, multi_object: bool = False):
  """
  Run inference with either original SAM or HF SAM predictor
  
  Args:
      predictor: SamPredictor or HFSamPredictor instance
      input_x: Input image
      selected_points: List of (point, label) tuples
      multi_object: Whether to handle multiple objects
  """
  if len(selected_points) == 0:
    return []
  
  # Check if using HF SAM
  if isinstance(predictor, HFSamPredictor):
    return _run_hf_inference(predictor, input_x, selected_points, multi_object)
  else:
    return _run_original_inference(predictor, input_x, selected_points, multi_object)


def _run_original_inference(predictor: SamPredictor, input_x, selected_points, multi_object: bool = False):
  """Run inference with original SAM"""
  points = torch.Tensor(
      [p for p, _ in selected_points]
  ).to(predictor.device).unsqueeze(1)

  labels = torch.Tensor(
      [int(l) for _, l in selected_points]
  ).to(predictor.device).unsqueeze(1)

  transformed_points = predictor.transform.apply_coords_torch(
      points, input_x.shape[:2])

  masks, scores, logits = predictor.predict_torch(
    point_coords=transformed_points[:,0][None],
    point_labels=labels[:,0][None],
    multimask_output=False,
  )
  masks = masks[0].cpu().numpy()  # N 1 H W   N is the number of points

  gc.collect()
  torch.cuda.empty_cache()

  return [(masks, 'final_mask')]


def _run_hf_inference(predictor: HFSamPredictor, input_x, selected_points, multi_object: bool = False):
  """Run inference with HF SAM"""
  # Prepare points and labels for HF SAM
  select_pts = [[list(p) for p, _ in selected_points]]
  select_lbls = [[int(l) for _, l in selected_points]]
  
  # Preprocess inputs
  inputs = predictor.preprocess(input_x, select_pts, select_lbls)

  # Run inference
  with torch.no_grad():
    outputs = predictor.model(**inputs)
  
  # Post-process masks
  masks = predictor.processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(),
    inputs["original_sizes"].cpu(),
    inputs["reshaped_input_sizes"].cpu(),
  )
  masks = masks[0][:,:1,...].cpu().numpy()

  gc.collect()
  torch.cuda.empty_cache()

  return [(masks, 'final_mask')]