luke9705 commited on
Commit
c1164ec
·
1 Parent(s): b107f85

Refactor load_file function to return paths for images and audio; add caption_image tool for generating image descriptions using Gemma3

Browse files
Files changed (1) hide show
  1. app.py +40 -16
app.py CHANGED
@@ -24,12 +24,11 @@ def is_image_extension(filename: str) -> bool:
24
  def load_file(path: str) -> list | dict:
25
  """Based on the file extension, load the file into a suitable object."""
26
 
27
- image = None
28
  text = None
29
  ext = Path(path).suffix.lower() # same as os.path.splitext(filename)[1].lower()
30
 
31
  if ext.endswith(".png") or ext.endswith(".jpg") or ext.endswith(".jpeg"):
32
- image = Image.open(path).convert("RGB") # pillow object
33
  elif ext.endswith(".xlsx") or ext.endswith(".xls"):
34
  text = pd.read_excel(path) # DataFrame
35
  elif ext.endswith(".csv"):
@@ -39,10 +38,7 @@ def load_file(path: str) -> list | dict:
39
  text = "\n".join(page.extract_text() for page in pdf.pages if page.extract_text())
40
  elif ext.endswith(".py") or ext.endswith(".txt"):
41
  with open(path, 'r') as f:
42
- text = f.read() # plain text str
43
-
44
- if image is not None:
45
- return [image]
46
  elif ext.endswith(".mp3") or ext.endswith(".wav"):
47
  return {"audio path": path}
48
  else:
