|
import gradio as gr |
|
import os |
|
import base64 |
|
import pandas as pd |
|
from PIL import Image |
|
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, VisitWebpageTool, OpenAIServerModel, tool |
|
from typing import Optional |
|
import requests |
|
from io import BytesIO |
|
import re |
|
from pathlib import Path |
|
import openai |
|
from openai import OpenAI |
|
import pdfplumber |
|
|
|
|
|
|
|
def is_image_extension(filename: str) -> bool: |
|
IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.svg'} |
|
ext = os.path.splitext(filename)[1].lower() |
|
return ext in IMAGE_EXTS |
|
|
|
def load_file(path: str) -> list | dict: |
|
"""Based on the file extension, load the file into a suitable object.""" |
|
|
|
image = None |
|
text = None |
|
ext = Path(path).suffix.lower() |
|
|
|
if ext.endswith(".png") or ext.endswith(".jpg") or ext.endswith(".jpeg"): |
|
image = Image.open(path).convert("RGB") |
|
elif ext.endswith(".xlsx") or ext.endswith(".xls"): |
|
text = pd.read_excel(path) |
|
elif ext.endswith(".csv"): |
|
text = pd.read_csv(path) |
|
elif ext.endswith(".pdf"): |
|
with pdfplumber.open(path) as pdf: |
|
text = "\n".join(page.extract_text() for page in pdf.pages if page.extract_text()) |
|
elif ext.endswith(".py") or ext.endswith(".txt"): |
|
with open(path, 'r') as f: |
|
text = f.read() |
|
|
|
if image is not None: |
|
return [image] |
|
elif ext.endswith(".mp3") or ext.endswith(".wav"): |
|
return {"raw document text": text, "audio path": path} |
|
else: |
|
return {"raw document text": text, "file path": path} |
|
|
|
def check_format(answer: str | list, *args, **kwargs) -> list: |
|
"""Check if the answer is a list and not a nested list.""" |
|
print("Checking format of the answer:", answer) |
|
if isinstance(answer, list): |
|
for item in answer: |
|
if isinstance(item, list): |
|
print("list detected") |
|
raise TypeError("Nested lists are not allowed in the final answer.") |
|
elif isinstance(answer, str): |
|
return [answer] |
|
elif isinstance(answer, dict): |
|
raise TypeError(f"Final answer must be a list, not a dict. Please check the answer format. Error: {e}") |
|
|
|
|
|
|
|
|
|
@tool |
|
def download_images(image_urls: str) -> list: |
|
""" |
|
Download web images from the given comma‐separated URLs and return them in a list of PIL Images. |
|
Args: |
|
image_urls: comma‐separated list of URLs to download |
|
Returns: |
|
List of PIL.Image.Image objects |
|
""" |
|
urls = [u.strip() for u in image_urls.split(",") if u.strip()] |
|
images = [] |
|
for __, url in enumerate(urls, start=1): |
|
try: |
|
|
|
resp = requests.get(url, timeout=10) |
|
resp.raise_for_status() |
|
|
|
|
|
img = Image.open(BytesIO(resp.content)).convert("RGB") |
|
images.append(img) |
|
|
|
except Exception as e: |
|
print(f"Failed to download from {url}: {e}") |
|
return images |
|
|
|
@tool |
|
def transcribe_audio(audio_path: str) -> str: |
|
""" |
|
Transcribe audio file using OpenAI Whisper API. |
|
Args: |
|
audio_path: path to the audio file to be transcribed. |
|
Returns: |
|
str : Transcription of the audio. |
|
""" |
|
client = openai.Client(api_key=os.getenv("OPENAI_API_KEY")) |
|
with open(audio_path, "rb") as audio: |
|
transcript = client.audio.transcriptions.create( |
|
file=audio, |
|
model="whisper-1", |
|
response_format="text", |
|
) |
|
print(transcript) |
|
try: |
|
return transcript |
|
except Exception as e: |
|
print(f"Error transcribing audio: {e}") |
|
|
|
@tool |
|
def generate_image(prompt: str, neg_prompt: str) -> Image.Image: |
|
""" |
|
Generate an image based on a text prompt using Flux Dev. |
|
Args: |
|
prompt: The text prompt to generate the image from. |
|
neg_prompt: The negative prompt to avoid certain elements in the image. |
|
Returns: |
|
Image.Image: The generated image as a PIL Image object. |
|
""" |
|
client = OpenAI(base_url="https://api.studio.nebius.com/v1", |
|
api_key=os.environ.get("NEBIUS_API_KEY"), |
|
) |
|
|
|
completion = client.images.generate( |
|
model="black-forest-labs/flux-dev", |
|
prompt=prompt, |
|
response_format="b64_json", |
|
extra_body={ |
|
"response_extension": "png", |
|
"width": 1024, |
|
"height": 1024, |
|
"num_inference_steps": 30, |
|
"seed": -1, |
|
"negative_prompt": neg_prompt, |
|
} |
|
) |
|
|
|
image_data = base64.b64decode(completion.to_dict()['data'][0]['b64_json']) |
|
image = BytesIO(image_data) |
|
image = Image.open(image).convert("RGB") |
|
|
|
return gr.Image(value=image, label="Generated Image") |
|
|
|
"""@tool |
|
def generate_audio(prompt: str) -> object: |
|
space = smolagents.load_tool( |
|
|
|
)""" |
|
|
|
|
|
|
|
|
|
|
|
class Agent: |
|
def __init__(self, ): |
|
client = HfApiModel("google/gemma-3-27b-it", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY")) |
|
self.agent = CodeAgent( |
|
model=client, |
|
tools=[DuckDuckGoSearchTool(max_results=5), VisitWebpageTool(max_output_length=20000), generate_image, download_images, transcribe_audio], |
|
additional_authorized_imports=["pandas", "PIL", "io"], |
|
planning_interval=1, |
|
max_steps=5, |
|
stream_outputs=False, |
|
final_answer_checks=[check_format] |
|
) |
|
with open("system_prompt.txt", "r") as f: |
|
system_prompt = f.read() |
|
self.agent.prompt_templates["system_prompt"] = system_prompt |
|
|
|
|
|
|
|
def __call__(self, message: str, |
|
images: Optional[list[Image.Image]] = None, |
|
files: Optional[str] = None, |
|
conversation_history: Optional[dict] = None) -> str: |
|
answer = self.agent.run(message, images = images, additional_args={"files": files, "conversation_history": conversation_history}) |
|
return answer |
|
|
|
|
|
def respond(message: str, history : dict, web_search: bool = False): |
|
|
|
|
|
print("history:", history) |
|
text = message.get("text", "") |
|
if not message.get("files") and not web_search: |
|
print("No files received.") |
|
message = agent(text + "\nADDITIONAL CONTRAINT: Don't use web search", conversation_history=history) |
|
elif not message.get("files") and web_search==True: |
|
print("No files received + web search enabled.") |
|
message = agent(text, conversation_history=history) |
|
else: |
|
files = message.get("files", []) |
|
print(f"files received: {files}") |
|
if is_image_extension(files[0]): |
|
image = load_file(files[0]) |
|
message = agent(text, images=image, conversation_history=history) |
|
else: |
|
file = load_file(files[0]) |
|
message = agent(text, files=file, conversation_history=history) |
|
|
|
print("Agent response:", message) |
|
|
|
return message |
|
|
|
def initialize_agent(): |
|
agent = Agent() |
|
print("Agent initialized.") |
|
return agent |
|
|
|
|
|
with gr.Blocks() as demo: |
|
global agent |
|
agent = initialize_agent() |
|
gr.ChatInterface( |
|
fn=respond, |
|
type='messages', |
|
multimodal=True, |
|
title='MultiAgent System for Screenplay Creation and Editing', |
|
show_progress='full', |
|
fill_height=True, |
|
fill_width=False, |
|
save_history=True, |
|
additional_inputs=[ |
|
gr.Checkbox(value=False, label="Web Search", |
|
info="Enable web search to find information online. If disabled, the agent will only use the provided files and images.", |
|
render=False), |
|
]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|