shamik commited on
Commit
5ea5bac
·
unverified ·
1 Parent(s): b4fb438

feat: updated the multiagent framework to include arxiv agent.

Browse files
Files changed (1) hide show
  1. src/agent_hackathon/multiagent.py +38 -36
src/agent_hackathon/multiagent.py CHANGED
@@ -1,7 +1,9 @@
1
- import asyncio
2
  from datetime import date
3
 
 
4
  from llama_index.core.agent.workflow import AgentWorkflow, ReActAgent
 
5
  from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
6
  from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
7
 
@@ -11,6 +13,8 @@ from src.agent_hackathon.consts import PROJECT_ROOT_DIR
11
  from src.agent_hackathon.generate_arxiv_responses import ArxivResponseGenerator
12
  from src.agent_hackathon.logger import get_logger
13
 
 
 
14
  # _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=False), override=True)
15
 
16
  logger = get_logger(log_name="multiagent", log_dir=PROJECT_ROOT_DIR / "logs")
@@ -28,32 +32,32 @@ class MultiAgentWorkflow:
28
  # provider="nebius",
29
  temperature=0.1,
30
  top_p=0.95,
31
- max_tokens=8192,
32
  # api_key=os.getenv(key="NEBIUS_API_KEY"),
33
  # base_url="https://api.studio.nebius.com/v1/",
 
34
  )
35
  self._generator = ArxivResponseGenerator(
36
  vector_store_path=PROJECT_ROOT_DIR / "db/arxiv_docs.db"
37
  )
38
- # self._arxiv_rag_tool = FunctionTool.from_defaults(
39
- # fn=self._arxiv_rag,
40
- # name="arxiv_rag",
41
- # description="Retrieves arxiv research papers.",
42
- # return_direct=True,
43
- # )
44
  self._duckduckgo_search_tool = [
45
  tool
46
  for tool in DuckDuckGoSearchToolSpec().to_tool_list()
47
  if tool.metadata.name == "duckduckgo_full_search"
48
  ]
49
- # self._arxiv_agent = ReActAgent(
50
- # name="arxiv_agent",
51
- # description="Retrieves information about arxiv research papers",
52
- # system_prompt="You are arxiv research paper agent, who retrieves information "
53
- # "about arxiv research papers.",
54
- # tools=[self._arxiv_rag_tool],
55
- # llm=self.llm,
56
- # )
57
  self._websearch_agent = ReActAgent(
58
  name="web_search",
59
  description="Searches the web",
@@ -63,8 +67,8 @@ class MultiAgentWorkflow:
63
  )
64
 
65
  self._workflow = AgentWorkflow(
66
- agents=[self._websearch_agent],
67
- root_agent="web_search",
68
  timeout=180,
69
  )
70
  # AgentWorkflow.from_tools_or_functions(
@@ -116,29 +120,27 @@ class MultiAgentWorkflow:
116
  """
117
  logger.info("Running multi-agent workflow.")
118
  try:
119
- research_papers = self._arxiv_rag(query=user_query)
120
  user_msg = (
121
- f"search with the web search agent to find any relevant events related to: {user_query}.\n"
122
- f" The web search results relevant to the current year: {date.today().year}. \n"
123
- )
124
- web_search_results = await self._workflow.run(user_msg=user_msg)
125
- final_res = (
126
- research_papers + "\n\n" + web_search_results.response.blocks[0].text
127
  )
 
128
  logger.info("Workflow run completed successfully.")
129
- return final_res
130
  except Exception as err:
131
  logger.error(f"Workflow run failed: {err}")
132
  raise
133
 
134
 
135
- if __name__ == "__main__":
136
- USER_QUERY = "i want to learn more about nlp"
137
- workflow = MultiAgentWorkflow()
138
- logger.info("Starting workflow for user query.")
139
- try:
140
- result = asyncio.run(workflow.run(user_query=USER_QUERY))
141
- logger.info("Workflow finished. Output below:")
142
- print(result)
143
- except Exception as err:
144
- logger.error(f"Error during workflow execution: {err}")
 
1
+ # import asyncio
2
  from datetime import date
3
 
4
+ import nest_asyncio
5
  from llama_index.core.agent.workflow import AgentWorkflow, ReActAgent
6
+ from llama_index.core.tools import FunctionTool
7
  from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
8
  from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
9
 
 
13
  from src.agent_hackathon.generate_arxiv_responses import ArxivResponseGenerator
14
  from src.agent_hackathon.logger import get_logger
15
 
16
+ nest_asyncio.apply()
17
+
18
  # _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=False), override=True)
19
 
20
  logger = get_logger(log_name="multiagent", log_dir=PROJECT_ROOT_DIR / "logs")
 
32
  # provider="nebius",
33
  temperature=0.1,
34
  top_p=0.95,
 
35
  # api_key=os.getenv(key="NEBIUS_API_KEY"),
36
  # base_url="https://api.studio.nebius.com/v1/",
37
+ system_prompt="Don't just plan, but execute the plan until failure.",
38
  )
39
  self._generator = ArxivResponseGenerator(
40
  vector_store_path=PROJECT_ROOT_DIR / "db/arxiv_docs.db"
41
  )
42
+ self._arxiv_rag_tool = FunctionTool.from_defaults(
43
+ fn=self._arxiv_rag,
44
+ name="arxiv_rag",
45
+ description="Retrieves arxiv research papers.",
46
+ return_direct=False,
47
+ )
48
  self._duckduckgo_search_tool = [
49
  tool
50
  for tool in DuckDuckGoSearchToolSpec().to_tool_list()
51
  if tool.metadata.name == "duckduckgo_full_search"
52
  ]
53
+ self._arxiv_agent = ReActAgent(
54
+ name="arxiv_agent",
55
+ description="Retrieves information about arxiv research papers",
56
+ system_prompt="You are arxiv research paper agent, who retrieves information "
57
+ "about arxiv research papers.",
58
+ tools=[self._arxiv_rag_tool],
59
+ llm=self.llm,
60
+ )
61
  self._websearch_agent = ReActAgent(
62
  name="web_search",
63
  description="Searches the web",
 
67
  )
68
 
69
  self._workflow = AgentWorkflow(
70
+ agents=[self._arxiv_agent, self._websearch_agent],
71
+ root_agent="arxiv_agent",
72
  timeout=180,
73
  )
74
  # AgentWorkflow.from_tools_or_functions(
 
120
  """
121
  logger.info("Running multi-agent workflow.")
122
  try:
 
123
  user_msg = (
124
+ f"First, give me arxiv research papers about: {user_query}."
125
+ f"Then search with web search agent for any events related to : {user_query}.\n"
126
+ f"The web search results should be relevant to the current year: {date.today().year}."
127
+ "Return all the content from all the agents."
 
 
128
  )
129
+ results = await self._workflow.run(user_msg=user_msg)
130
  logger.info("Workflow run completed successfully.")
131
+ return results
132
  except Exception as err:
133
  logger.error(f"Workflow run failed: {err}")
134
  raise
135
 
136
 
137
+ # if __name__ == "__main__":
138
+ # USER_QUERY = "i want to learn more about nlp"
139
+ # workflow = MultiAgentWorkflow()
140
+ # logger.info("Starting workflow for user query.")
141
+ # try:
142
+ # result = asyncio.run(workflow.run(user_query=USER_QUERY))
143
+ # logger.info("Workflow finished. Output below:")
144
+ # print(result)
145
+ # except Exception as err:
146
+ # logger.error(f"Error during workflow execution: {err}")