|
import time
|
|
import torch
|
|
import cv2
|
|
from PIL import Image, ImageDraw, ImageOps
|
|
import numpy as np
|
|
from typing import Union
|
|
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
|
import matplotlib.pyplot as plt
|
|
import PIL
|
|
from .mask_painter import mask_painter
|
|
|
|
|
|
class BaseSegmenter:
|
|
def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
|
|
"""
|
|
device: model device
|
|
SAM_checkpoint: path of SAM checkpoint
|
|
model_type: vit_b, vit_l, vit_h
|
|
"""
|
|
print(f"Initializing BaseSegmenter to {device}")
|
|
assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
|
|
|
|
self.device = device
|
|
|
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
|
from accelerate import init_empty_weights
|
|
|
|
|
|
with init_empty_weights():
|
|
self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
|
|
from mmgp import offload
|
|
|
|
|
|
|
|
offload.load_model_data(self.model, "ckpts/mask/sam_vit_h_4b8939_fp16.safetensors")
|
|
self.model.to(torch.float32)
|
|
self.model.to(device=self.device)
|
|
self.predictor = SamPredictor(self.model)
|
|
self.embedded = False
|
|
|
|
@torch.no_grad()
|
|
def set_image(self, image: np.ndarray):
|
|
|
|
|
|
self.orignal_image = image
|
|
if self.embedded:
|
|
print('repeat embedding, please reset_image.')
|
|
return
|
|
self.predictor.set_image(image)
|
|
self.embedded = True
|
|
return
|
|
|
|
@torch.no_grad()
|
|
def reset_image(self):
|
|
|
|
self.predictor.reset_image()
|
|
self.embedded = False
|
|
|
|
def predict(self, prompts, mode, multimask=True):
|
|
"""
|
|
image: numpy array, h, w, 3
|
|
prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
|
|
prompts['point_coords']: numpy array [N,2]
|
|
prompts['point_labels']: numpy array [1,N]
|
|
prompts['mask_input']: numpy array [1,256,256]
|
|
mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
|
|
mask_outputs: True (return 3 masks), False (return 1 mask only)
|
|
whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
|
|
"""
|
|
assert self.embedded, 'prediction is called before set_image (feature embedding).'
|
|
assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
|
|
|
|
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
|
if mode == 'point':
|
|
masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
|
|
point_labels=prompts['point_labels'],
|
|
multimask_output=multimask)
|
|
elif mode == 'mask':
|
|
masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
|
|
multimask_output=multimask)
|
|
elif mode == 'both':
|
|
masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
|
|
point_labels=prompts['point_labels'],
|
|
mask_input=prompts['mask_input'],
|
|
multimask_output=multimask)
|
|
else:
|
|
raise("Not implement now!")
|
|
|
|
return masks, scores, logits
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
|
|
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
|
model_type = 'vit_h'
|
|
device = "cuda:4"
|
|
base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
|
|
|
|
|
|
base_segmenter.set_image(image)
|
|
|
|
|
|
|
|
mode = 'point'
|
|
prompts = {
|
|
'point_coords': np.array([[500, 375], [1125, 625]]),
|
|
'point_labels': np.array([1, 1]),
|
|
}
|
|
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False)
|
|
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
|
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR)
|
|
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
|
|
|
|
|
mode = 'both'
|
|
mask_input = logits[np.argmax(scores), :, :]
|
|
prompts = {'mask_input': mask_input [None, :, :]}
|
|
prompts = {
|
|
'point_coords': np.array([[500, 375], [1125, 625]]),
|
|
'point_labels': np.array([1, 0]),
|
|
'mask_input': mask_input[None, :, :]
|
|
}
|
|
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True)
|
|
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
|
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR)
|
|
cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
|
|
|
|
|
|
mode = 'mask'
|
|
mask_input = logits[np.argmax(scores), :, :]
|
|
|
|
prompts = {'mask_input': mask_input[None, :, :]}
|
|
|
|
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True)
|
|
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
|
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR)
|
|
cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
|
|
|