SameerArz commited on
Commit
3afc8c8
ยท
verified ยท
1 Parent(s): 60ed045

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -22
app.py CHANGED
@@ -1,28 +1,29 @@
1
  import gradio as gr
2
- from groq import Groq
3
  import os
4
  import threading
5
  import base64
6
  from io import BytesIO
 
7
 
8
- # Initialize Groq client (No need for Mistral API)
9
- client = Groq(api_key=os.environ["GROQ_API_KEY"])
10
 
11
- # Load Text-to-Image Models
12
  model1 = gr.load("models/prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA")
13
  model2 = gr.load("models/Purz/face-projection")
 
14
 
15
- # Stop event for threading (image generation)
16
  stop_event = threading.Event()
17
 
18
- # Convert PIL image to Base64
19
- def pil_to_base64(pil_image, image_format='jpeg'):
20
  buffered = BytesIO()
21
  pil_image.save(buffered, format=image_format)
22
- base64_string = base64.b64encode(buffered.getvalue()).decode('utf-8')
23
  return base64_string, image_format
24
 
25
- # Function for Visual Question Answering (Groq)
26
  def answer_question(text, image, temperature=0.0, max_tokens=1024):
27
  base64_string, file_format = pil_to_base64(image)
28
 
@@ -36,8 +37,8 @@ def answer_question(text, image, temperature=0.0, max_tokens=1024):
36
  }
37
  ]
38
 
39
- chat_response = client.chat.completions.create(
40
- model="gemma2-9b-it", # Groq model for vision tasks
41
  messages=messages,
42
  temperature=temperature,
43
  max_tokens=max_tokens
@@ -45,18 +46,24 @@ def answer_question(text, image, temperature=0.0, max_tokens=1024):
45
 
46
  return chat_response.choices[0].message.content
47
 
 
 
 
 
 
 
 
48
 
49
- # Clear all fields
50
  def clear_all():
51
- return "", None, ""
52
-
53
 
54
- # Set up the Gradio interface
55
  with gr.Blocks() as demo:
56
- gr.Markdown("# ๐ŸŽ“ AI Tutor & Visual Learning Assistant")
57
 
58
- # Section 3: Visual Question Answering (Groq)
59
- gr.Markdown("## ๐Ÿ–ผ๏ธ Visual Question Answering (Groq)")
60
  with gr.Row():
61
  with gr.Column(scale=2):
62
  question = gr.Textbox(placeholder="Ask about the image...", lines=2)
@@ -66,24 +73,42 @@ with gr.Blocks() as demo:
66
  max_tokens = gr.Slider(label="Max Tokens", minimum=128, maximum=2048, value=1024, step=128)
67
 
68
  with gr.Column(scale=3):
69
- output_text = gr.Textbox(lines=10, label="Groq VQA Response")
70
 
71
  with gr.Row():
72
  clear_btn = gr.Button("Clear", variant="secondary")
73
  submit_btn_vqa = gr.Button("Submit", variant="primary")
74
 
75
- # VQA Processing
 
 
 
 
 
 
 
 
 
 
 
76
  submit_btn_vqa.click(
77
  fn=answer_question,
78
  inputs=[question, image, temperature, max_tokens],
79
  outputs=[output_text]
80
  )
81
 
82
- # Clear VQA Inputs
 
 
 
 
 
 
 
83
  clear_btn.click(
84
  fn=clear_all,
85
  inputs=[],
86
- outputs=[question, image, output_text]
87
  )
88
 
89
  if __name__ == "__main__":
 
1
  import gradio as gr
 
2
  import os
3
  import threading
4
  import base64
5
  from io import BytesIO
6
+ from groq import Groq
7
 
8
+ # ๐Ÿ”น Initialize Groq API Client (FREE)
9
+ groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
10
 
11
+ # ๐Ÿ”น Load Text-to-Image Models (Restoring Multi-Image Generation)
12
  model1 = gr.load("models/prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA")
13
  model2 = gr.load("models/Purz/face-projection")
14
+ model3 = gr.load("models/stablediffusion/stable-diffusion-xl")
15
 
16
+ # ๐Ÿ”น Stop Event for Threading
17
  stop_event = threading.Event()
18
 
19
+ # ๐Ÿ”น Convert PIL image to Base64
20
+ def pil_to_base64(pil_image, image_format="jpeg"):
21
  buffered = BytesIO()
22
  pil_image.save(buffered, format=image_format)
23
+ base64_string = base64.b64encode(buffered.getvalue()).decode("utf-8")
24
  return base64_string, image_format
25
 
26
+ # ๐Ÿ”น Function for Visual Question Answering (VQA) with Mixtral-8x7B
27
  def answer_question(text, image, temperature=0.0, max_tokens=1024):
28
  base64_string, file_format = pil_to_base64(image)
29
 
 
37
  }
