pratikshahp commited on
Commit
f9d3b76
·
verified ·
1 Parent(s): c8bf552

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -53
app.py CHANGED
@@ -1,60 +1,43 @@
1
- import os
2
- import io
3
- from PIL import Image,ImageDraw
4
- from transformers import AutoImageProcessor, AutoModelForObjectDetection
5
- import streamlit as st
6
  import torch
 
7
  import requests
 
8
 
9
- def input_image_setup(uploaded_file):
10
- if uploaded_file is not None:
11
- bytes_data = uploaded_file.getvalue()
12
- image = Image.open(io.BytesIO(bytes_data)) # Convert bytes data to PIL image
13
- return image
14
- else:
15
- raise FileNotFoundError("No file uploaded")
16
-
17
- #Streamlit App
18
- st.set_page_config(page_title="Image Detection")
19
- st.header("Object Detection Application")
20
- #Select your model
21
- models = ["facebook/detr-resnet-50","ciasimbaya/ObjectDetection","hustvl/yolos-tiny","microsoft/table-transformer-detection","valentinafeve/yolos-fashionpedia"] # List of supported models
22
- model_name = st.selectbox("Select model", models)
23
- processor = AutoImageProcessor.from_pretrained(model_name)
24
- model = AutoModelForObjectDetection.from_pretrained(model_name)
25
- #Upload an image
26
- uploaded_file = st.file_uploader("choose an image...", type=["jpg","jpeg","png"])
27
- image=""
28
- if uploaded_file is not None:
29
- image = Image.open(uploaded_file)
30
- st.image(image, caption="Uploaded Image.", use_column_width=True)
31
- submit = st.button("Detect Objects ")
32
- if submit:
33
- image_data = input_image_setup(uploaded_file)
34
- st.subheader("The response is..")
35
- inputs = processor(images=image, return_tensors="pt")
36
  outputs = model(**inputs)
37
 
38
- logits = outputs.logits
39
- bboxes = outputs.pred_boxes
40
-
41
- target_sizes = torch.tensor([image.size[::-1]])
42
- results = processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]
43
 
44
- # Draw bounding boxes on the image
45
- drawn_image = image.copy()
46
- draw = ImageDraw.Draw(drawn_image)
47
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
48
- box = [int(i) for i in box.tolist()]
49
- draw.rectangle(box, outline="red", width=2)
50
- label_text = f"{model.config.id2label[label.item()]} ({round(score.item(), 2)})"
51
- draw.text((box[0], box[1]), label_text, fill="red")
52
-
53
- st.image(drawn_image, caption="Detected Objects", use_column_width=True)
54
- st.subheader("List of Objects:")
55
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
56
- box = [round(i, 2) for i in box.tolist()]
57
- st.write(
58
- f"Detected :orange[{model.config.id2label[label.item()]}] with confidence "
59
- f":green[{round(score.item(), 3)}] at location :violet[{box}]"
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection
2
+ from PIL import Image, ImageDraw
 
 
 
3
  import torch
4
+ import gradio as gr
5
  import requests
6
+ from io import BytesIO
7
 
8
+ # Load pre-trained DETR model
9
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
10
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
11
+
12
+ # COCO class index for "person" = 1 (used as proxy for face detection)
13
+ FACE_CLASS_INDEX = 1
14
+
15
+ def detect_faces(img: Image.Image):
16
+ # Prepare input for the model
17
+ inputs = processor(images=img, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  outputs = model(**inputs)
19
 
20
+ # Get outputs
21
+ target_sizes = torch.tensor([img.size[::-1]])
22
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
 
 
23
 
24
+ # Draw bounding boxes
25
+ draw = ImageDraw.Draw(img)
 
 
 
 
 
 
 
 
 
26
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
27
+ if label.item() == FACE_CLASS_INDEX: # 'person'
28
+ box = [round(i, 2) for i in box.tolist()]
29
+ draw.rectangle(box, outline="green", width=3)
30
+ draw.text((box[0], box[1]), f"{score:.2f}", fill="green")
31
+
32
+ return img
33
+
34
+ # Gradio interface
35
+ iface = gr.Interface(
36
+ fn=detect_faces,
37
+ inputs=gr.Image(type="pil"),
38
+ outputs="image",
39
+ title="Face Detection App (Hugging Face + Gradio)",
40
+ description="Upload an image and detect faces using facebook/detr-resnet-50 model."
41
+ )
42
+
43
+ iface.launch()