Pulkit-bristol commited on
Commit
8aa0b4b
·
1 Parent(s): 9057373

second try

Browse files
Files changed (2) hide show
  1. agent/agent_1.py +60 -58
  2. requirements.txt +2 -1
agent/agent_1.py CHANGED
@@ -8,21 +8,30 @@ from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
  from langchain_huggingface import HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
  from langchain_community.vectorstores import FAISS
14
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
15
  from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
18
  import torch
 
 
 
19
 
20
  load_dotenv()
21
 
 
 
 
 
 
22
  class LocalChatModel:
23
- # mistralai/Mistral-7B-Instruct-v0.3 or TinyLlama/TinyLlama-1.1B-Chat-v1.0
24
- def __init__(self, model_name="mistralai/Mistral-7B-Instruct-v0.3"):
25
- print("Loading LLM on CPU...")
26
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
28
  self.model.eval()
@@ -47,7 +56,7 @@ class LocalChatModel:
47
  with torch.no_grad():
48
  outputs = self.model.generate(
49
  **inputs,
50
- max_new_tokens=256,
51
  do_sample=True,
52
  temperature=0.7,
53
  pad_token_id=self.tokenizer.eos_token_id
@@ -58,25 +67,21 @@ class LocalChatModel:
58
  return AIMessage(content=response)
59
 
60
  @tool
61
-
62
  def multiply(a: int, b: int) -> int:
63
  """Multiply two integers."""
64
  return a * b
65
 
66
  @tool
67
-
68
  def add(a: int, b: int) -> int:
69
  """Add two integers."""
70
  return a + b
71
 
72
  @tool
73
-
74
  def subtract(a: int, b: int) -> int:
75
  """Subtract second integer from first."""
76
  return a - b
77
 
78
  @tool
79
-
80
  def divide(a: int, b: int) -> float:
81
  """Divide first integer by second. Raises error if divisor is zero."""
82
  if b == 0:
@@ -84,61 +89,67 @@ def divide(a: int, b: int) -> float:
84
  return a / b
85
 
86
  @tool
87
-
88
  def modulus(a: int, b: int) -> int:
89
  """Get the modulus (remainder) of first integer divided by second."""
90
  return a % b
91
 
92
  @tool
93
-
94
  def wiki_search(query: str) -> str:
95
  """Search Wikipedia for a query and return formatted results."""
96
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
97
- formatted_search_docs = "\n\n---\n\n".join(
98
- [
99
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
100
- for doc in search_docs
101
- ])
102
- return {"wiki_results": formatted_search_docs}
103
 
104
  @tool
105
-
106
  def web_search(query: str) -> str:
107
  """Search Tavily for a query and return formatted results."""
108
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
109
- formatted_search_docs = "\n\n---\n\n".join(
110
- [
111
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
112
- for doc in search_docs
113
- ])
114
- return {"web_results": formatted_search_docs}
115
 
116
  @tool
117
-
118
  def arvix_search(query: str) -> str:
119
  """Search Arxiv for a query and return formatted results."""
120
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
121
- formatted_search_docs = "\n\n---\n\n".join(
122
- [
123
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
124
- for doc in search_docs
125
- ])
126
- return {"arvix_results": formatted_search_docs}
127
-
128
- #dir = os.getcwd()
129
- #print(dir.rsplit('/')[:-1])
130
- with open("statics/system_prompt.txt", "r", encoding="utf-8") as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  system_prompt = f.read()
132
 
133
  sys_msg = SystemMessage(content=system_prompt)
134
 
135
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
136
  vector_store = FAISS.from_texts(["Sample text 1", "Sample text 2"], embedding=embeddings)
137
- create_retriever_tool = create_retriever_tool(
138
- retriever=vector_store.as_retriever(),
139
- name="Question Search",
140
- description="A tool to retrieve similar questions from a vector store."
141
- )
142
 
143
  tools = [
144
  multiply,
@@ -149,6 +160,9 @@ tools = [
149
  wiki_search,
150
  web_search,
151
  arvix_search,
 
 
 
152
  ]
153
 
154
  def build_graph(provider: str = "huggingface"):
@@ -157,36 +171,24 @@ def build_graph(provider: str = "huggingface"):
157
  elif provider == "groq":
158
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
159
  elif provider == "huggingface":
160
- llm = LocalChatModel(model_name="mistralai/Mistral-7B-Instruct-v0.3")
161
  else:
162
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
163
 
164
  def assistant(state: MessagesState):
165
  return {"messages": [llm.invoke(state["messages"]) ]}
166
 
167
- def retriever(state: MessagesState):
168
- similar_question = vector_store.similarity_search(state["messages"][0].content)
169
- example_msg = HumanMessage(
170
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
171
- )
172
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
173
-
174
  builder = StateGraph(MessagesState)
175
- builder.add_node("retriever", retriever)
176
  builder.add_node("assistant", assistant)
177
  builder.add_node("tools", ToolNode(tools))
178
- builder.add_edge(START, "retriever")
179
- builder.add_edge("retriever", "assistant")
180
- builder.add_conditional_edges(
181
- "assistant",
182
- tools_condition,
183
- )
184
  builder.add_edge("tools", "assistant")
185
 
186
  return builder.compile()
187
 
188
  if __name__ == "__main__":
189
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
190
  graph = build_graph(provider="huggingface")
191
  messages = [HumanMessage(content=question)]
192
  messages = graph.invoke({"messages": messages})
 
8
  from langchain_groq import ChatGroq
9
  from langchain_huggingface import HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
 
12
  from langchain_community.vectorstores import FAISS
13
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
14
  from langchain_core.tools import tool
15
  from langchain.tools.retriever import create_retriever_tool
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BlipProcessor, BlipForConditionalGeneration
17
+ from youtube_transcript_api import YouTubeTranscriptApi
18
+ from PIL import Image
19
+ import requests
20
  import torch
21
+ import pandas as pd
22
+ import numpy as np
23
+ from sklearn.metrics.pairwise import cosine_similarity
24
 
25
  load_dotenv()
26
 
27
+ # Load QA pairs and compute embeddings once
28
+ qa_df = pd.read_csv("/statics/qa_pairs.csv")
29
+ embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
30
+ qa_embeddings = embeddings_model.embed_documents(qa_df["question"].tolist())
31
+
32
  class LocalChatModel:
33
+ def __init__(self, model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
34
+ print(f"Loading {model_name} on CPU...")
 
35
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
36
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
37
  self.model.eval()
 
56
  with torch.no_grad():
57
  outputs = self.model.generate(
58
  **inputs,
59
+ max_new_tokens=512,
60
  do_sample=True,
61
  temperature=0.7,
62
  pad_token_id=self.tokenizer.eos_token_id
 
67
  return AIMessage(content=response)
68
 
69
  @tool
 
70
  def multiply(a: int, b: int) -> int:
71
  """Multiply two integers."""
72
  return a * b
73
 
74
  @tool
 
75
  def add(a: int, b: int) -> int:
76
  """Add two integers."""
77
  return a + b
78
 
79
  @tool
 
80
  def subtract(a: int, b: int) -> int:
81
  """Subtract second integer from first."""
82
  return a - b
83
 
84
  @tool
 
85
  def divide(a: int, b: int) -> float:
86
  """Divide first integer by second. Raises error if divisor is zero."""
87
  if b == 0:
 
89
  return a / b
90
 
91
  @tool
 
92
  def modulus(a: int, b: int) -> int:
93
  """Get the modulus (remainder) of first integer divided by second."""
94
  return a % b
95
 
96
  @tool
 
97
  def wiki_search(query: str) -> str:
98
  """Search Wikipedia for a query and return formatted results."""
99
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
100
+ return "\n\n---\n\n".join([doc.page_content for doc in search_docs])
 
 
 
 
 
101
 
102
  @tool
 
103
  def web_search(query: str) -> str:
104
  """Search Tavily for a query and return formatted results."""
105
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
106
+ return "\n\n---\n\n".join([doc.page_content for doc in search_docs])
 
 
 
 
 
107
 
108
  @tool
 
109
  def arvix_search(query: str) -> str:
110
  """Search Arxiv for a query and return formatted results."""
111
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
112
+ return "\n\n---\n\n".join([doc.page_content[:1000] for doc in search_docs])
113
+
114
+ @tool
115
+ def youtube_summary(video_url: str) -> str:
116
+ """Fetch and summarize a YouTube video using transcript (if available)."""
117
+ import re
118
+ match = re.search(r"(?<=v=|youtu.be/)[^&#]+", video_url)
119
+ if not match:
120
+ return "Invalid YouTube URL."
121
+ video_id = match.group()
122
+ try:
123
+ transcript = YouTubeTranscriptApi.get_transcript(video_id)
124
+ return " ".join([seg["text"] for seg in transcript])[:3000]
125
+ except Exception as e:
126
+ return f"Transcript not available or error: {e}"
127
+
128
+ @tool
129
+ def image_caption(image_url: str) -> str:
130
+ """Generate a description of an image from a public URL."""
131
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
132
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
133
+ image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
134
+ inputs = processor(image, return_tensors="pt")
135
+ out = model.generate(**inputs)
136
+ return processor.decode(out[0], skip_special_tokens=True)
137
+
138
+ @tool
139
+ def qa_reference(query: str) -> str:
140
+ """Search example QA dataset for similar questions and return the closest answer."""
141
+ query_embedding = embeddings_model.embed_query(query)
142
+ sims = cosine_similarity([query_embedding], qa_embeddings)[0]
143
+ top_idx = int(np.argmax(sims))
144
+ return f"Similar question: {qa_df.question[top_idx]}\nAnswer: {qa_df.answer[top_idx]}"
145
+
146
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
147
  system_prompt = f.read()
148
 
149
  sys_msg = SystemMessage(content=system_prompt)
150
 
151
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
152
  vector_store = FAISS.from_texts(["Sample text 1", "Sample text 2"], embedding=embeddings)
 
 
 
 
 
153
 
154
  tools = [
155
  multiply,
 
160
  wiki_search,
161
  web_search,
162
  arvix_search,
163
+ youtube_summary,
164
+ image_caption,
165
+ qa_reference,
166
  ]
167
 
168
  def build_graph(provider: str = "huggingface"):
 
171
  elif provider == "groq":
172
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
173
  elif provider == "huggingface":
174
+ llm = LocalChatModel()
175
  else:
176
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
177
 
178
  def assistant(state: MessagesState):
179
  return {"messages": [llm.invoke(state["messages"]) ]}
180
 
 
 
 
 
 
 
 
181
  builder = StateGraph(MessagesState)
 
182
  builder.add_node("assistant", assistant)
183
  builder.add_node("tools", ToolNode(tools))
184
+ builder.add_edge(START, "assistant")
185
+ builder.add_conditional_edges("assistant", tools_condition)
 
 
 
 
186
  builder.add_edge("tools", "assistant")
187
 
188
  return builder.compile()
189
 
190
  if __name__ == "__main__":
191
+ question = "Describe this image: https://example.com/sample.jpg"
192
  graph = build_graph(provider="huggingface")
193
  messages = [HumanMessage(content=question)]
194
  messages = graph.invoke({"messages": messages})
requirements.txt CHANGED
@@ -17,4 +17,5 @@ wikipedia
17
  pgvector
18
  python-dotenv
19
  faiss-cpu
20
- sentencepiece
 
 
17
  pgvector
18
  python-dotenv
19
  faiss-cpu
20
+ sentencepiece
21
+ youtube-transcript-api