Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
import tempfile
|
6 |
+
import shutil
|
7 |
+
from functools import partial # To create a callable for our custom tool
|
8 |
+
|
9 |
+
from diffusers import StableDiffusionPipeline
|
10 |
+
from huggingface_hub import InferenceClient
|
11 |
+
|
12 |
+
# LangChain imports
|
13 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
14 |
+
from langchain_core.tools import tool
|
15 |
+
from langchain_community.tools import DuckDuckGoSearchRun
|
16 |
+
from langchain_community.llms import HuggingFaceHub
|
17 |
+
from langchain.agents import AgentExecutor, create_react_agent
|
18 |
+
from langchain.schema import HumanMessage, AIMessage
|
19 |
+
|
20 |
+
# --- 1. Load Stable Diffusion Pipeline (happens once at startup) ---
|
21 |
+
HF_TOKEN = os.environ.get("HF_TOKEN") # Using HF_TOKEN for consistency with HuggingFaceHub LLM
|
22 |
+
|
23 |
+
# Define the model ID for image generation
|
24 |
+
IMAGE_GEN_MODEL_ID = "segmind/tiny-sd" # Using the smaller model as it loaded successfully
|
25 |
+
# IMAGE_GEN_MODEL_ID = "runwayml/stable-diffusion-v1-5" # You can try this again after proving basic functionality
|
26 |
+
|
27 |
+
print(f"Loading Stable Diffusion Pipeline directly on GPU: {IMAGE_GEN_MODEL_ID}...")
|
28 |
+
try:
|
29 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
30 |
+
IMAGE_GEN_MODEL_ID,
|
31 |
+
torch_dtype=torch.float16,
|
32 |
+
use_safetensors=False, # Set to False for models that don't have safetensors (like tiny-sd)
|
33 |
+
token=HF_TOKEN # Pass token for potential faster model download
|
34 |
+
)
|
35 |
+
pipe.to("cuda") # Move the model to the GPU
|
36 |
+
print(f"Stable Diffusion Pipeline ({IMAGE_GEN_MODEL_ID}) loaded successfully on GPU.")
|
37 |
+
except Exception as e:
|
38 |
+
print("β Error loading Stable Diffusion Pipeline:")
|
39 |
+
import traceback
|
40 |
+
traceback.print_exc()
|
41 |
+
pipe = None # Indicate failure to load
|
42 |
+
|
43 |
+
# --- 2. Define Custom Image Generation Tool for LangChain ---
|
44 |
+
# Use @tool decorator to make a function a LangChain tool
|
45 |
+
@tool
|
46 |
+
def image_generator(prompt: str) -> str:
|
47 |
+
"""
|
48 |
+
Generates an image from a detailed text prompt using a Stable Diffusion pipeline.
|
49 |
+
The input MUST be a detailed text description for the image to generate.
|
50 |
+
"""
|
51 |
+
if pipe is None:
|
52 |
+
return "Error: Image generation pipeline failed to load. Please check Space logs during startup."
|
53 |
+
|
54 |
+
print(f"\n--- Agent is calling image_generator with prompt: '{prompt}' ---")
|
55 |
+
try:
|
56 |
+
with torch.no_grad():
|
57 |
+
pil_image = pipe(prompt, guidance_scale=7.5, height=512, width=512).images[0]
|
58 |
+
|
59 |
+
# Save the PIL image to a temporary file, Gradio will handle displaying this path
|
60 |
+
# NOTE: LangChain tools typically return strings. For image display, we'll return
|
61 |
+
# the path, and handle its display in the Gradio UI directly based on content.
|
62 |
+
temp_dir = tempfile.mkdtemp()
|
63 |
+
image_path = os.path.join(temp_dir, "generated_image.png")
|
64 |
+
pil_image.save(image_path)
|
65 |
+
|
66 |
+
print(f"Image saved to temporary path: {image_path}")
|
67 |
+
# Return a special string prefix so Gradio knows it's an image path
|
68 |
+
return f"__IMAGE_PATH__:{image_path}"
|
69 |
+
except Exception as e:
|
70 |
+
print("Error in image_generator tool execution:")
|
71 |
+
traceback.print_exc()
|
72 |
+
return f"Error generating image: {str(e)}"
|
73 |
+
|
74 |
+
# --- 3. Define other Tools for LangChain ---
|
75 |
+
search = DuckDuckGoSearchRun()
|
76 |
+
|
77 |
+
# --- 4. Define the LangChain Agent ---
|
78 |
+
# Ensure models are loaded successfully before proceeding
|
79 |
+
if pipe is None:
|
80 |
+
raise RuntimeError("Cannot start agent as image generation pipeline failed to load. Check logs.")
|
81 |
+
|
82 |
+
# Instantiate the LLM for the agent
|
83 |
+
# Using HuggingFaceHub to connect to Zephyr-7b-beta model on HF Inference API
|
84 |
+
# Ensure HF_TOKEN is set as a Space Secret
|
85 |
+
llm = HuggingFaceHub(
|
86 |
+
repo_id="HuggingFaceH4/zephyr-7b-beta",
|
87 |
+
huggingfacehub_api_token=HF_TOKEN, # Use HF_TOKEN directly as required by HuggingFaceHub LLM
|
88 |
+
model_kwargs={"temperature": 0.5, "max_new_tokens": 512}
|
89 |
+
)
|
90 |
+
|
91 |
+
# Create the tools list
|
92 |
+
tools = [image_generator, search]
|
93 |
+
|
94 |
+
# Define the agent prompt
|
95 |
+
# This prompt guides the LLM on how to use the tools
|
96 |
+
prompt_template = ChatPromptTemplate.from_messages(
|
97 |
+
[
|
98 |
+
("system", """You are a powerful AI assistant that can generate images and search the web.
|
99 |
+
You have access to the following tools: {tools}
|
100 |
+
|
101 |
+
When you need to generate an image, use the `image_generator` tool. Its input must be a very detailed, descriptive text string.
|
102 |
+
When you need factual information or context, use the `search` tool.
|
103 |
+
|
104 |
+
Always follow these steps:
|
105 |
+
1. Think step-by-step: Analyze the user's request and determine if you need to search or generate an image.
|
106 |
+
2. If you need to search, use the `search` tool.
|
107 |
+
3. If you need to generate an image, ensure you have enough detail. If not, ask for more or use search.
|
108 |
+
4. When you have enough information, use the `image_generator` tool.
|
109 |
+
5. Provide your final answer. If you generated an image, include the image in your final answer.
|
110 |
+
"""),
|
111 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
112 |
+
("human", "{input}"),
|
113 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
114 |
+
]
|
115 |
+
)
|
116 |
+
|
117 |
+
# Create the agent
|
118 |
+
agent = create_react_agent(llm, tools, prompt_template)
|
119 |
+
|
120 |
+
# Create the agent executor
|
121 |
+
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
|
122 |
+
|
123 |
+
# --- 5. Gradio UI Integration ---
|
124 |
+
|
125 |
+
# Function to run the agent and display output
|
126 |
+
def run_agent_in_gradio(message, history):
|
127 |
+
# Convert Gradio history to LangChain chat_history format
|
128 |
+
chat_history = []
|
129 |
+
for human_msg, ai_msg in history:
|
130 |
+
chat_history.append(HumanMessage(content=human_msg))
|
131 |
+
chat_history.append(AIMessage(content=ai_msg))
|
132 |
+
|
133 |
+
try:
|
134 |
+
# Stream output from the agent
|
135 |
+
# LangChain AgentExecutor doesn't directly stream token by token in a simple loop
|
136 |
+
# For streaming, you'd typically use .stream() or a custom callback handler.
|
137 |
+
# For simplicity in Gradio ChatInterface, we'll run it once.
|
138 |
+
response = agent_executor.invoke({"input": message, "chat_history": chat_history})
|
139 |
+
agent_output = response["output"]
|
140 |
+
|
141 |
+
# Check if the output is an image path from our custom tool
|
142 |
+
if agent_output.startswith("__IMAGE_PATH__:") :
|
143 |
+
image_path = agent_output.replace("__IMAGE_PATH__:", "")
|
144 |
+
# Return the Gradio Image component directly
|
145 |
+
return gr.Image(value=image_path, label="Generated Image")
|
146 |
+
else:
|
147 |
+
# Return regular text
|
148 |
+
return agent_output
|
149 |
+
|
150 |
+
except Exception as e:
|
151 |
+
print(f"Error running agent: {e}")
|
152 |
+
traceback.print_exc()
|
153 |
+
return f"β Agent encountered an error: {str(e)}"
|
154 |
+
|
155 |
+
# Gradio ChatInterface setup
|
156 |
+
demo = gr.ChatInterface(
|
157 |
+
fn=run_agent_in_gradio,
|
158 |
+
chatbot=gr.Chatbot(label="AI Agent"),
|
159 |
+
textbox=gr.Textbox(placeholder="Ask me to generate an image or search the web...", container=False, scale=7),
|
160 |
+
title="Intelligent Image Generator & Web Search Agent (LangChain)",
|
161 |
+
description="This agent can generate images based on prompts or search the web for information first."
|
162 |
+
)
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
demo.launch()
|