llama-models / app /main_api.py
deniskiplimo816's picture
Upload 27 files
293ab16 verified
import base64
import os
import subprocess
import threading
import logging
from typing import Any, Dict, List, Optional, Union
from app.vision import extract_text_from_image, describe_image
import uvicorn
from fastapi import (
FastAPI, APIRouter, Depends, UploadFile, File, Form, HTTPException, status
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from app import chat_memory
from app.auth import (
USERS_DB, authenticate_user, get_password_hash,
create_access_token, get_current_user, verify_token
)
from app.agent import LocalLLMAgent
from app.langchain_agent import make_local_agent
from app.email_tool import generate_email
from app.files_api import save_upload
from app.embeddings import DocStore, embed_file, query_file_chunks
# ✅ Option A - import the actual objects
from app.chat_memory import persistent_memory, chat_history, semantic_search
from app.tools import TOOLS, use_tool, get_tools
from app.vision import extract_text_from_image, caption_image
from app.audio_tool import transcribe_audio, text_to_speech
# Initialize app
app = FastAPI(title="🧠 LLaMA Local Agent")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
docs_store = DocStore()
agent = LocalLLMAgent("models/capybarahermes-2.5-mistral-7b.Q5_K_S.gguf", docs_store)
chain_agent = make_local_agent("models/capybarahermes-2.5-mistral-7b.Q5_K_S.gguf")
# === Auth Models ===
class TokenResponse(BaseModel):
access_token: str
token_type: str = "bearer"
class RegisterRequest(BaseModel):
username: str
password: str
role: Optional[str] = "user"
# === Routes ===
@app.post("/login", response_model=TokenResponse)
def login(username: str = Form(...), password: str = Form(...)):
user = authenticate_user(username, password)
if not user:
raise HTTPException(status_code=401, detail="Invalid credentials")
token = create_access_token({"sub": user["username"]})
return {"access_token": token, "token_type": "bearer"}
@app.post("/register", response_model=TokenResponse)
def register(payload: RegisterRequest):
if payload.username in USERS_DB:
raise HTTPException(status_code=400, detail="User exists")
USERS_DB[payload.username] = {
"username": payload.username,
"hashed_password": get_password_hash(payload.password),
"role": payload.role,
}
token = create_access_token({"sub": payload.username})
return {"access_token": token, "token_type": "bearer"}
@app.get("/me")
def read_current_user(user: dict = Depends(get_current_user)):
return {"username": user["username"], "role": user.get("role")}
@app.get("/tools", response_model=List[str])
def list_tools():
return list(TOOLS.keys())
@app.get("/tools/details")
def tool_details():
return [{"name": t.name, "description": t.description} for t in get_tools()]
class ToolRequest(BaseModel):
input: Optional[Union[str, Dict, List]] = None
@app.post("/tool/{tool_name}")
def call_tool(tool_name: str, req: ToolRequest, _=Depends(verify_token)):
data = req.input
if tool_name not in TOOLS:
raise HTTPException(404, f"Tool '{tool_name}' not found")
return {"tool": tool_name, "result": use_tool(tool_name, data)}
@app.post("/chat")
def chat(req: Dict[str, Any], _=Depends(verify_token)):
prompt = req.get("prompt")
chat_history.append_message("user", prompt)
resp = agent.local_llm_chat(prompt)
chat_history.append_message("ai", resp)
return {"response": resp}
@app.post("/agent")
def run_chain(req: Dict[str, Any], _=Depends(verify_token)):
prompt = req.get("prompt")
return {"result": chain_agent.run(prompt)}
@app.post("/upload")
async def upload_file(file: UploadFile = File(...), _=Depends(verify_token)):
path = save_upload(file)
embed_file(path)
return {"path": path}
@app.get("/docs")
def list_docs():
return os.listdir("uploaded_files")
@app.delete("/docs/{name}")
def delete_doc(name: str):
path = os.path.join("uploaded_files", name)
if os.path.exists(path):
os.remove(path)
return {"deleted": name}
raise HTTPException(404, "Not found")
@app.post("/ask-doc")
def ask_doc(req: Dict[str, Any], _=Depends(verify_token)):
prompt = req.get("prompt")
return agent.ask_doc(prompt)
@app.get("/query_file")
def query_file(filename: str, question: str):
return query_file_chunks(filename, question)
@app.post("/image-caption")
async def img_caption(file: UploadFile = File(...), _=Depends(verify_token)):
tmp = save_upload(file)
caption = caption_image(tmp)
return {"caption": caption}
@app.post("/ocr")
def ocr_image(base64_image: str):
return {"text": extract_text_from_image(base64_image)}
@app.post("/caption")
def caption_image_api(base64_image: str):
return {"caption": describe_image(base64.b64decode(base64_image))}
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...), _=Depends(verify_token)):
tmp = save_upload(file)
return {"transcription": transcribe_audio(tmp)}
@app.post("/speak")
def speak(text: str, _=Depends(verify_token)):
mp3 = text_to_speech(text)
return FileResponse(mp3, media_type="audio/mpeg")
@app.get("/generate_email")
def email_gen(to: str, product: str, discount: float, _=Depends(verify_token)):
return {"email": generate_email(to, product, discount)}
@app.get("/history/export")
def export_history():
return {"text": chat_memory.chat_history.export_history()}
@app.get("/search")
def search_chat(query: str):
return {"matches": chat_memory.chat_history.search_history(query)}
@app.get("/memory/stats")
def memory_stats(_=Depends(verify_token)):
return {"size": len(agent.mem.db.all())}
@app.post("/reset")
def reset_memory(_=Depends(verify_token)):
agent.reset()
return {"status": "cleared"}
# --- Launch both backend & frontend concurrently (optional) ---
def run_backend():
uvicorn.run("app.main_api:app", host="0.0.0.0", port=8000, reload=True)
def run_frontend():
subprocess.run(["streamlit", "run", "frontend/streamlit_app.py"])
if __name__ == "__main__":
t1 = threading.Thread(target=run_backend)
t2 = threading.Thread(target=run_frontend)
t1.start()
t2.start()
t1.join()
t2.join()