File size: 1,681 Bytes
1310897 1d755bf 420e08f 1d755bf 420e08f 1d755bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from .ai_tools import Calculator, DocRetriever, WebSearcher
from .graph import GaiaGraph
class GaiaAgent:
def __init__(self, model_name="HuggingFaceH4/zephyr-7b-beta"):
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.llm_pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer
)
# Initialize tools
self.calculator = Calculator()
self.doc_retriever = DocRetriever()
self.web_searcher = WebSearcher()
# Create tool list
self.tools = [
self.calculator,
self.web_searcher,
self.doc_retriever
]
# Build LangGraph workflow
self.graph = GaiaGraph(
model=self.llm_pipeline,
tokenizer=self.tokenizer,
tools=self.tools
)
print(f"GaiaAgent initialized with model: {model_name}")
def load_document(self, document_text: str):
"""Load document content for retrieval"""
self.doc_retriever.load_document(document_text)
print(f"Document loaded ({len(document_text)} characters)")
def __call__(self, question: str) -> str:
print(f"Agent received question: {question[:50]}{'...' if len(question) > 50 else ''}")
result = self.graph.run(question)
print(f"Agent returning answer: {result[:50]}{'...' if len(result) > 50 else ''}")
return result |