Spaces:
Sleeping
Sleeping
File size: 6,495 Bytes
293ab16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import base64
import os
import subprocess
import threading
import logging
from typing import Any, Dict, List, Optional, Union
from app.vision import extract_text_from_image, describe_image
import uvicorn
from fastapi import (
FastAPI, APIRouter, Depends, UploadFile, File, Form, HTTPException, status
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from app import chat_memory
from app.auth import (
USERS_DB, authenticate_user, get_password_hash,
create_access_token, get_current_user, verify_token
)
from app.agent import LocalLLMAgent
from app.langchain_agent import make_local_agent
from app.email_tool import generate_email
from app.files_api import save_upload
from app.embeddings import DocStore, embed_file, query_file_chunks
# ✅ Option A - import the actual objects
from app.chat_memory import persistent_memory, chat_history, semantic_search
from app.tools import TOOLS, use_tool, get_tools
from app.vision import extract_text_from_image, caption_image
from app.audio_tool import transcribe_audio, text_to_speech
# Initialize app
app = FastAPI(title="🧠 LLaMA Local Agent")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
docs_store = DocStore()
agent = LocalLLMAgent("models/capybarahermes-2.5-mistral-7b.Q5_K_S.gguf", docs_store)
chain_agent = make_local_agent("models/capybarahermes-2.5-mistral-7b.Q5_K_S.gguf")
# === Auth Models ===
class TokenResponse(BaseModel):
access_token: str
token_type: str = "bearer"
class RegisterRequest(BaseModel):
username: str
password: str
role: Optional[str] = "user"
# === Routes ===
@app.post("/login", response_model=TokenResponse)
def login(username: str = Form(...), password: str = Form(...)):
user = authenticate_user(username, password)
if not user:
raise HTTPException(status_code=401, detail="Invalid credentials")
token = create_access_token({"sub": user["username"]})
return {"access_token": token, "token_type": "bearer"}
@app.post("/register", response_model=TokenResponse)
def register(payload: RegisterRequest):
if payload.username in USERS_DB:
raise HTTPException(status_code=400, detail="User exists")
USERS_DB[payload.username] = {
"username": payload.username,
"hashed_password": get_password_hash(payload.password),
"role": payload.role,
}
token = create_access_token({"sub": payload.username})
return {"access_token": token, "token_type": "bearer"}
@app.get("/me")
def read_current_user(user: dict = Depends(get_current_user)):
return {"username": user["username"], "role": user.get("role")}
@app.get("/tools", response_model=List[str])
def list_tools():
return list(TOOLS.keys())
@app.get("/tools/details")
def tool_details():
return [{"name": t.name, "description": t.description} for t in get_tools()]
class ToolRequest(BaseModel):
input: Optional[Union[str, Dict, List]] = None
@app.post("/tool/{tool_name}")
def call_tool(tool_name: str, req: ToolRequest, _=Depends(verify_token)):
data = req.input
if tool_name not in TOOLS:
raise HTTPException(404, f"Tool '{tool_name}' not found")
return {"tool": tool_name, "result": use_tool(tool_name, data)}
@app.post("/chat")
def chat(req: Dict[str, Any], _=Depends(verify_token)):
prompt = req.get("prompt")
chat_history.append_message("user", prompt)
resp = agent.local_llm_chat(prompt)
chat_history.append_message("ai", resp)
return {"response": resp}
@app.post("/agent")
def run_chain(req: Dict[str, Any], _=Depends(verify_token)):
prompt = req.get("prompt")
return {"result": chain_agent.run(prompt)}
@app.post("/upload")
async def upload_file(file: UploadFile = File(...), _=Depends(verify_token)):
path = save_upload(file)
embed_file(path)
return {"path": path}
@app.get("/docs")
def list_docs():
return os.listdir("uploaded_files")
@app.delete("/docs/{name}")
def delete_doc(name: str):
path = os.path.join("uploaded_files", name)
if os.path.exists(path):
os.remove(path)
return {"deleted": name}
raise HTTPException(404, "Not found")
@app.post("/ask-doc")
def ask_doc(req: Dict[str, Any], _=Depends(verify_token)):
prompt = req.get("prompt")
return agent.ask_doc(prompt)
@app.get("/query_file")
def query_file(filename: str, question: str):
return query_file_chunks(filename, question)
@app.post("/image-caption")
async def img_caption(file: UploadFile = File(...), _=Depends(verify_token)):
tmp = save_upload(file)
caption = caption_image(tmp)
return {"caption": caption}
@app.post("/ocr")
def ocr_image(base64_image: str):
return {"text": extract_text_from_image(base64_image)}
@app.post("/caption")
def caption_image_api(base64_image: str):
return {"caption": describe_image(base64.b64decode(base64_image))}
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...), _=Depends(verify_token)):
tmp = save_upload(file)
return {"transcription": transcribe_audio(tmp)}
@app.post("/speak")
def speak(text: str, _=Depends(verify_token)):
mp3 = text_to_speech(text)
return FileResponse(mp3, media_type="audio/mpeg")
@app.get("/generate_email")
def email_gen(to: str, product: str, discount: float, _=Depends(verify_token)):
return {"email": generate_email(to, product, discount)}
@app.get("/history/export")
def export_history():
return {"text": chat_memory.chat_history.export_history()}
@app.get("/search")
def search_chat(query: str):
return {"matches": chat_memory.chat_history.search_history(query)}
@app.get("/memory/stats")
def memory_stats(_=Depends(verify_token)):
return {"size": len(agent.mem.db.all())}
@app.post("/reset")
def reset_memory(_=Depends(verify_token)):
agent.reset()
return {"status": "cleared"}
# --- Launch both backend & frontend concurrently (optional) ---
def run_backend():
uvicorn.run("app.main_api:app", host="0.0.0.0", port=8000, reload=True)
def run_frontend():
subprocess.run(["streamlit", "run", "frontend/streamlit_app.py"])
if __name__ == "__main__":
t1 = threading.Thread(target=run_backend)
t2 = threading.Thread(target=run_frontend)
t1.start()
t2.start()
t1.join()
t2.join()
|