shamik commited on
Commit
2091d19
·
unverified ·
1 Parent(s): a8c86eb

feat: adding multiagent script.

Browse files
Files changed (1) hide show
  1. src/agent_hackathon/multiagent.py +146 -0
src/agent_hackathon/multiagent.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ from datetime import date
4
+
5
+ from consts import PROJECT_ROOT_DIR
6
+
7
+ # from dotenv import find_dotenv, load_dotenv
8
+ from generate_arxiv_responses import ArxivResponseGenerator
9
+ from llama_index.core.agent.workflow import AgentWorkflow, ReActAgent
10
+ from llama_index.core.tools import FunctionTool
11
+ from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
12
+ from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
13
+
14
+ from src.agent_hackathon.logger import get_logger
15
+
16
+ # _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=False), override=True)
17
+
18
+ logger = get_logger(log_name="multiagent", log_dir=PROJECT_ROOT_DIR / "logs")
19
+
20
+
21
+ class MultiAgentWorkflow:
22
+ """Multi-agent workflow for retrieving research papers and related events."""
23
+
24
+ def __init__(self) -> None:
25
+ """Initialize the workflow with LLM, tools, and generator."""
26
+ logger.info("Initializing MultiAgentWorkflow.")
27
+ self.llm = HuggingFaceInferenceAPI(
28
+ model="meta-llama/Llama-3.3-70B-Instruct",
29
+ provider="auto",
30
+ # provider="nebius",
31
+ temperature=0.1,
32
+ top_p=0.95,
33
+ max_tokens=8192
34
+ # api_key=os.getenv(key="NEBIUS_API_KEY"),
35
+ # base_url="https://api.studio.nebius.com/v1/",
36
+ )
37
+ self._generator = ArxivResponseGenerator(
38
+ vector_store_path=PROJECT_ROOT_DIR / "db/arxiv_docs.db"
39
+ )
40
+ # self._arxiv_rag_tool = FunctionTool.from_defaults(
41
+ # fn=self._arxiv_rag,
42
+ # name="arxiv_rag",
43
+ # description="Retrieves arxiv research papers.",
44
+ # return_direct=True,
45
+ # )
46
+ self._duckduckgo_search_tool = [
47
+ tool
48
+ for tool in DuckDuckGoSearchToolSpec().to_tool_list()
49
+ if tool.metadata.name == "duckduckgo_full_search"
50
+ ]
51
+ # self._arxiv_agent = ReActAgent(
52
+ # name="arxiv_agent",
53
+ # description="Retrieves information about arxiv research papers",
54
+ # system_prompt="You are arxiv research paper agent, who retrieves information "
55
+ # "about arxiv research papers.",
56
+ # tools=[self._arxiv_rag_tool],
57
+ # llm=self.llm,
58
+ # )
59
+ self._websearch_agent = ReActAgent(
60
+ name="web_search",
61
+ description="Searches the web",
62
+ system_prompt="You are search engine who searches the web using duckduckgo tool",
63
+ tools=self._duckduckgo_search_tool,
64
+ llm=self.llm,
65
+ )
66
+
67
+ self._workflow = AgentWorkflow(
68
+ agents=[self._websearch_agent],
69
+ root_agent="web_search",
70
+ timeout=180,
71
+ )
72
+ # AgentWorkflow.from_tools_or_functions(
73
+ # tools_or_functions=self._duckduckgo_search_tool,
74
+ # llm=self.llm,
75
+ # system_prompt="You are an expert that "
76
+ # "searches for any corresponding events related to the "
77
+ # "user query "
78
+ # "using the duckduckgo_search_tool and returns the final results." \
79
+ # "Don't return the steps but execute the necessary tools that you have " \
80
+ # "access to and return the results.",
81
+ # timeout=180,
82
+ # )
83
+
84
+ logger.info("MultiAgentWorkflow initialized.")
85
+
86
+ def _arxiv_rag(self, query: str) -> str:
87
+ """Retrieve research papers from arXiv based on the query.
88
+
89
+ Args:
90
+ query (str): The search query.
91
+
92
+ Returns:
93
+ str: Retrieved research papers as a string.
94
+ """
95
+ return self._generator.retrieve_arxiv_papers(query=query)
96
+
97
+ def _clean_response(self, result: str) -> str:
98
+ """Removes the think tags.
99
+
100
+ Args:
101
+ result (str): The result with the <think></think> content.
102
+
103
+ Returns:
104
+ str: The result without the <think></think> content.
105
+ """
106
+ if result.find("</think>"):
107
+ result = result[result.find("</think>") + len("</think>") :]
108
+ return result
109
+
110
+ async def run(self, user_query: str) -> str:
111
+ """Run the multi-agent workflow for a given user query.
112
+
113
+ Args:
114
+ user_query (str): The user's search query.
115
+
116
+ Returns:
117
+ str: The output string.
118
+ """
119
+ logger.info("Running multi-agent workflow.")
120
+ try:
121
+ research_papers = self._arxiv_rag(query=user_query)
122
+ user_msg = (
123
+ f"search with the web search agent to find any relevant events related to: {user_query}.\n"
124
+ f" The web search results relevant to the current year: {date.today().year}. \n"
125
+ )
126
+ web_search_results = await self._workflow.run(user_msg=user_msg)
127
+ final_res = (
128
+ research_papers + "\n\n" + web_search_results.response.blocks[0].text
129
+ )
130
+ logger.info("Workflow run completed successfully.")
131
+ return final_res
132
+ except Exception as err:
133
+ logger.error(f"Workflow run failed: {err}")
134
+ raise
135
+
136
+
137
+ if __name__ == "__main__":
138
+ USER_QUERY = "i want to learn more about nlp"
139
+ workflow = MultiAgentWorkflow()
140
+ logger.info("Starting workflow for user query.")
141
+ try:
142
+ result = asyncio.run(workflow.run(user_query=USER_QUERY))
143
+ logger.info("Workflow finished. Output below:")
144
+ print(result)
145
+ except Exception as err:
146
+ logger.error(f"Error during workflow execution: {err}")