PengWeixuanSZU commited on
Commit
2c6d33e
·
verified ·
1 Parent(s): c098df2

Should work.....

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -177,7 +177,7 @@ def preprocess_for_removal(images, masks):
177
  out_masks.append(msk_resized)
178
  arr_images = np.stack(out_images)
179
  arr_masks = np.stack(out_masks)
180
- return torch.from_numpy(arr_images).half().to(device), torch.from_numpy(arr_masks).half().to(device)
181
 
182
  @spaces.GPU(duration=200)
183
  def inference_and_return_video(dilation_iterations, num_inference_steps, video_state):
@@ -189,7 +189,8 @@ def inference_and_return_video(dilation_iterations, num_inference_steps, video_s
189
  images = np.array(images)
190
  masks = np.array(masks)
191
  img_tensor, mask_tensor = preprocess_for_removal(images, masks)
192
- mask_tensor = mask_tensor[:,:,:,:1]
 
193
 
194
  if mask_tensor.shape[1] < mask_tensor.shape[2]:
195
  height = 480
 
177
  out_masks.append(msk_resized)
178
  arr_images = np.stack(out_images)
179
  arr_masks = np.stack(out_masks)
180
+ return torch.from_numpy(arr_images).half(), torch.from_numpy(arr_masks).half()
181
 
182
  @spaces.GPU(duration=200)
183
  def inference_and_return_video(dilation_iterations, num_inference_steps, video_state):
 
189
  images = np.array(images)
190
  masks = np.array(masks)
191
  img_tensor, mask_tensor = preprocess_for_removal(images, masks)
192
+ img_tensor=img_tensor.to("cuda")
193
+ mask_tensor = mask_tensor[:,:,:,:1].to("cuda")
194
 
195
  if mask_tensor.shape[1] < mask_tensor.shape[2]:
196
  height = 480