Agent / agent /agent_1.py
Pulkit-bristol
Final optimal code
3ef6f19
"""LangGraph Agent"""
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import FAISS
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from transformers import AutoModelForCausalLM, AutoTokenizer, BlipProcessor, BlipForConditionalGeneration
from youtube_transcript_api import YouTubeTranscriptApi
from PIL import Image
import requests
import torch
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
#from load_agent import QAResponder
load_dotenv()
# Load QA pairs and compute embeddings once
qa_df = pd.read_csv("statics/qa_pairs.csv")
embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
qa_embeddings = embeddings_model.embed_documents(qa_df["question"].tolist())
# facebook/blenderbot-400M-distill
# TinyLlama/TinyLlama-1.1B-Chat-v1.0
# gpt2
# mistralai/Mistral-Small-Instruct-2409
class LocalChatModel:
def __init__(self, model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
print(f"Loading {model_name} on CPU...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.model.eval()
def invoke(self, messages: list) -> AIMessage:
chat = []
for msg in messages:
if isinstance(msg, SystemMessage):
chat.append({"role": "system", "content": msg.content})
elif isinstance(msg, HumanMessage):
chat.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
chat.append({"role": "assistant", "content": msg.content})
prompt = self.tokenizer.apply_chat_template(
chat,
tokenize=False,
add_generation_prompt=True
)
inputs = self.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt):].strip()
return AIMessage(content=response)
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two integers."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract second integer from first."""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide first integer by second. Raises error if divisor is zero."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Get the modulus (remainder) of first integer divided by second."""
return a % b
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a query and return formatted results."""
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join([doc.page_content for doc in search_docs])
@tool
def web_search(query: str) -> str:
"""Search Tavily for a query and return formatted results."""
search_docs = TavilySearchResults(max_results=3).invoke(query=query)
return "\n\n---\n\n".join([doc.page_content for doc in search_docs])
@tool
def arvix_search(query: str) -> str:
"""Search Arxiv for a query and return formatted results."""
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join([doc.page_content[:1000] for doc in search_docs])
@tool
def youtube_summary(video_url: str) -> str:
"Fetch and summarize a YouTube video using transcript (if available)."
import re
match = re.search(r"(?<=v=|youtu.be/)[^&#]+", video_url)
if not match:
return "Invalid YouTube URL."
video_id = match.group()
try:
transcript = YouTubeTranscriptApi.get_transcript(video_id)
return " ".join([seg["text"] for seg in transcript])[:3000]
except Exception as e:
return f"Transcript not available or error: {e}"
@tool
def image_caption(image_url: str) -> str:
"Generate a description of an image from a public URL."
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
inputs = processor(image, return_tensors="pt")
out = model.generate(**inputs)
return processor.decode(out[0], skip_special_tokens=True)
@tool
def qa_reference(query: str) -> str:
"""Search example QA dataset for similar questions and return the closest answer."""
query_embedding = embeddings_model.embed_query(query)
sims = cosine_similarity([query_embedding], qa_embeddings)[0]
top_idx = int(np.argmax(sims))
return f"Similar question: {qa_df.question[top_idx]}\nAnswer: {qa_df.answer[top_idx]}"
with open("statics/system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
sys_msg = SystemMessage(content=system_prompt)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
vector_store = FAISS.from_texts(["Sample text 1", "Sample text 2"], embedding=embeddings)
tools = [
multiply,
add,
subtract,
divide,
modulus,
wiki_search,
web_search,
arvix_search,
youtube_summary,
image_caption,
qa_reference,
]
def build_graph(provider: str = "huggingface"):
if provider == "google":
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
elif provider == "groq":
llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
elif provider == "huggingface":
llm = LocalChatModel()
else:
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
def assistant(state: MessagesState):
return {"messages": [llm.invoke(state["messages"]) ]}
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
if __name__ == "__main__":
question = "Describe this image: https://example.com/sample.jpg"
graph = build_graph(provider="huggingface")
messages = [HumanMessage(content=question)]
messages = graph.invoke({"messages": messages})
for m in messages["messages"]:
m.pretty_print()