Spaces:
Sleeping
Sleeping
# agent.py - Enhanced LLaMA model wrapper and LangChain agent support | |
import os | |
from typing import Optional, Dict, List | |
from llama_cpp import Llama | |
from app.chat_memory import PersistentMemory as Memory | |
from app.embeddings import DocStore | |
from app.tools import get_tools | |
from app.langchain_agent import create_langchain_agent | |
from app.model_utils import download_model_if_missing, list_available_models | |
# =============================== | |
# Configuration & Utilities | |
# =============================== | |
MODEL_DIR = os.getenv("MODEL_DIR", "models") | |
DEFAULT_MODEL_PATH = os.getenv( | |
"MODEL_PATH", | |
download_model_if_missing() or os.path.join(MODEL_DIR, "capybarahermes-2.5-mistral-7b.Q5_K_S.gguf") | |
) | |
try: | |
import llama_cpp | |
llama_cpp_available = True | |
except Exception as e: | |
print("β Failed to load llama_cpp:", e) | |
llama_cpp_available = False | |
def list_models() -> List[str]: | |
"""List available .gguf models in the model directory.""" | |
if not os.path.exists(MODEL_DIR): | |
return [] | |
return [f for f in os.listdir(MODEL_DIR) if f.endswith(".gguf")] | |
def set_model_path(name: str) -> str: | |
"""Build and verify full path to a model file.""" | |
path = os.path.join(MODEL_DIR, name) | |
if not os.path.exists(path): | |
raise FileNotFoundError(f"β οΈ Model not found: {path}") | |
return path | |
# =============================== | |
# Core Local LLaMA Wrapper Class | |
# =============================== | |
class LocalLLMAgent: | |
def __init__(self, model_path: str = DEFAULT_MODEL_PATH, docstore: Optional[DocStore] = None): | |
self.model_path = model_path | |
self.llm = self._load_llm() | |
self.mem = Memory() | |
self.docs = docstore | |
def _load_llm(self) -> Llama: | |
"""Initialize and return LLaMA model.""" | |
return Llama( | |
model_path=self.model_path, | |
n_ctx=2048, | |
n_threads=8, | |
n_gpu_layers=40, | |
verbose=False | |
) | |
def chat(self, prompt: str) -> str: | |
"""Chat with context-aware memory.""" | |
ctx = self.mem.get_last() | |
full_prompt = f"{ctx}\nUser: {prompt}\nAI:" | |
response = self.llm(full_prompt, max_tokens=256, stop=["User:", "\n"]) | |
answer = response["choices"][0]["text"].strip() | |
self.mem.add(prompt, answer) | |
return answer | |
def ask(self, question: str) -> str: | |
"""Simple Q&A without memory.""" | |
response = self.llm(f"Q: {question}\nA:", max_tokens=256, stop=["Q:", "\n"]) | |
return response["choices"][0]["text"].strip() | |
def ask_doc(self, question: str) -> Dict[str, str]: | |
"""Ask a question against the document store.""" | |
if not self.docs: | |
raise ValueError("β Document store not initialized.") | |
meta, chunk = self.docs.retrieve(question) | |
context = f"Relevant content:\n{chunk}\nQuestion: {question}\nAnswer:" | |
response = self.llm(context, max_tokens=256, stop=["Question:", "\n"]) | |
return { | |
"source": meta, | |
"answer": response["choices"][0]["text"].strip() | |
} | |
def reset_memory(self): | |
"""Clear memory context.""" | |
self.mem.clear() | |
def switch_model(self, model_name: str): | |
"""Dynamically switch the model being used.""" | |
self.model_path = set_model_path(model_name) | |
self.llm = self._load_llm() | |
print(f"β Model switched to {model_name}") | |
# =============================== | |
# Lightweight One-Shot Chat | |
# =============================== | |
_basic_llm = Llama( | |
model_path=DEFAULT_MODEL_PATH, | |
n_ctx=2048, | |
n_threads=8, | |
n_gpu_layers=40, | |
verbose=False | |
) | |
def local_llm_chat(prompt: str) -> str: | |
"""Simple one-shot LLaMA call without memory.""" | |
response = _basic_llm(f"[INST] {prompt} [/INST]", stop=["</s>"], max_tokens=1024) | |
return response["choices"][0]["text"].strip() | |
# =============================== | |
# LangChain Tool Agent Interface | |
# =============================== | |
def run_agent(message: str) -> str: | |
"""Execute LangChain agent with tools and memory.""" | |
tools = get_tools() | |
memory = Memory() | |
agent_executor = create_langchain_agent(tools, memory) | |
return agent_executor.run(message) | |
# =============================== | |
# Optional Debug/Test Mode | |
# =============================== | |
if __name__ == "__main__": | |
print("π Available Models:", list_models()) | |
agent = LocalLLMAgent() | |
print("π€", agent.chat("Hello! Who are you?")) | |
print("π§ ", agent.ask("What is the capital of France?")) | |