luke9705 commited on
Commit
5090fe0
·
1 Parent(s): 5a4500e

lack image generation

Browse files
Files changed (1) hide show
  1. app.py +57 -19
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import os
 
3
  import pandas as pd
4
  from PIL import Image
5
  from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, VisitWebpageTool, OpenAIServerModel, tool
@@ -11,12 +12,12 @@ from pathlib import Path
11
  import openai
12
 
13
  ## utilty functions
14
- def is_image_extension(filename: str) -> bool: # not used in the code, but useful to have
15
  IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.svg'}
16
  ext = os.path.splitext(filename)[1].lower() # os.path.splitext(path) returns (root, ext)
17
  return ext in IMAGE_EXTS
18
 
19
- def load_file(path: list) -> dict:
20
  """Based on the file extension, load the file into a suitable object."""
21
 
22
  image = None
@@ -24,7 +25,6 @@ def load_file(path: list) -> dict:
24
  csv = None
25
  text = None
26
  ext = Path(path).suffix.lower() # same as os.path.splitext(filename)[1].lower()
27
- print(f"ext: {ext}")
28
 
29
  if ext.endswith(".png") or ext.endswith(".jpg") or ext.endswith(".jpeg"):
30
  image = Image.open(path).convert("RGB") # pillow object
@@ -35,11 +35,11 @@ def load_file(path: list) -> dict:
35
  elif ext.endswith(".py") or ext.endswith(".txt"):
36
  with open(path, 'r') as f:
37
  text = f.read() # plain text str
38
- elif ext.endswith(".mp3") or ext.endswith(".wav"):
39
- with open(path, 'wb') as f:
40
- f.write("output.mp3") # binary data (leave it hardcoded for now)
41
-
42
- return {"image" : image, "excel": excel, "csv": csv, "raw text": text}
43
 
44
 
45
  ## tools definition
@@ -69,15 +69,16 @@ def download_images(image_urls: str) -> list:
69
  return images
70
 
71
  @tool # since they gave us OpenAI API credits, we can keep using it
72
- def transcribe_audio() -> str:
73
  """
74
  Transcribe audio file using OpenAI Whisper API.
75
- The path to the audio file is hardcoded as "output.mp3". Don't need to pass it as an argument.
 
76
  Returns:
77
- str: Transcription of the audio.
78
  """
79
  client = openai.Client(api_key=os.getenv("OPEN_AI_API_KEY"))
