bielas194 commited on
Commit
63621b6
Β·
verified Β·
1 Parent(s): 90a9f10

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import tempfile
6
+ import shutil
7
+ from functools import partial # To create a callable for our custom tool
8
+
9
+ from diffusers import StableDiffusionPipeline
10
+ from huggingface_hub import InferenceClient
11
+
12
+ # LangChain imports
13
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
14
+ from langchain_core.tools import tool
15
+ from langchain_community.tools import DuckDuckGoSearchRun
16
+ from langchain_community.llms import HuggingFaceHub
17
+ from langchain.agents import AgentExecutor, create_react_agent
18
+ from langchain.schema import HumanMessage, AIMessage
19
+
20
+ # --- 1. Load Stable Diffusion Pipeline (happens once at startup) ---
21
+ HF_TOKEN = os.environ.get("HF_TOKEN") # Using HF_TOKEN for consistency with HuggingFaceHub LLM
22
+
23
+ # Define the model ID for image generation
24
+ IMAGE_GEN_MODEL_ID = "segmind/tiny-sd" # Using the smaller model as it loaded successfully
25
+ # IMAGE_GEN_MODEL_ID = "runwayml/stable-diffusion-v1-5" # You can try this again after proving basic functionality
26
+
27
+ print(f"Loading Stable Diffusion Pipeline directly on GPU: {IMAGE_GEN_MODEL_ID}...")
28
+ try:
29
+ pipe = StableDiffusionPipeline.from_pretrained(
30
+ IMAGE_GEN_MODEL_ID,
31
+ torch_dtype=torch.float16,
32
+ use_safetensors=False, # Set to False for models that don't have safetensors (like tiny-sd)
33
+ token=HF_TOKEN # Pass token for potential faster model download
34
+ )
35
+ pipe.to("cuda") # Move the model to the GPU
36
+ print(f"Stable Diffusion Pipeline ({IMAGE_GEN_MODEL_ID}) loaded successfully on GPU.")
37
+ except Exception as e:
38
+ print("❌ Error loading Stable Diffusion Pipeline:")
39
+ import traceback
40
+ traceback.print_exc()
41
+ pipe = None # Indicate failure to load
42
+
43
+ # --- 2. Define Custom Image Generation Tool for LangChain ---
44
+ # Use @tool decorator to make a function a LangChain tool
45
+ @tool
46
+ def image_generator(prompt: str) -> str:
47
+ """
48
+ Generates an image from a detailed text prompt using a Stable Diffusion pipeline.
49
+ The input MUST be a detailed text description for the image to generate.
50
+ """
51
+ if pipe is None:
52
+ return "Error: Image generation pipeline failed to load. Please check Space logs during startup."
53
+
54
+ print(f"\n--- Agent is calling image_generator with prompt: '{prompt}' ---")
55
+ try:
56
+ with torch.no_grad():
57
+ pil_image = pipe(prompt, guidance_scale=7.5, height=512, width=512).images[0]
58
+
59
+ # Save the PIL image to a temporary file, Gradio will handle displaying this path
60
+ # NOTE: LangChain tools typically return strings. For image display, we'll return
61
+ # the path, and handle its display in the Gradio UI directly based on content.
62
+ temp_dir = tempfile.mkdtemp()
63
+ image_path = os.path.join(temp_dir, "generated_image.png")
64
+ pil_image.save(image_path)
65
+
66
+ print(f"Image saved to temporary path: {image_path}")
67
+ # Return a special string prefix so Gradio knows it's an image path
68
+ return f"__IMAGE_PATH__:{image_path}"
69
+ except Exception as e:
70
+ print("Error in image_generator tool execution:")
71
+ traceback.print_exc()
72
+ return f"Error generating image: {str(e)}"
73
+
74
+ # --- 3. Define other Tools for LangChain ---
75
+ search = DuckDuckGoSearchRun()
76
+
77
+ # --- 4. Define the LangChain Agent ---
78
+ # Ensure models are loaded successfully before proceeding
79
+ if pipe is None:
80
+ raise RuntimeError("Cannot start agent as image generation pipeline failed to load. Check logs.")
81
+
82
+ # Instantiate the LLM for the agent
83
+ # Using HuggingFaceHub to connect to Zephyr-7b-beta model on HF Inference API
84
+ # Ensure HF_TOKEN is set as a Space Secret
85
+ llm = HuggingFaceHub(
86
+ repo_id="HuggingFaceH4/zephyr-7b-beta",
87
+ huggingfacehub_api_token=HF_TOKEN, # Use HF_TOKEN directly as required by HuggingFaceHub LLM
88
+ model_kwargs={"temperature": 0.5, "max_new_tokens": 512}
89
+ )
90
+
91
+ # Create the tools list
92
+ tools = [image_generator, search]
93
+
94
+ # Define the agent prompt
95
+ # This prompt guides the LLM on how to use the tools
96
+ prompt_template = ChatPromptTemplate.from_messages(
97
+ [
98
+ ("system", """You are a powerful AI assistant that can generate images and search the web.
99
+ You have access to the following tools: {tools}
100
+
101
+ When you need to generate an image, use the `image_generator` tool. Its input must be a very detailed, descriptive text string.
102
+ When you need factual information or context, use the `search` tool.
103
+
104
+ Always follow these steps:
105
+ 1. Think step-by-step: Analyze the user's request and determine if you need to search or generate an image.
106
+ 2. If you need to search, use the `search` tool.
107
+ 3. If you need to generate an image, ensure you have enough detail. If not, ask for more or use search.
108
+ 4. When you have enough information, use the `image_generator` tool.
109
+ 5. Provide your final answer. If you generated an image, include the image in your final answer.
110
+ """),
111
+ MessagesPlaceholder(variable_name="chat_history"),
112
+ ("human", "{input}"),
113
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
114
+ ]
115
+ )
116
+
117
+ # Create the agent
118
+ agent = create_react_agent(llm, tools, prompt_template)
119
+
120
+ # Create the agent executor
121
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
122
+
123
+ # --- 5. Gradio UI Integration ---
124
+
125
+ # Function to run the agent and display output
126
+ def run_agent_in_gradio(message, history):
127
+ # Convert Gradio history to LangChain chat_history format
128
+ chat_history = []
129
+ for human_msg, ai_msg in history:
130
+ chat_history.append(HumanMessage(content=human_msg))
131
+ chat_history.append(AIMessage(content=ai_msg))
132
+
133
+ try:
134
+ # Stream output from the agent
135
+ # LangChain AgentExecutor doesn't directly stream token by token in a simple loop
136
+ # For streaming, you'd typically use .stream() or a custom callback handler.
137
+ # For simplicity in Gradio ChatInterface, we'll run it once.
138
+ response = agent_executor.invoke({"input": message, "chat_history": chat_history})
139
+ agent_output = response["output"]
140
+
141
+ # Check if the output is an image path from our custom tool
142
+ if agent_output.startswith("__IMAGE_PATH__:") :
143
+ image_path = agent_output.replace("__IMAGE_PATH__:", "")
144
+ # Return the Gradio Image component directly
145
+ return gr.Image(value=image_path, label="Generated Image")
146
+ else:
147
+ # Return regular text
148
+ return agent_output
149
+
150
+ except Exception as e:
151
+ print(f"Error running agent: {e}")
152
+ traceback.print_exc()
153
+ return f"❌ Agent encountered an error: {str(e)}"
154
+
155
+ # Gradio ChatInterface setup
156
+ demo = gr.ChatInterface(
157
+ fn=run_agent_in_gradio,
158
+ chatbot=gr.Chatbot(label="AI Agent"),
159
+ textbox=gr.Textbox(placeholder="Ask me to generate an image or search the web...", container=False, scale=7),
160
+ title="Intelligent Image Generator & Web Search Agent (LangChain)",
161
+ description="This agent can generate images based on prompts or search the web for information first."
162
+ )
163
+
164
+ if __name__ == "__main__":
165
+ demo.launch()