import os import torch import gradio as gr from PIL import Image import tempfile import shutil from functools import partial import traceback # <--- ADDED THIS LINE: Import the traceback module from diffusers import StableDiffusionPipeline from huggingface_hub import InferenceClient # LangChain imports from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.tools import tool from langchain_community.tools import DuckDuckGoSearchRun from langchain_community.llms import HuggingFaceHub from langchain.agents import AgentExecutor, create_react_agent from langchain.schema import HumanMessage, AIMessage # --- 1. Load Stable Diffusion Pipeline (happens once at startup) --- HF_TOKEN = os.environ.get("HF_TOKEN") # Using HF_TOKEN for consistency with HuggingFaceHub LLM # Define the model ID for image generation IMAGE_GEN_MODEL_ID = "segmind/tiny-sd" # Using the smaller model as it loaded successfully print(f"Loading Stable Diffusion Pipeline directly on GPU: {IMAGE_GEN_MODEL_ID}...") try: pipe = StableDiffusionPipeline.from_pretrained( IMAGE_GEN_MODEL_ID, torch_dtype=torch.float16, # Use float16 for less VRAM usage on T4 use_safetensors=False, # Set to False for models that don't have safetensors (like tiny-sd) token=HF_TOKEN # Pass token for potential faster model download ) pipe.to("cuda") # Move the model to the GPU print(f"Stable Diffusion Pipeline ({IMAGE_GEN_MODEL_ID}) loaded successfully on GPU.") except Exception as e: print("❌ Error loading Stable Diffusion Pipeline:") traceback.print_exc() pipe = None # Indicate failure to load # --- 2. Define Custom Image Generation Tool for LangChain --- # Use @tool decorator to make a function a LangChain tool @tool def image_generator(prompt: str) -> str: """ Generates an image from a detailed text prompt using a Stable Diffusion pipeline. The input MUST be a detailed text description for the image to generate. """ if pipe is None: return "Error: Image generation pipeline failed to load. Please check Space logs during startup." print(f"\n--- Agent is calling image_generator with prompt: '{prompt}' ---") try: with torch.no_grad(): pil_image = pipe(prompt, guidance_scale=7.5, height=512, width=512).images[0] # Save the PIL image to a temporary file, Gradio will handle displaying this path # NOTE: LangChain tools typically return strings. For image display, we'll return # the path, and handle its display in the Gradio UI directly based on content. temp_dir = tempfile.mkdtemp() image_path = os.path.join(temp_dir, "generated_image.png") pil_image.save(image_path) print(f"Image saved to temporary path: {image_path}") # Return a special string prefix so Gradio knows it's an image path return f"__IMAGE_PATH__:{image_path}" except Exception as e: print("Error in image_generator tool execution:") traceback.print_exc() return f"Error generating image: {str(e)}" # --- 3. Define other Tools for LangChain --- search = DuckDuckGoSearchRun() # --- 4. Define the LangChain Agent --- # Ensure models are loaded successfully before proceeding if pipe is None: raise RuntimeError("Cannot start agent as image generation pipeline failed to load. Check logs.") # Instantiate the LLM for the agent llm = HuggingFaceHub( repo_id="HuggingFaceH4/zephyr-7b-beta", huggingfacehub_api_token=HF_TOKEN, # Use HF_TOKEN directly as required by HuggingFaceHub LLM model_kwargs={"temperature": 0.5, "max_new_tokens": 512} ) # Create the tools list tools = [image_generator, search] # Define the agent prompt # This prompt guides the LLM on how to use the tools prompt_template = ChatPromptTemplate.from_messages( [ ("system", """You are a powerful AI assistant that can generate images and search the web. You have access to the following tools: {tools} Available tools: {tool_names} # <--- THIS LINE IS CRUCIAL AND MUST BE PRESENT. When you need to generate an image, use the `image_generator` tool. Its input must be a very detailed, descriptive text string. When you need factual information or context, use the `search` tool. Always follow these steps: 1. Think step-by-step: Analyze the user's request and determine if you need to search or generate an image. 2. If you need to search, use the `search` tool. 3. If you need to generate an image, ensure you have enough detail. If not, ask for more or use search. 4. When you have enough information, use the `image_generator` tool. 5. Provide your final answer. If you generated an image, include the image in your final answer. """), MessagesPlaceholder(variable_name="chat_history"), ("human", "{input}"), MessagesPlaceholder(variable_name="agent_scratchpad"), # This placeholder must be present ] ) # Create the agent agent = create_react_agent(llm, tools, prompt_template) # Create the agent executor agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True) # --- 5. Gradio UI Integration --- # Function to run the agent and display output def run_agent_in_gradio(message, history): # Convert Gradio history to LangChain chat_history format chat_history = [] for human_msg, ai_msg in history: chat_history.append(HumanMessage(content=human_msg)) chat_history.append(AIMessage(content=ai_msg)) try: # THIS LINE IS CRUCIAL AND MUST INCLUDE "agent_scratchpad": [] response = agent_executor.invoke( {"input": message, "chat_history": chat_history, "agent_scratchpad": []} ) agent_output = response["output"] # Check if the output is an image path from our custom tool if agent_output.startswith("__IMAGE_PATH__:") : image_path = agent_output.replace("__IMAGE_PATH__:", "") # Return the Gradio Image component directly return gr.Image(value=image_path, label="Generated Image") else: # Return regular text return agent_output except Exception as e: print(f"Error running agent: {e}") traceback.print_exc() return f"❌ Agent encountered an error: {str(e)}" # Gradio ChatInterface setup demo = gr.ChatInterface( fn=run_agent_in_gradio, chatbot=gr.Chatbot(label="AI Agent"), textbox=gr.Textbox(placeholder="Ask me to generate an image or search the web...", container=False, scale=7), title="Intelligent Image Generator & Web Search Agent (LangChain)", description="This agent can generate images based on prompts or search the web for information first." ) if __name__ == "__main__": demo.launch()