38
  ]
39
 
40
+ chat_response = groq_client.chat.completions.create(
41
+ model="mixtral-8x7b-32768",
42
  messages=messages,
43
  temperature=temperature,
44
  max_tokens=max_tokens
 
46
 
47
  return chat_response.choices[0].message.content
48
 
49
+ # ๐Ÿ”น Function to Generate Three Images (Multi-Output)
50
+ def generate_images(prompt):
51
+ stop_event.clear()
52
+ img1 = model1.predict(prompt)
53
+ img2 = model2.predict(prompt)
54
+ img3 = model3.predict(prompt)
55
+ return img1, img2, img3
56
 
57
+ # ๐Ÿ”น Clear All Fields
58
  def clear_all():
59
+ return "", None, "", None, None, None
 
60
 
61
+ # ๐Ÿ”น Set up Gradio Interface
62
  with gr.Blocks() as demo:
63
+ gr.Markdown("# ๐ŸŽ“ AI Tutor, VQA & Image Generation")
64
 
65
+ # ๐Ÿ”น Section 1: Visual Question Answering (Groq)
66
+ gr.Markdown("## ๐Ÿ–ผ๏ธ Visual Question Answering (Mixtral-8x7B)")
67
  with gr.Row():
68
  with gr.Column(scale=2):
69
  question = gr.Textbox(placeholder="Ask about the image...", lines=2)
 
73
  max_tokens = gr.Slider(label="Max Tokens", minimum=128, maximum=2048, value=1024, step=128)
74
 
75
  with gr.Column(scale=3):
76
+ output_text = gr.Textbox(lines=10, label="Mixtral VQA Response")
77
 
78
  with gr.Row():
79
  clear_btn = gr.Button("Clear", variant="secondary")
80
  submit_btn_vqa = gr.Button("Submit", variant="primary")
81
 
82
+ # ๐Ÿ”น Section 2: Image Generation (3 Outputs)
83
+ gr.Markdown("## ๐ŸŽจ AI-Generated Images (3 Variations)")
84
+ with gr.Row():
85
+ prompt = gr.Textbox(placeholder="Describe the image you want...", lines=2)
86
+ generate_btn = gr.Button("Generate Images", variant="primary")
87
+
88
+ with gr.Row():
89
+ image1 = gr.Image(label="Image 1")
90
+ image2 = gr.Image(label="Image 2")
91
+ image3 = gr.Image(label="Image 3")
92
+
93
+ # ๐Ÿ”น VQA Processing
94
  submit_btn_vqa.click(
95
  fn=answer_question,
96
  inputs=[question, image, temperature, max_tokens],
97
  outputs=[output_text]
98
  )
99
 
100
+ # ๐Ÿ”น Image Generation Processing
101
+ generate_btn.click(
102
+ fn=generate_images,
103
+ inputs=[prompt],
104
+ outputs=[image1, image2, image3]
105
+ )
106
+
107
+ # ๐Ÿ”น Clear All Inputs
108
  clear_btn.click(
109
  fn=clear_all,
110
  inputs=[],
111
+ outputs=[question, image, output_text, image1, image2, image3]
112
  )
113
 
114
  if __name__ == "__main__":