mubashirhussaindev commited on
Commit
f31b870
·
verified ·
1 Parent(s): 54f75bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -5,13 +5,19 @@ from transformers import AutoModelForImageSegmentation
5
  import torch
6
  from torchvision import transforms
7
 
8
- torch.set_float32_matmul_precision(["high", "highest"][0])
 
9
 
 
 
 
 
10
  birefnet = AutoModelForImageSegmentation.from_pretrained(
11
  "ZhengPeng7/BiRefNet", trust_remote_code=True
12
  )
13
- birefnet.to("cpu")
14
 
 
15
  transform_image = transforms.Compose(
16
  [
17
  transforms.Resize((1024, 1024)),
@@ -20,18 +26,10 @@ transform_image = transforms.Compose(
20
  ]
21
  )
22
 
23
- def fn(image):
24
- im = load_img(image, output_type="pil")
25
- im = im.convert("RGB")
26
- origin = im.copy()
27
- processed_image = process(im)
28
- return (processed_image, origin)
29
-
30
- @spaces.GPU
31
  def process(image):
32
  image_size = image.size
33
- input_images = transform_image(image).unsqueeze(0).to("cuda")
34
- # Prediction
35
  with torch.no_grad():
36
  preds = birefnet(input_images)[-1].sigmoid().cpu()
37
  pred = preds[0].squeeze()
@@ -40,6 +38,15 @@ def process(image):
40
  image.putalpha(mask)
41
  return image
42
 
 
 
 
 
 
 
 
 
 
43
  def process_file(f):
44
  name_path = f.rsplit(".", 1)[0] + ".png"
45
  im = load_img(f, output_type="pil")
@@ -48,6 +55,7 @@ def process_file(f):
48
  transparent.save(name_path)
49
  return name_path
50
 
 
51
  slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
52
  slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
53
  image_upload = gr.Image(label="Upload an image")
@@ -59,13 +67,15 @@ output_file = gr.File(label="Output PNG File")
59
  chameleon = load_img("butterfly.jpg", output_type="pil")
60
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
61
 
 
62
  tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
63
  tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
64
  tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
65
 
 
66
  demo = gr.TabbedInterface(
67
  [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
68
  )
69
 
70
  if __name__ == "__main__":
71
- demo.launch(show_error=True)
 
5
  import torch
6
  from torchvision import transforms
7
 
8
+ # Set float32 matmul precision (used for performance tuning)
9
+ torch.set_float32_matmul_precision("high")
10
 
11
+ # Detect device (GPU if available, otherwise CPU)
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Load model and send to appropriate device
15
  birefnet = AutoModelForImageSegmentation.from_pretrained(
16
  "ZhengPeng7/BiRefNet", trust_remote_code=True
17
  )
18
+ birefnet.to(device)
19
 
20
+ # Image transformation pipeline
21
  transform_image = transforms.Compose(
22
  [
23
  transforms.Resize((1024, 1024)),
 
26
  ]
27
  )
28
 
29
+ # Background removal pipeline
 
 
 
 
 
 
 
30
  def process(image):
31
  image_size = image.size
32
+ input_images = transform_image(image).unsqueeze(0).to(device)
 
33
  with torch.no_grad():
34
  preds = birefnet(input_images)[-1].sigmoid().cpu()
35
  pred = preds[0].squeeze()
 
38
  image.putalpha(mask)
39
  return image
40
 
41
+ # Gradio interface function
42
+ def fn(image):
43
+ im = load_img(image, output_type="pil")
44
+ im = im.convert("RGB")
45
+ origin = im.copy()
46
+ processed_image = process(im)
47
+ return (processed_image, origin)
48
+
49
+ # Process uploaded file and return PNG with alpha channel
50
  def process_file(f):
51
  name_path = f.rsplit(".", 1)[0] + ".png"
52
  im = load_img(f, output_type="pil")
 
55
  transparent.save(name_path)
56
  return name_path
57
 
58
+ # Gradio components
59
  slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
60
  slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
61
  image_upload = gr.Image(label="Upload an image")
 
67
  chameleon = load_img("butterfly.jpg", output_type="pil")
68
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
69
 
70
+ # Gradio interface tabs
71
  tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
72
  tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
73
  tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
74
 
75
+ # Launch app
76
  demo = gr.TabbedInterface(
77
  [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
78
  )
79
 
80
  if __name__ == "__main__":
81
+ demo.launch(show_error=True)