Spaces:
Sleeping
Sleeping
"""LangGraph Agent""" | |
import os | |
from dotenv import load_dotenv | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import tools_condition | |
from langgraph.prebuilt import ToolNode | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain_core.tools import tool | |
from langchain.tools.retriever import create_retriever_tool | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BlipProcessor, BlipForConditionalGeneration | |
from youtube_transcript_api import YouTubeTranscriptApi | |
from PIL import Image | |
import requests | |
import torch | |
import pandas as pd | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
#from load_agent import QAResponder | |
load_dotenv() | |
# Load QA pairs and compute embeddings once | |
qa_df = pd.read_csv("statics/qa_pairs.csv") | |
embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
qa_embeddings = embeddings_model.embed_documents(qa_df["question"].tolist()) | |
# facebook/blenderbot-400M-distill | |
# TinyLlama/TinyLlama-1.1B-Chat-v1.0 | |
# gpt2 | |
# mistralai/Mistral-Small-Instruct-2409 | |
class LocalChatModel: | |
def __init__(self, model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"): | |
print(f"Loading {model_name} on CPU...") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name) | |
self.model.eval() | |
def invoke(self, messages: list) -> AIMessage: | |
chat = [] | |
for msg in messages: | |
if isinstance(msg, SystemMessage): | |
chat.append({"role": "system", "content": msg.content}) | |
elif isinstance(msg, HumanMessage): | |
chat.append({"role": "user", "content": msg.content}) | |
elif isinstance(msg, AIMessage): | |
chat.append({"role": "assistant", "content": msg.content}) | |
prompt = self.tokenizer.apply_chat_template( | |
chat, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = self.tokenizer(prompt, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response[len(prompt):].strip() | |
return AIMessage(content=response) | |
def multiply(a: int, b: int) -> int: | |
"""Multiply two integers.""" | |
return a * b | |
def add(a: int, b: int) -> int: | |
"""Add two integers.""" | |
return a + b | |
def subtract(a: int, b: int) -> int: | |
"""Subtract second integer from first.""" | |
return a - b | |
def divide(a: int, b: int) -> float: | |
"""Divide first integer by second. Raises error if divisor is zero.""" | |
if b == 0: | |
raise ValueError("Cannot divide by zero.") | |
return a / b | |
def modulus(a: int, b: int) -> int: | |
"""Get the modulus (remainder) of first integer divided by second.""" | |
return a % b | |
def wiki_search(query: str) -> str: | |
"""Search Wikipedia for a query and return formatted results.""" | |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
return "\n\n---\n\n".join([doc.page_content for doc in search_docs]) | |
def web_search(query: str) -> str: | |
"""Search Tavily for a query and return formatted results.""" | |
search_docs = TavilySearchResults(max_results=3).invoke(query=query) | |
return "\n\n---\n\n".join([doc.page_content for doc in search_docs]) | |
def arvix_search(query: str) -> str: | |
"""Search Arxiv for a query and return formatted results.""" | |
search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
return "\n\n---\n\n".join([doc.page_content[:1000] for doc in search_docs]) | |
def youtube_summary(video_url: str) -> str: | |
"Fetch and summarize a YouTube video using transcript (if available)." | |
import re | |
match = re.search(r"(?<=v=|youtu.be/)[^&#]+", video_url) | |
if not match: | |
return "Invalid YouTube URL." | |
video_id = match.group() | |
try: | |
transcript = YouTubeTranscriptApi.get_transcript(video_id) | |
return " ".join([seg["text"] for seg in transcript])[:3000] | |
except Exception as e: | |
return f"Transcript not available or error: {e}" | |
def image_caption(image_url: str) -> str: | |
"Generate a description of an image from a public URL." | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") | |
inputs = processor(image, return_tensors="pt") | |
out = model.generate(**inputs) | |
return processor.decode(out[0], skip_special_tokens=True) | |
def qa_reference(query: str) -> str: | |
"""Search example QA dataset for similar questions and return the closest answer.""" | |
query_embedding = embeddings_model.embed_query(query) | |
sims = cosine_similarity([query_embedding], qa_embeddings)[0] | |
top_idx = int(np.argmax(sims)) | |
return f"Similar question: {qa_df.question[top_idx]}\nAnswer: {qa_df.answer[top_idx]}" | |
with open("statics/system_prompt.txt", "r", encoding="utf-8") as f: | |
system_prompt = f.read() | |
sys_msg = SystemMessage(content=system_prompt) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
vector_store = FAISS.from_texts(["Sample text 1", "Sample text 2"], embedding=embeddings) | |
tools = [ | |
multiply, | |
add, | |
subtract, | |
divide, | |
modulus, | |
wiki_search, | |
web_search, | |
arvix_search, | |
youtube_summary, | |
image_caption, | |
qa_reference, | |
] | |
def build_graph(provider: str = "huggingface"): | |
if provider == "google": | |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0) | |
elif provider == "groq": | |
llm = ChatGroq(model="qwen-qwq-32b", temperature=0) | |
elif provider == "huggingface": | |
llm = LocalChatModel() | |
else: | |
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.") | |
def assistant(state: MessagesState): | |
return {"messages": [llm.invoke(state["messages"]) ]} | |
builder = StateGraph(MessagesState) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(tools)) | |
builder.add_edge(START, "assistant") | |
builder.add_conditional_edges("assistant", tools_condition) | |
builder.add_edge("tools", "assistant") | |
return builder.compile() | |
if __name__ == "__main__": | |
question = "Describe this image: https://example.com/sample.jpg" | |
graph = build_graph(provider="huggingface") | |
messages = [HumanMessage(content=question)] | |
messages = graph.invoke({"messages": messages}) | |
for m in messages["messages"]: | |
m.pretty_print() | |