Spaces:
Sleeping
Sleeping
import base64 | |
import os | |
import subprocess | |
import threading | |
import logging | |
from typing import Any, Dict, List, Optional, Union | |
from app.vision import extract_text_from_image, describe_image | |
import uvicorn | |
from fastapi import ( | |
FastAPI, APIRouter, Depends, UploadFile, File, Form, HTTPException, status | |
) | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import FileResponse | |
from pydantic import BaseModel | |
from app import chat_memory | |
from app.auth import ( | |
USERS_DB, authenticate_user, get_password_hash, | |
create_access_token, get_current_user, verify_token | |
) | |
from app.agent import LocalLLMAgent | |
from app.langchain_agent import make_local_agent | |
from app.email_tool import generate_email | |
from app.files_api import save_upload | |
from app.embeddings import DocStore, embed_file, query_file_chunks | |
# ✅ Option A - import the actual objects | |
from app.chat_memory import persistent_memory, chat_history, semantic_search | |
from app.tools import TOOLS, use_tool, get_tools | |
from app.vision import extract_text_from_image, caption_image | |
from app.audio_tool import transcribe_audio, text_to_speech | |
# Initialize app | |
app = FastAPI(title="🧠 LLaMA Local Agent") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
docs_store = DocStore() | |
agent = LocalLLMAgent("models/capybarahermes-2.5-mistral-7b.Q5_K_S.gguf", docs_store) | |
chain_agent = make_local_agent("models/capybarahermes-2.5-mistral-7b.Q5_K_S.gguf") | |
# === Auth Models === | |
class TokenResponse(BaseModel): | |
access_token: str | |
token_type: str = "bearer" | |
class RegisterRequest(BaseModel): | |
username: str | |
password: str | |
role: Optional[str] = "user" | |
# === Routes === | |
def login(username: str = Form(...), password: str = Form(...)): | |
user = authenticate_user(username, password) | |
if not user: | |
raise HTTPException(status_code=401, detail="Invalid credentials") | |
token = create_access_token({"sub": user["username"]}) | |
return {"access_token": token, "token_type": "bearer"} | |
def register(payload: RegisterRequest): | |
if payload.username in USERS_DB: | |
raise HTTPException(status_code=400, detail="User exists") | |
USERS_DB[payload.username] = { | |
"username": payload.username, | |
"hashed_password": get_password_hash(payload.password), | |
"role": payload.role, | |
} | |
token = create_access_token({"sub": payload.username}) | |
return {"access_token": token, "token_type": "bearer"} | |
def read_current_user(user: dict = Depends(get_current_user)): | |
return {"username": user["username"], "role": user.get("role")} | |
def list_tools(): | |
return list(TOOLS.keys()) | |
def tool_details(): | |
return [{"name": t.name, "description": t.description} for t in get_tools()] | |
class ToolRequest(BaseModel): | |
input: Optional[Union[str, Dict, List]] = None | |
def call_tool(tool_name: str, req: ToolRequest, _=Depends(verify_token)): | |
data = req.input | |
if tool_name not in TOOLS: | |
raise HTTPException(404, f"Tool '{tool_name}' not found") | |
return {"tool": tool_name, "result": use_tool(tool_name, data)} | |
def chat(req: Dict[str, Any], _=Depends(verify_token)): | |
prompt = req.get("prompt") | |
chat_history.append_message("user", prompt) | |
resp = agent.local_llm_chat(prompt) | |
chat_history.append_message("ai", resp) | |
return {"response": resp} | |
def run_chain(req: Dict[str, Any], _=Depends(verify_token)): | |
prompt = req.get("prompt") | |
return {"result": chain_agent.run(prompt)} | |
async def upload_file(file: UploadFile = File(...), _=Depends(verify_token)): | |
path = save_upload(file) | |
embed_file(path) | |
return {"path": path} | |
def list_docs(): | |
return os.listdir("uploaded_files") | |
def delete_doc(name: str): | |
path = os.path.join("uploaded_files", name) | |
if os.path.exists(path): | |
os.remove(path) | |
return {"deleted": name} | |
raise HTTPException(404, "Not found") | |
def ask_doc(req: Dict[str, Any], _=Depends(verify_token)): | |
prompt = req.get("prompt") | |
return agent.ask_doc(prompt) | |
def query_file(filename: str, question: str): | |
return query_file_chunks(filename, question) | |
async def img_caption(file: UploadFile = File(...), _=Depends(verify_token)): | |
tmp = save_upload(file) | |
caption = caption_image(tmp) | |
return {"caption": caption} | |
def ocr_image(base64_image: str): | |
return {"text": extract_text_from_image(base64_image)} | |
def caption_image_api(base64_image: str): | |
return {"caption": describe_image(base64.b64decode(base64_image))} | |
async def transcribe(file: UploadFile = File(...), _=Depends(verify_token)): | |
tmp = save_upload(file) | |
return {"transcription": transcribe_audio(tmp)} | |
def speak(text: str, _=Depends(verify_token)): | |
mp3 = text_to_speech(text) | |
return FileResponse(mp3, media_type="audio/mpeg") | |
def email_gen(to: str, product: str, discount: float, _=Depends(verify_token)): | |
return {"email": generate_email(to, product, discount)} | |
def export_history(): | |
return {"text": chat_memory.chat_history.export_history()} | |
def search_chat(query: str): | |
return {"matches": chat_memory.chat_history.search_history(query)} | |
def memory_stats(_=Depends(verify_token)): | |
return {"size": len(agent.mem.db.all())} | |
def reset_memory(_=Depends(verify_token)): | |
agent.reset() | |
return {"status": "cleared"} | |
# --- Launch both backend & frontend concurrently (optional) --- | |
def run_backend(): | |
uvicorn.run("app.main_api:app", host="0.0.0.0", port=8000, reload=True) | |
def run_frontend(): | |
subprocess.run(["streamlit", "run", "frontend/streamlit_app.py"]) | |
if __name__ == "__main__": | |
t1 = threading.Thread(target=run_backend) | |
t2 = threading.Thread(target=run_frontend) | |
t1.start() | |
t2.start() | |
t1.join() | |
t2.join() | |