|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from .ai_tools import Calculator, DocRetriever, WebSearcher |
|
from .graph import GaiaGraph |
|
|
|
class GaiaAgent: |
|
def __init__(self, model_name="HuggingFaceH4/zephyr-7b-beta"): |
|
self.model_name = model_name |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForCausalLM.from_pretrained(model_name) |
|
self.llm_pipeline = pipeline( |
|
"text-generation", |
|
model=self.model, |
|
tokenizer=self.tokenizer |
|
) |
|
|
|
|
|
self.calculator = Calculator() |
|
self.doc_retriever = DocRetriever() |
|
self.web_searcher = WebSearcher() |
|
|
|
|
|
self.tools = [ |
|
self.calculator, |
|
self.web_searcher, |
|
self.doc_retriever |
|
] |
|
|
|
|
|
self.graph = GaiaGraph( |
|
model=self.llm_pipeline, |
|
tokenizer=self.tokenizer, |
|
tools=self.tools |
|
) |
|
|
|
print(f"GaiaAgent initialized with model: {model_name}") |
|
|
|
def load_document(self, document_text: str): |
|
"""Load document content for retrieval""" |
|
self.doc_retriever.load_document(document_text) |
|
print(f"Document loaded ({len(document_text)} characters)") |
|
|
|
def __call__(self, question: str) -> str: |
|
print(f"Agent received question: {question[:50]}{'...' if len(question) > 50 else ''}") |
|
result = self.graph.run(question) |
|
print(f"Agent returning answer: {result[:50]}{'...' if len(result) > 50 else ''}") |
|
return result |