Spaces:
Runtime error
Runtime error
File size: 6,828 Bytes
63621b6 2d54ce1 49fee4b 63621b6 2d54ce1 63621b6 49fee4b 63621b6 49fee4b 63621b6 c0f4098 49fee4b 63621b6 c0f4098 63621b6 c0f4098 63621b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
import torch
import gradio as gr
from PIL import Image
import tempfile
import shutil
from functools import partial
import traceback # <--- ADDED THIS LINE: Import the traceback module
from diffusers import StableDiffusionPipeline
from huggingface_hub import InferenceClient
# LangChain imports
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import tool
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.llms import HuggingFaceHub
from langchain.agents import AgentExecutor, create_react_agent
from langchain.schema import HumanMessage, AIMessage
# --- 1. Load Stable Diffusion Pipeline (happens once at startup) ---
HF_TOKEN = os.environ.get("HF_TOKEN") # Using HF_TOKEN for consistency with HuggingFaceHub LLM
# Define the model ID for image generation
IMAGE_GEN_MODEL_ID = "segmind/tiny-sd" # Using the smaller model as it loaded successfully
print(f"Loading Stable Diffusion Pipeline directly on GPU: {IMAGE_GEN_MODEL_ID}...")
try:
pipe = StableDiffusionPipeline.from_pretrained(
IMAGE_GEN_MODEL_ID,
torch_dtype=torch.float16, # Use float16 for less VRAM usage on T4
use_safetensors=False, # Set to False for models that don't have safetensors (like tiny-sd)
token=HF_TOKEN # Pass token for potential faster model download
)
pipe.to("cuda") # Move the model to the GPU
print(f"Stable Diffusion Pipeline ({IMAGE_GEN_MODEL_ID}) loaded successfully on GPU.")
except Exception as e:
print("β Error loading Stable Diffusion Pipeline:")
traceback.print_exc()
pipe = None # Indicate failure to load
# --- 2. Define Custom Image Generation Tool for LangChain ---
# Use @tool decorator to make a function a LangChain tool
@tool
def image_generator(prompt: str) -> str:
"""
Generates an image from a detailed text prompt using a Stable Diffusion pipeline.
The input MUST be a detailed text description for the image to generate.
"""
if pipe is None:
return "Error: Image generation pipeline failed to load. Please check Space logs during startup."
print(f"\n--- Agent is calling image_generator with prompt: '{prompt}' ---")
try:
with torch.no_grad():
pil_image = pipe(prompt, guidance_scale=7.5, height=512, width=512).images[0]
# Save the PIL image to a temporary file, Gradio will handle displaying this path
# NOTE: LangChain tools typically return strings. For image display, we'll return
# the path, and handle its display in the Gradio UI directly based on content.
temp_dir = tempfile.mkdtemp()
image_path = os.path.join(temp_dir, "generated_image.png")
pil_image.save(image_path)
print(f"Image saved to temporary path: {image_path}")
# Return a special string prefix so Gradio knows it's an image path
return f"__IMAGE_PATH__:{image_path}"
except Exception as e:
print("Error in image_generator tool execution:")
traceback.print_exc()
return f"Error generating image: {str(e)}"
# --- 3. Define other Tools for LangChain ---
search = DuckDuckGoSearchRun()
# --- 4. Define the LangChain Agent ---
# Ensure models are loaded successfully before proceeding
if pipe is None:
raise RuntimeError("Cannot start agent as image generation pipeline failed to load. Check logs.")
# Instantiate the LLM for the agent
llm = HuggingFaceHub(
repo_id="HuggingFaceH4/zephyr-7b-beta",
huggingfacehub_api_token=HF_TOKEN, # Use HF_TOKEN directly as required by HuggingFaceHub LLM
model_kwargs={"temperature": 0.5, "max_new_tokens": 512}
)
# Create the tools list
tools = [image_generator, search]
# Define the agent prompt
# This prompt guides the LLM on how to use the tools
prompt_template = ChatPromptTemplate.from_messages(
[
("system", """You are a powerful AI assistant that can generate images and search the web.
You have access to the following tools: {tools}
Available tools: {tool_names} # <--- THIS LINE IS CRUCIAL AND MUST BE PRESENT.
When you need to generate an image, use the `image_generator` tool. Its input must be a very detailed, descriptive text string.
When you need factual information or context, use the `search` tool.
Always follow these steps:
1. Think step-by-step: Analyze the user's request and determine if you need to search or generate an image.
2. If you need to search, use the `search` tool.
3. If you need to generate an image, ensure you have enough detail. If not, ask for more or use search.
4. When you have enough information, use the `image_generator` tool.
5. Provide your final answer. If you generated an image, include the image in your final answer.
"""),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"), # This placeholder must be present
]
)
# Create the agent
agent = create_react_agent(llm, tools, prompt_template)
# Create the agent executor
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
# --- 5. Gradio UI Integration ---
# Function to run the agent and display output
def run_agent_in_gradio(message, history):
# Convert Gradio history to LangChain chat_history format
chat_history = []
for human_msg, ai_msg in history:
chat_history.append(HumanMessage(content=human_msg))
chat_history.append(AIMessage(content=ai_msg))
try:
# THIS LINE IS CRUCIAL AND MUST INCLUDE "agent_scratchpad": []
response = agent_executor.invoke(
{"input": message, "chat_history": chat_history, "agent_scratchpad": []}
)
agent_output = response["output"]
# Check if the output is an image path from our custom tool
if agent_output.startswith("__IMAGE_PATH__:") :
image_path = agent_output.replace("__IMAGE_PATH__:", "")
# Return the Gradio Image component directly
return gr.Image(value=image_path, label="Generated Image")
else:
# Return regular text
return agent_output
except Exception as e:
print(f"Error running agent: {e}")
traceback.print_exc()
return f"β Agent encountered an error: {str(e)}"
# Gradio ChatInterface setup
demo = gr.ChatInterface(
fn=run_agent_in_gradio,
chatbot=gr.Chatbot(label="AI Agent"),
textbox=gr.Textbox(placeholder="Ask me to generate an image or search the web...", container=False, scale=7),
title="Intelligent Image Generator & Web Search Agent (LangChain)",
description="This agent can generate images based on prompts or search the web for information first."
)
if __name__ == "__main__":
demo.launch() |