80
- with open("output.mp3", "rb") as audio: # to modify path because it is arriving from gradio
81
  transcript = client.audio.transcriptions.create(
82
  file=audio,
83
  model="whisper-1",
@@ -89,6 +90,39 @@ def transcribe_audio() -> str:
89
  except Exception as e:
90
  print(f"Error transcribing audio: {e}")
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  ## agent definition
94
  class Agent:
@@ -96,7 +130,7 @@ class Agent:
96
  client = HfApiModel("google/gemma-3-27b-it", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
97
  self.agent = CodeAgent(
98
  model=client,
99
- tools=[DuckDuckGoSearchTool(max_results=5), VisitWebpageTool(max_output_length=20000), download_images, transcribe_audio],
100
  additional_authorized_imports=["pandas", "PIL", "io"],
101
  planning_interval=1,
102
  max_steps=5,
@@ -105,21 +139,25 @@ class Agent:
105
  #print("System prompt:", self.agent.prompt_templates["system_prompt"])
106
 
107
  def __call__(self, message: str, images: Optional[list[Image.Image]] = None, files: Optional[str] = None) -> str:
108
- answer = self.agent.run(message, additional_args={"images": images ,"files": files})
109
  return answer
110
 
111
  ## gradio functions
112
  def respond(message, history):
113
 
114
  text = message.get("text", "")
115
- if not message.get("files"):
116
  print("No files received.")
117
  message = agent(text)
118
  else:
119
  files = message.get("files", [])
120
  print(f"files received: {files}")
121
- file = load_file(files[0])
122
- message = agent(text, files=file)
 
 
 
 
123
 
124
  return message
125
 
@@ -128,7 +166,7 @@ def initialize_agent():
128
  print("Agent initialized.")
129
  return agent
130
 
131
-
132
  with gr.Blocks() as demo:
133
  global agent
134
  agent = initialize_agent()
@@ -136,7 +174,7 @@ with gr.Blocks() as demo:
136
  fn=respond,
137
  type='messages',
138
  multimodal=True,
139
- title='MultiAgent_System_for_Screenplay_Creation_and_Editing',
140
  show_progress='full'
141
  )
142
 
 
1
  import gradio as gr
2
  import os
3
+ import base64
4
  import pandas as pd
5
  from PIL import Image
6
  from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, VisitWebpageTool, OpenAIServerModel, tool
 
12
  import openai
13
 
14
  ## utilty functions
15
+ def is_image_extension(filename: str) -> bool:
16
  IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.svg'}
17
  ext = os.path.splitext(filename)[1].lower() # os.path.splitext(path) returns (root, ext)
18
  return ext in IMAGE_EXTS
19
 
20
+ def load_file(path: str) -> list | dict:
21
  """Based on the file extension, load the file into a suitable object."""
22
 
23
  image = None
 
25
  csv = None
26
  text = None
27
  ext = Path(path).suffix.lower() # same as os.path.splitext(filename)[1].lower()
 
28
 
29
  if ext.endswith(".png") or ext.endswith(".jpg") or ext.endswith(".jpeg"):
30
  image = Image.open(path).convert("RGB") # pillow object
 
35
  elif ext.endswith(".py") or ext.endswith(".txt"):
36
  with open(path, 'r') as f:
37
  text = f.read() # plain text str
38
+
39
+ if image is not None:
40
+ return [image]
41
+ else:
42
+ return {"excel": excel, "csv": csv, "raw text": text, "audio path": path}
43
 
44
 
45
  ## tools definition
 
69
  return images
70
 
71
  @tool # since they gave us OpenAI API credits, we can keep using it
72
+ def transcribe_audio(audio_path: str) -> str:
73
  """
74
  Transcribe audio file using OpenAI Whisper API.
75
+ Args:
76
+ audio_path (str): path to the audio file to be transcribed.
77
  Returns:
78
+ str : Transcription of the audio.
79
  """
80
  client = openai.Client(api_key=os.getenv("OPEN_AI_API_KEY"))
81
+ with open(audio_path, "rb") as audio: # to modify path because it is arriving from gradio
82
  transcript = client.audio.transcriptions.create(
83
  file=audio,
84
  model="whisper-1",
 
90
  except Exception as e:
91
  print(f"Error transcribing audio: {e}")
92
 
93
+ @tool
94
+ def generate_image(prompt: str, neg_prompt: str) -> Image.Image:
95
+ """
96
+ Generate an image based on a text prompt using Flux Dev.
97
+ Args:
98
+ prompt (str): The text prompt to generate the image from.
99
+ neg_prompt (str): The negative prompt to avoid certain elements in the image.
100
+ Returns:
101
+ Image.Image: The generated image as a PIL Image object.
102
+ """
103
+ client = OpenAI(base_url="https://api.studio.nebius.com/v1",
104
+ api_key=os.environ.get("NEBIUS_API_KEY"),
105
+ )
106
+
107
+ completion = client.images.generate(
108
+ model="black-forest-labs/flux-dev",
109
+ prompt=prompt,
110
+ response_format="b64_json",
111
+ extra_body={
112
+ "response_extension": "png",
113
+ "width": 1024,
114
+ "height": 1024,
115
+ "num_inference_steps": 30,
116
+ "seed": -1,
117
+ "negative_prompt": neg_prompt,
118
+ }
119
+ )
120
+
121
+ image_data = base64.b64decode(completion.to_dict()['data'][0]['b64_json'])
122
+ image = Image.open(BytesIO(image_data))
123
+ return image
124
+
125
+
126
 
127
  ## agent definition
128
  class Agent:
 
130
  client = HfApiModel("google/gemma-3-27b-it", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
131
  self.agent = CodeAgent(
132
  model=client,
133
+ tools=[DuckDuckGoSearchTool(max_results=5), VisitWebpageTool(max_output_length=20000), generate_image, download_images, transcribe_audio],
134
  additional_authorized_imports=["pandas", "PIL", "io"],
135
  planning_interval=1,
136
  max_steps=5,
 
139
  #print("System prompt:", self.agent.prompt_templates["system_prompt"])
140
 
141
  def __call__(self, message: str, images: Optional[list[Image.Image]] = None, files: Optional[str] = None) -> str:
142
+ answer = self.agent.run(message, images = images, additional_args={"files": files})
143
  return answer
144
 
145
  ## gradio functions
146
  def respond(message, history):
147
 
148
  text = message.get("text", "")
149
+ if not message.get("files"): # no files uploaded
150
  print("No files received.")
151
  message = agent(text)
152
  else:
153
  files = message.get("files", [])
154
  print(f"files received: {files}")
155
+ if is_image_extension(files[0]):
156
+ image = load_file(files[0]) # assuming only one file is uploaded at a time (gradio default behavior)
157
+ message = agent(text, images=image)
158
+ else:
159
+ file = load_file(files[0])
160
+ message = agent(text, files=file)
161
 
162
  return message
163
 
 
166
  print("Agent initialized.")
167
  return agent
168
 
169
+ ## gradio interface
170
  with gr.Blocks() as demo:
171
  global agent
172
  agent = initialize_agent()
 
174
  fn=respond,
175
  type='messages',
176
  multimodal=True,
177
+ title='MultiAgent System for Screenplay Creation and Editing',
178
  show_progress='full'
179
  )
180