Agasthya0's picture
Update app.py
1d425f0 verified
raw
history blame
2.03 kB
import os
import torch
from fastapi import FastAPI, Request
from fastapi.responses import FileResponse, JSONResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import uvicorn
# ------------------------------
# Config
# ------------------------------
BASE_MODEL = "deepseek-ai/deepseek-coder-6.7b-base"
ADAPTER_PATH = "Agasthya0/colabmind-coder-6.7b-ml-qlora"
# ------------------------------
# Load Model + Tokenizer
# ------------------------------
print("πŸš€ Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
print("🧠 Loading base model in 4-bit...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
load_in_4bit=True,
device_map="auto",
torch_dtype=torch.float16
)
print("πŸ”— Attaching LoRA adapter...")
model = PeftModel.from_pretrained(model, ADAPTER_PATH)
# ------------------------------
# Inference Function
# ------------------------------
def generate_code(prompt: str):
if not prompt.strip():
return "⚠️ Please enter a prompt."
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.2,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# ------------------------------
# FastAPI App
# ------------------------------
app = FastAPI()
@app.get("/")
def serve_frontend():
return FileResponse("index.html")
@app.post("/run/predict")
async def predict(request: Request):
data = await request.json()
prompt = data.get("data", [""])[0]
output = generate_code(prompt)
return JSONResponse({"data": [output]})
# ------------------------------
# Run (for local debugging, Spaces ignores this)
# ------------------------------
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)