File size: 6,828 Bytes
63621b6
 
 
 
 
 
2d54ce1
49fee4b
63621b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d54ce1
63621b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49fee4b
63621b6
 
 
 
 
 
 
 
 
 
 
 
 
49fee4b
63621b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0f4098
49fee4b
 
 
63621b6
 
 
 
 
 
 
 
 
 
 
 
 
c0f4098
63621b6
c0f4098
63621b6
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import torch
import gradio as gr
from PIL import Image
import tempfile
import shutil
from functools import partial 
import traceback # <--- ADDED THIS LINE: Import the traceback module

from diffusers import StableDiffusionPipeline
from huggingface_hub import InferenceClient

# LangChain imports
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import tool
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.llms import HuggingFaceHub
from langchain.agents import AgentExecutor, create_react_agent
from langchain.schema import HumanMessage, AIMessage

# --- 1. Load Stable Diffusion Pipeline (happens once at startup) ---
HF_TOKEN = os.environ.get("HF_TOKEN") # Using HF_TOKEN for consistency with HuggingFaceHub LLM

# Define the model ID for image generation
IMAGE_GEN_MODEL_ID = "segmind/tiny-sd" # Using the smaller model as it loaded successfully

print(f"Loading Stable Diffusion Pipeline directly on GPU: {IMAGE_GEN_MODEL_ID}...")
try:
    pipe = StableDiffusionPipeline.from_pretrained(
        IMAGE_GEN_MODEL_ID,
        torch_dtype=torch.float16, # Use float16 for less VRAM usage on T4
        use_safetensors=False, # Set to False for models that don't have safetensors (like tiny-sd)
        token=HF_TOKEN # Pass token for potential faster model download
    )
    pipe.to("cuda") # Move the model to the GPU
    print(f"Stable Diffusion Pipeline ({IMAGE_GEN_MODEL_ID}) loaded successfully on GPU.")
except Exception as e:
    print("❌ Error loading Stable Diffusion Pipeline:")
    traceback.print_exc()
    pipe = None # Indicate failure to load

# --- 2. Define Custom Image Generation Tool for LangChain ---
# Use @tool decorator to make a function a LangChain tool
@tool
def image_generator(prompt: str) -> str:
    """
    Generates an image from a detailed text prompt using a Stable Diffusion pipeline.
    The input MUST be a detailed text description for the image to generate.
    """
    if pipe is None:
        return "Error: Image generation pipeline failed to load. Please check Space logs during startup."
    
    print(f"\n--- Agent is calling image_generator with prompt: '{prompt}' ---")
    try:
        with torch.no_grad():
            pil_image = pipe(prompt, guidance_scale=7.5, height=512, width=512).images[0]
        
        # Save the PIL image to a temporary file, Gradio will handle displaying this path
        # NOTE: LangChain tools typically return strings. For image display, we'll return
        # the path, and handle its display in the Gradio UI directly based on content.
        temp_dir = tempfile.mkdtemp()
        image_path = os.path.join(temp_dir, "generated_image.png")
        pil_image.save(image_path)
        
        print(f"Image saved to temporary path: {image_path}")
        # Return a special string prefix so Gradio knows it's an image path
        return f"__IMAGE_PATH__:{image_path}" 
    except Exception as e:
        print("Error in image_generator tool execution:")
        traceback.print_exc()
        return f"Error generating image: {str(e)}"

# --- 3. Define other Tools for LangChain ---
search = DuckDuckGoSearchRun()

# --- 4. Define the LangChain Agent ---
# Ensure models are loaded successfully before proceeding
if pipe is None:
    raise RuntimeError("Cannot start agent as image generation pipeline failed to load. Check logs.")

# Instantiate the LLM for the agent
llm = HuggingFaceHub(
    repo_id="HuggingFaceH4/zephyr-7b-beta",
    huggingfacehub_api_token=HF_TOKEN, # Use HF_TOKEN directly as required by HuggingFaceHub LLM
    model_kwargs={"temperature": 0.5, "max_new_tokens": 512}
)

# Create the tools list
tools = [image_generator, search]

# Define the agent prompt
# This prompt guides the LLM on how to use the tools
prompt_template = ChatPromptTemplate.from_messages(
    [
        ("system", """You are a powerful AI assistant that can generate images and search the web.
You have access to the following tools: {tools}
Available tools: {tool_names} # <--- THIS LINE IS CRUCIAL AND MUST BE PRESENT.

When you need to generate an image, use the `image_generator` tool. Its input must be a very detailed, descriptive text string.
When you need factual information or context, use the `search` tool.

Always follow these steps:
1. Think step-by-step: Analyze the user's request and determine if you need to search or generate an image.
2. If you need to search, use the `search` tool.
3. If you need to generate an image, ensure you have enough detail. If not, ask for more or use search.
4. When you have enough information, use the `image_generator` tool.
5. Provide your final answer. If you generated an image, include the image in your final answer.
"""),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{input}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"), # This placeholder must be present
    ]
)

# Create the agent
agent = create_react_agent(llm, tools, prompt_template)

# Create the agent executor
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)

# --- 5. Gradio UI Integration ---

# Function to run the agent and display output
def run_agent_in_gradio(message, history):
    # Convert Gradio history to LangChain chat_history format
    chat_history = []
    for human_msg, ai_msg in history:
        chat_history.append(HumanMessage(content=human_msg))
        chat_history.append(AIMessage(content=ai_msg))

    try:
        # THIS LINE IS CRUCIAL AND MUST INCLUDE "agent_scratchpad": []
        response = agent_executor.invoke(
            {"input": message, "chat_history": chat_history, "agent_scratchpad": []} 
        )
        agent_output = response["output"]

        # Check if the output is an image path from our custom tool
        if agent_output.startswith("__IMAGE_PATH__:") :
            image_path = agent_output.replace("__IMAGE_PATH__:", "")
            # Return the Gradio Image component directly
            return gr.Image(value=image_path, label="Generated Image")
        else:
            # Return regular text
            return agent_output

    except Exception as e:
        print(f"Error running agent: {e}")
        traceback.print_exc() 
        return f"❌ Agent encountered an error: {str(e)}"
        
# Gradio ChatInterface setup
demo = gr.ChatInterface(
    fn=run_agent_in_gradio,
    chatbot=gr.Chatbot(label="AI Agent"),
    textbox=gr.Textbox(placeholder="Ask me to generate an image or search the web...", container=False, scale=7),
    title="Intelligent Image Generator & Web Search Agent (LangChain)",
    description="This agent can generate images based on prompts or search the web for information first."
)

if __name__ == "__main__":
    demo.launch()