TinySuitStarfish commited on
Commit
82f8462
·
verified ·
1 Parent(s): f04a479

Update agentic_agent.py

Browse files
Files changed (1) hide show
  1. agentic_agent.py +46 -29
agentic_agent.py CHANGED
@@ -5,6 +5,8 @@ from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
 
 
8
  from langchain_groq import ChatGroq
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
@@ -105,15 +107,18 @@ def arvix_search(query: str) -> str:
105
  return {"arvix_results": formatted_search_docs}
106
 
107
 
 
 
108
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
109
  system_prompt = f.read()
110
 
 
111
  sys_msg = SystemMessage(content=system_prompt)
112
 
113
  # build a retriever
114
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
115
  supabase: Client = create_client(
116
- os.environ.get("SUPABASE_URL"),
117
  os.environ.get("SUPABASE_SERVICE_KEY"))
118
  vector_store = SupabaseVectorStore(
119
  client=supabase,
@@ -127,14 +132,31 @@ create_retriever_tool = create_retriever_tool(
127
  description="A tool to retrieve similar questions from a vector store.",
128
  )
129
 
130
- tools = [add, subtract, multiply, divide, modulus, web_search, wiki_search, arvix_search]
131
 
132
- def build_graph(provider: str = "google"):
 
 
 
 
 
 
 
 
 
 
 
 
133
  """Build the graph"""
 
134
  if provider == "google":
 
135
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
136
  elif provider == "groq":
137
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
 
 
 
138
  elif provider == "huggingface":
139
  llm = ChatHuggingFace(
140
  llm=HuggingFaceEndpoint(
@@ -144,38 +166,33 @@ def build_graph(provider: str = "google"):
144
  )
145
  else:
146
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
147
  llm_with_tools = llm.bind_tools(tools)
148
 
149
  # Node
150
  def assistant(state: MessagesState):
151
  """Assistant node"""
152
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
153
-
154
- from langchain_core.messages import AIMessage
155
- def retriever(state: MessagesState):
156
- query = state["messages"][-1].content
157
- similar_doc = vector_store.similarity_search(query, k=1)[0]
158
 
159
- content = similar_doc.page_content
160
- if "Final answer :" in content:
161
- answer = content.split("Final answer :")[-1].strip()
162
- else:
163
- answer = content.strip()
164
- return {"messages": [AIMessage(content=answer)]}
 
165
 
166
  builder = StateGraph(MessagesState)
167
  builder.add_node("retriever", retriever)
168
-
169
- # Retriever ist Start und Endpunkt
170
- builder.set_entry_point("retriever")
171
- builder.set_finish_point("retriever")
172
-
173
- return builder.compile()
174
-
175
-
176
-
177
-
178
-
179
-
180
-
181
-
 
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain.agents import initialize_agent, Tool
10
  from langchain_groq import ChatGroq
11
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
12
  from langchain_community.tools.tavily_search import TavilySearchResults
 
107
  return {"arvix_results": formatted_search_docs}
108
 
109
 
110
+
111
+ # load the system prompt from the file
112
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
113
  system_prompt = f.read()
114
 
115
+ # System message
116
  sys_msg = SystemMessage(content=system_prompt)
117
 
118
  # build a retriever
119
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
120
  supabase: Client = create_client(
121
+ os.environ.get("SUPABASE_URL"),
122
  os.environ.get("SUPABASE_SERVICE_KEY"))
123
  vector_store = SupabaseVectorStore(
124
  client=supabase,
 
132
  description="A tool to retrieve similar questions from a vector store.",
133
  )
134
 
 
135
 
136
+
137
+ tools = [
138
+ multiply,
139
+ add,
140
+ subtract,
141
+ divide,
142
+ modulus,
143
+ wiki_search,
144
+ web_search,
145
+ arvix_search,
146
+ ]
147
+
148
+ def build_graph(provider: str = "groq"):
149
  """Build the graph"""
150
+ # Load environment variables from .env file
151
  if provider == "google":
152
+ # Google Gemini
153
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
154
  elif provider == "groq":
155
+ # Groq https://console.groq.com/docs/models
156
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
157
+ elif provider == "openai":
158
+ # OpenAI
159
+ llm = ChatOpenAI(model="gpt-4", temperature=0)
160
  elif provider == "huggingface":
161
  llm = ChatHuggingFace(
162
  llm=HuggingFaceEndpoint(
 
166
  )
167
  else:
168
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
169
+ # Bind tools to LLM
170
  llm_with_tools = llm.bind_tools(tools)
171
 
172
  # Node
173
  def assistant(state: MessagesState):
174
  """Assistant node"""
175
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
176
 
177
+ def retriever(state: MessagesState):
178
+ """Retriever node"""
179
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
180
+ example_msg = HumanMessage(
181
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
182
+ )
183
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
184
 
185
  builder = StateGraph(MessagesState)
186
  builder.add_node("retriever", retriever)
187
+ builder.add_node("assistant", assistant)
188
+ builder.add_node("tools", ToolNode(tools))
189
+ builder.add_edge(START, "retriever")
190
+ builder.add_edge("retriever", "assistant")
191
+ builder.add_conditional_edges(
192
+ "assistant",
193
+ tools_condition,
194
+ )
195
+ builder.add_edge("tools", "assistant")
196
+
197
+ # Compile graph
198
+ return builder.compile()