thorfine commited on
Commit
2152d81
·
verified ·
1 Parent(s): 1106bc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -14
app.py CHANGED
@@ -1,30 +1,51 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, Blip2ForConditionalGeneration, BitsAndBytesConfig,Blip2Processor
3
  from gtts import gTTS
4
  from tempfile import NamedTemporaryFile
5
  from PIL import Image
6
  import torch
7
- import os
8
- import torchaudio
9
  import whisper
10
 
11
- # Load BLIP-2 model
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
- quant_config = BitsAndBytesConfig(load_in_8bit=True)
15
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
16
- model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl", device_map="auto")
 
 
17
 
18
- # Load Whisper model (turbo version)
19
  whisper_model = whisper.load_model("small")
20
 
21
- def transcribe(audio):
22
- # Use Whisper for transcription
23
- result = whisper_model.transcribe(audio)
24
  return result["text"]
25
 
26
- from PIL import Image
27
- import torch
28
- from gtts import gTTS
29
- from tempfile import NamedTemporaryFile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
1
  import gradio as gr
2
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration, BitsAndBytesConfig
3
  from gtts import gTTS
4
  from tempfile import NamedTemporaryFile
5
  from PIL import Image
6
  import torch
 
 
7
  import whisper
8
 
9
+ # Set device
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ # Load BLIP-2 model
13
  processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
14
+ model = Blip2ForConditionalGeneration.from_pretrained(
15
+ "Salesforce/blip2-flan-t5-xl", device_map="auto"
16
+ ).to(device)
17
 
18
+ # Load Whisper model
19
  whisper_model = whisper.load_model("small")
20
 
21
+ # Transcribe function
22
+ def transcribe(audio_path):
23
+ result = whisper_model.transcribe(audio_path)
24
  return result["text"]
25
 
26
+ # Main function
27
+ def ask_image(image, audio):
28
+ question = transcribe(audio)
29
+ inputs = processor(images=image, text=question, return_tensors="pt").to(device)
30
+ generated_ids = model.generate(**inputs)
31
+ answer = processor.decode(generated_ids[0], skip_special_tokens=True)
32
+
33
+ tts = gTTS(answer)
34
+ with NamedTemporaryFile(delete=False, suffix=".mp3") as f:
35
+ tts.save(f.name)
36
+ audio_out = f.name
37
+
38
+ return answer, audio_out
39
+
40
+ # Gradio UI
41
+ with gr.Blocks() as demo:
42
+ gr.Markdown("## 🎤🖼️ Ask-the-Image: Ask questions about an image using your voice")
43
+ image_input = gr.Image(type="pil", label="Upload an Image")
44
+ audio_input = gr.Audio(type="filepath", label="Ask a Question (voice)", microphone=True)
45
+ text_output = gr.Textbox(label="Answer")
46
+ audio_output = gr.Audio(label="Answer in Speech")
47
+
48
+ btn = gr.Button("Ask")
49
+ btn.click(fn=ask_image, inputs=[image_input, audio_input], outputs=[text_output, audio_output])
50
 
51
+ demo.launch()