luke9705 commited on
Commit
f1e2fa3
·
1 Parent(s): 56f7f57

add tool for generation audio from sample; update agent initialization and Gradio interface structure; swapping back to Gemma for fast testing

Browse files
Files changed (1) hide show
  1. app.py +33 -25
app.py CHANGED
@@ -44,7 +44,7 @@ def load_file(path: str) -> list | dict:
44
  if image is not None:
45
  return [image]
46
  elif ext.endswith(".mp3") or ext.endswith(".wav"):
47
- return {"audio": text, "audio path": path}
48
  else:
49
  return {"raw document text": text, "file path": path}
50
 
@@ -157,7 +157,6 @@ def generate_audio(prompt: str, duration: int) -> gr.Component:
157
  Args:
158
  prompt: The text prompt to generate the audio from.
159
  duration: Duration of the generated audio in seconds. Max 30 seconds.
160
-
161
  Returns:
162
  gr.Component: The generated audio as a Gradio Audio component.
163
  """
@@ -167,18 +166,21 @@ def generate_audio(prompt: str, duration: int) -> gr.Component:
167
  name="Sound_Generator",
168
  description="Generate music or sound effects from a text prompt using MusicGen."
169
  )
170
- sound = client(prompt, duration)
 
 
 
171
 
172
  return gr.Audio(value=sound)
173
 
174
  @tool
175
- def generate_audio_from_sample(prompt: str, duration: int, sample: list[int, np.ndarray] = None) -> gr.Component:
176
  """
177
  Generate audio from a text prompt + audio sample using MusicGen.
178
  Args:
179
  prompt: The text prompt to generate the audio from.
180
  duration: Duration of the generated audio in seconds. Max 30 seconds.
181
- sample: Optional audio sample to guide generation.
182
 
183
  Returns:
184
  gr.Component: The generated audio as a Gradio Audio component.
@@ -189,21 +191,24 @@ def generate_audio_from_sample(prompt: str, duration: int, sample: list[int, np.
189
  name="Sound_Generator",
190
  description="Generate music or sound effects from a text prompt using MusicGen."
191
  )
192
- sound = client(prompt, duration, sample)
 
 
 
193
 
194
  return gr.Audio(value=sound)
195
 
196
 
197
-
198
  ## agent definition
199
  class Agent:
200
  def __init__(self, ):
201
  #client = HfApiModel("deepseek-ai/DeepSeek-R1-0528", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
202
- client = OpenAIServerModel(
 
203
  model_id="claude-opus-4-20250514",
204
  api_base="https://api.anthropic.com/v1/",
205
  api_key=os.environ["ANTHROPIC_API_KEY"],
206
- )
207
  self.agent = CodeAgent(
208
  model=client,
209
  tools=[DuckDuckGoSearchTool(max_results=5),
@@ -271,23 +276,26 @@ def initialize_agent():
271
  return agent
272
 
273
  ## gradio interface
274
- with gr.Blocks() as demo:
275
- global agent
276
- agent = initialize_agent()
277
- gr.ChatInterface(
278
- fn=respond,
279
- type='messages',
280
- multimodal=True,
281
- title='MultiAgent System for Screenplay Creation and Editing',
282
- show_progress='full',
283
- fill_height=True,
284
- fill_width=True,
285
- save_history=True,
286
- additional_inputs=[
 
287
  gr.Checkbox(value=False, label="Web Search",
288
- info="Enable web search to find information online. If disabled, the agent will only use the provided files and images.",
289
- render=False),
290
- ])
 
 
291
 
292
 
293
  if __name__ == "__main__":
 
44
  if image is not None:
45
  return [image]
46
  elif ext.endswith(".mp3") or ext.endswith(".wav"):
47
+ return {"audio path": path}
48
  else:
49
  return {"raw document text": text, "file path": path}
50
 
 
157
  Args:
158
  prompt: The text prompt to generate the audio from.
159
  duration: Duration of the generated audio in seconds. Max 30 seconds.
 
160
  Returns:
161
  gr.Component: The generated audio as a Gradio Audio component.
162
  """
 
166
  name="Sound_Generator",
167
  description="Generate music or sound effects from a text prompt using MusicGen."
168
  )
169
+ if duration > 30:
170
+ sound = client(prompt, 30)
171
+ else:
172
+ sound = client(prompt, duration)
173
 
174
  return gr.Audio(value=sound)
175
 
176
  @tool
177
+ def generate_audio_from_sample(prompt: str, duration: int, sample_path: str = None) -> gr.Component:
178
  """
179
  Generate audio from a text prompt + audio sample using MusicGen.
180
  Args:
181
  prompt: The text prompt to generate the audio from.
182
  duration: Duration of the generated audio in seconds. Max 30 seconds.
183
+ sample_path: audio sample path to guide generation.
184
 
185
  Returns:
186
  gr.Component: The generated audio as a Gradio Audio component.
 
191
  name="Sound_Generator",
192
  description="Generate music or sound effects from a text prompt using MusicGen."
193
  )
194
+ if duration > 30:
195
+ sound = client(prompt, 30, sample_path)
196
+ else:
197
+ sound = client(prompt, duration, sample_path)
198
 
199
  return gr.Audio(value=sound)
200
 
201
 
 
202
  ## agent definition
203
  class Agent:
204
  def __init__(self, ):
205
  #client = HfApiModel("deepseek-ai/DeepSeek-R1-0528", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
206
+ client = HfApiModel("google/gemma-3-27b-it", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
207
+ """client = OpenAIServerModel(
208
  model_id="claude-opus-4-20250514",
209
  api_base="https://api.anthropic.com/v1/",
210
  api_key=os.environ["ANTHROPIC_API_KEY"],
211
+ )"""
212
  self.agent = CodeAgent(
213
  model=client,
214
  tools=[DuckDuckGoSearchTool(max_results=5),
 
276
  return agent
277
 
278
  ## gradio interface
279
+
280
+ global agent
281
+ agent = initialize_agent()
282
+ demo = gr.ChatInterface(
283
+ fn=respond,
284
+ type='messages',
285
+ multimodal=True,
286
+ title='MultiAgent System for Screenplay Creation and Editing',
287
+ show_progress='full',
288
+ fill_height=True,
289
+ fill_width=True,
290
+ save_history=True,
291
+ autoscroll=True,
292
+ additional_inputs=[
293
  gr.Checkbox(value=False, label="Web Search",
294
+ info="Enable web search to find information online. If disabled, the agent will only use the provided files and images.",
295
+ render=False),
296
+ ],
297
+ additional_inputs_accordion=gr.Accordion(label="Tools available: ", open=True, render=False)
298
+ )
299
 
300
 
301
  if __name__ == "__main__":