cygon24 commited on
Commit
66be0d9
·
verified ·
1 Parent(s): 39ddfd0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import scipy.io.wavfile as wavfile
4
+
5
+
6
+ # Use a pipeline as a high-level helper
7
+ from transformers import pipeline
8
+
9
+ # used in local
10
+ # model_path = "../models/models--facebook--detr-resnet-50/snapshots/1d5f47bd3bdd2c4bbfa585418ffe6da5028b4c0b"
11
+
12
+ # tts_model_path = ("../Models/models--kakao-enterprise--vits-ljs/snapshots"
13
+ # "/3bcb8321394f671bd948ebf0d086d694dda95464")
14
+
15
+
16
+ narrator = pipeline("text-to-speech",
17
+ model="kakao-enterprise/vits-ljs")
18
+
19
+ object_detector = pipeline("object-detection",
20
+ model="facebook/detr-resnet-50")
21
+
22
+ # object_detector = pipeline("object-detection",
23
+ # model=model_path)
24
+ #
25
+ # narrator = pipeline("text-to-speech",
26
+ # model=tts_model_path)
27
+
28
+ # Define the function to generate audio from text
29
+ def generate_audio(text):
30
+ # Generate the narrated text
31
+ narrated_text = narrator(text)
32
+
33
+ # Save the audio to a WAV file
34
+ wavfile.write("output.wav", rate=narrated_text["sampling_rate"],
35
+ data=narrated_text["audio"][0])
36
+
37
+ # Return the path to the saved audio file
38
+ return "output.wav"
39
+
40
+
41
+ def read_objects(detection_objects):
42
+ # Initialize counters for each object label
43
+ object_counts = {}
44
+
45
+ # Count the occurrences of each label
46
+ for detection in detection_objects:
47
+ label = detection['label']
48
+ if label in object_counts:
49
+ object_counts[label] += 1
50
+ else:
51
+ object_counts[label] = 1
52
+
53
+ # Generate the response string
54
+ response = "This picture contains"
55
+ labels = list(object_counts.keys())
56
+ for i, label in enumerate(labels):
57
+ response += f" {object_counts[label]} {label}"
58
+ if object_counts[label] > 1:
59
+ response += "s"
60
+ if i < len(labels) - 2:
61
+ response += ","
62
+ elif i == len(labels) - 2:
63
+ response += " and"
64
+
65
+ response += "."
66
+
67
+ return response
68
+
69
+
70
+
71
+ def draw_bounding_boxes(image, detections, font_path=None, font_size=20):
72
+ """
73
+ Draws bounding boxes on the given image based on the detections.
74
+ :param image: PIL.Image object
75
+ :param detections: List of detection results, where each result is a dictionary containing
76
+ 'score', 'label', and 'box' keys. 'box' itself is a dictionary with 'xmin',
77
+ 'ymin', 'xmax', 'ymax'.
78
+ :param font_path: Path to the TrueType font file to use for text.
79
+ :param font_size: Size of the font to use for text.
80
+ :return: PIL.Image object with bounding boxes drawn.
81
+ """
82
+ # Make a copy of the image to draw on
83
+ draw_image = image.copy()
84
+ draw = ImageDraw.Draw(draw_image)
85
+
86
+ # Load custom font or default font if path not provided
87
+ if font_path:
88
+ font = ImageFont.truetype(font_path, font_size)
89
+ else:
90
+ # When font_path is not provided, load default font but it's size is fixed
91
+ font = ImageFont.load_default()
92
+ # Increase font size workaround by using a TTF font file, if needed, can download and specify the path
93
+
94
+ for detection in detections:
95
+ box = detection['box']
96
+ xmin = box['xmin']
97
+ ymin = box['ymin']
98
+ xmax = box['xmax']
99
+ ymax = box['ymax']
100
+
101
+ # Draw the bounding box
102
+ draw.rectangle([(xmin, ymin), (xmax, ymax)], outline="red", width=3)
103
+
104
+ # Optionally, you can also draw the label and score
105
+ label = detection['label']
106
+ score = detection['score']
107
+ text = f"{label} {score:.2f}"
108
+
109
+ # Draw text with background rectangle for visibility
110
+ if font_path: # Use the custom font with increased size
111
+ text_size = draw.textbbox((xmin, ymin), text, font=font)
112
+ else:
113
+ # Calculate text size using the default font
114
+ text_size = draw.textbbox((xmin, ymin), text)
115
+
116
+ draw.rectangle([(text_size[0], text_size[1]), (text_size[2], text_size[3])], fill="red")
117
+ draw.text((xmin, ymin), text, fill="white", font=font)
118
+
119
+ return draw_image
120
+
121
+
122
+ def detect_object(image):
123
+ raw_image = image
124
+ output = object_detector(raw_image)
125
+ processed_image = draw_bounding_boxes(raw_image, output)
126
+ natural_text = read_objects(output)
127
+ processed_audio = generate_audio(natural_text)
128
+ return processed_image, processed_audio
129
+
130
+
131
+ demo = gr.Interface(fn=detect_object,
132
+ inputs=[gr.Image(label="Select Image",type="pil")],
133
+ outputs=[gr.Image(label="Processed Image", type="pil"), gr.Audio(label="Generated Audio")],
134
+ title="@cygon: Object Detector with Audio",
135
+ description="THIS APPLICATION WILL BE USED TO HIGHLIGHT OBJECTS AND GIVES AUDIO DESCRIPTION FOR THE PROVIDED INPUT IMAGE.")
136
+ demo.launch()
137
+
138
+ # print(output)