superone001's picture
Update agent.py
1d755bf verified
raw
history blame
1.68 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from .ai_tools import Calculator, DocRetriever, WebSearcher
from .graph import GaiaGraph
class GaiaAgent:
def __init__(self, model_name="HuggingFaceH4/zephyr-7b-beta"):
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.llm_pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer
)
# Initialize tools
self.calculator = Calculator()
self.doc_retriever = DocRetriever()
self.web_searcher = WebSearcher()
# Create tool list
self.tools = [
self.calculator,
self.web_searcher,
self.doc_retriever
]
# Build LangGraph workflow
self.graph = GaiaGraph(
model=self.llm_pipeline,
tokenizer=self.tokenizer,
tools=self.tools
)
print(f"GaiaAgent initialized with model: {model_name}")
def load_document(self, document_text: str):
"""Load document content for retrieval"""
self.doc_retriever.load_document(document_text)
print(f"Document loaded ({len(document_text)} characters)")
def __call__(self, question: str) -> str:
print(f"Agent received question: {question[:50]}{'...' if len(question) > 50 else ''}")
result = self.graph.run(question)
print(f"Agent returning answer: {result[:50]}{'...' if len(result) > 50 else ''}")
return result