@@ -197,6 +193,38 @@ def generate_audio_from_sample(prompt: str, duration: int, sample_path: str = No
197
  sound = client(prompt, duration, sample_path)
198
 
199
  return gr.Audio(value=sound)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
 
202
  ## agent definition
@@ -204,6 +232,7 @@ class Agent:
204
  def __init__(self, ):
205
  #client = HfApiModel("deepseek-ai/DeepSeek-R1-0528", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
206
  client = HfApiModel("Qwen/Qwen3-32B", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
 
207
  """client = OpenAIServerModel(
208
  model_id="claude-opus-4-20250514",
209
  api_base="https://api.anthropic.com/v1/",
@@ -216,6 +245,7 @@ class Agent:
216
  generate_image,
217
  generate_audio_from_sample,
218
  generate_audio,
 
219
  download_images,
220
  transcribe_audio],
221
  additional_authorized_imports=["pandas", "PIL", "io"],
@@ -237,6 +267,7 @@ class Agent:
237
  answer = self.agent.run(message, images = images, additional_args={"files": files, "conversation_history": conversation_history})
238
  return answer
239
 
 
240
  ## gradio functions
241
  def respond(message: str, history : dict, web_search: bool = False):
242
 
@@ -251,14 +282,7 @@ def respond(message: str, history : dict, web_search: bool = False):
251
  message = agent(text, conversation_history=history)
252
  else:
253
  files = message.get("files", [])
254
- print(f"files received: {files}")
255
- if is_image_extension(files[0]) and not web_search:
256
- image = load_file(files[0]) # assuming only one file is uploaded at a time (gradio default behavior)
257
- message = agent(text + "\nADDITIONAL CONTRAINT: Don't use web search", images=image, conversation_history=history)
258
- elif is_image_extension(files[0]) and web_search:
259
- image = load_file(files[0])
260
- message = agent(text, images=image, conversation_history=history)
261
- elif not web_search:
262
  file = load_file(files[0])
263
  message = agent(text + "\nADDITIONAL CONTRAINT: Don't use web search", files=file, conversation_history=history)
264
  else:
@@ -276,7 +300,6 @@ def initialize_agent():
276
  return agent
277
 
278
  ## gradio interface
279
-
280
  global agent
281
  agent = initialize_agent()
282
  demo = gr.ChatInterface(
@@ -289,13 +312,14 @@ demo = gr.ChatInterface(
289
  fill_width=True,
290
  save_history=True,
291
  autoscroll=True,
 
292
  additional_inputs=[
293
  gr.Checkbox(value=False, label="Web Search",
294
  info="Enable web search to find information online. If disabled, the agent will only use the provided files and images.",
295
  render=False),
296
  ],
297
  additional_inputs_accordion=gr.Accordion(label="Tools available: ", open=True, render=False)
298
- )
299
 
300
 
301
  if __name__ == "__main__":
 
24
  def load_file(path: str) -> list | dict:
25
  """Based on the file extension, load the file into a suitable object."""
26
 
 
27
  text = None
28
  ext = Path(path).suffix.lower() # same as os.path.splitext(filename)[1].lower()
29
 
30
  if ext.endswith(".png") or ext.endswith(".jpg") or ext.endswith(".jpeg"):
31
+ return {"image path": path}
32
  elif ext.endswith(".xlsx") or ext.endswith(".xls"):
33
  text = pd.read_excel(path) # DataFrame
34
  elif ext.endswith(".csv"):
 
38
  text = "\n".join(page.extract_text() for page in pdf.pages if page.extract_text())
39
  elif ext.endswith(".py") or ext.endswith(".txt"):
40
  with open(path, 'r') as f:
41
+ text = f.read() # plain text str
 
 
 
42
  elif ext.endswith(".mp3") or ext.endswith(".wav"):
43
  return {"audio path": path}
44
  else:
 
193
  sound = client(prompt, duration, sample_path)
194
 
195
  return gr.Audio(value=sound)
196
+
197
+ @tool
198
+ def caption_image(img_path: str, prompt: str) -> str:
199
+ """
200
+ Generate a caption for an image at the given path using Gemma3.
201
+ Args:
202
+ img_path: The file path to the image to be captioned.
203
+ prompt: A text prompt describing what you want the model to focus on or ask about the image.
204
+ Returns:
205
+ str: A description of the image.
206
+ """
207
+ client_2 = HfApiModel("google/gemma-3-27b-it",
208
+ provider="nebius",
209
+ api_key=os.getenv("NEBIUS_API_KEY"))
210
+
211
+ with open(img_path, "rb") as f:
212
+ encoded = base64.b64encode(f.read()).decode("utf-8")
213
+ data_uri = f"data:image/jpeg;base64,{encoded}"
214
+ messages = [{"role": "user", "content": [
215
+ {
216
+ "type": "text",
217
+ "text": prompt,
218
+ },
219
+ {
220
+ "type": "image_url",
221
+ "image_url": {
222
+ "url": data_uri
223
+ }
224
+ }
225
+ ]}]
226
+ resp = client_2(messages)
227
+ return resp.content
228
 
229
 
230
  ## agent definition
 
232
  def __init__(self, ):
233
  #client = HfApiModel("deepseek-ai/DeepSeek-R1-0528", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
234
  client = HfApiModel("Qwen/Qwen3-32B", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
235
+
236
  """client = OpenAIServerModel(
237
  model_id="claude-opus-4-20250514",
238
  api_base="https://api.anthropic.com/v1/",
 
245
  generate_image,
246
  generate_audio_from_sample,
247
  generate_audio,
248
+ caption_image,
249
  download_images,
250
  transcribe_audio],
251
  additional_authorized_imports=["pandas", "PIL", "io"],
 
267
  answer = self.agent.run(message, images = images, additional_args={"files": files, "conversation_history": conversation_history})
268
  return answer
269
 
270
+
271
  ## gradio functions
272
  def respond(message: str, history : dict, web_search: bool = False):
273
 
 
282
  message = agent(text, conversation_history=history)
283
  else:
284
  files = message.get("files", [])
285
+ if not web_search:
 
 
 
 
 
 
 
286
  file = load_file(files[0])
287
  message = agent(text + "\nADDITIONAL CONTRAINT: Don't use web search", files=file, conversation_history=history)
288
  else:
 
300
  return agent
301
 
302
  ## gradio interface
 
303
  global agent
304
  agent = initialize_agent()
305
  demo = gr.ChatInterface(
 
312
  fill_width=True,
313
  save_history=True,
314
  autoscroll=True,
315
+ #css = css_snippet,
316
  additional_inputs=[
317
  gr.Checkbox(value=False, label="Web Search",
318
  info="Enable web search to find information online. If disabled, the agent will only use the provided files and images.",
319
  render=False),
320
  ],
321
  additional_inputs_accordion=gr.Accordion(label="Tools available: ", open=True, render=False)
322
+ ).queue()
323
 
324
 
325
  if __name__ == "__main__":