samith-a's picture
django model try, no-accesss-token
14dbea3
raw
history blame
2.67 kB
# app.py
import gradio as gr
import torch
from unsloth import FastLanguageModel
from peft import PeftModel
from transformers import AutoTokenizer
class ModelManager:
_instance = None
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, self.tokenizer = self.load_model()
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def load_model(self):
# Load base model
backbone, tokenizer = FastLanguageModel.from_pretrained(
"unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
load_in_4bit=True,
dtype=torch.float16,
device_map=self.device,
)
# Load your fine-tuned adapter
try:
model = PeftModel.from_pretrained(
backbone,
"samith-a/Django-orm-code-gen",
torch_dtype=torch.float16,
device_map=self.device,
)
print("Adapter weights loaded successfully")
except Exception as e:
print(f"Error loading adapter: {e}")
model = backbone
FastLanguageModel.for_inference(model)
return model, tokenizer
def generate(self, instruction: str, input_text: str, max_new_tokens: int = 128) -> str:
alpaca_template = (
"### Instruction:\n{}\n\n"
"### Input:\n{}\n\n"
"### Response:\n"
)
prompt = alpaca_template.format(instruction, input_text)
encoded = self.tokenizer([prompt], return_tensors="pt").to(self.device)
outputs = self.model.generate(**encoded, max_new_tokens=max_new_tokens)
raw = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return raw.split("### Response:")[-1].strip()
# Instantiate once
manager = ModelManager.get_instance()
def predict(instruction, context, max_tokens=128):
return manager.generate(instruction, context, max_new_tokens=int(max_tokens))
# Gradio UI / API
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(lines=2, label="Instruction", placeholder="Describe what you want…"),
gr.Textbox(lines=5, label="Input (code/context)", placeholder="Optional context…"),
gr.Slider(minimum=16, maximum=512, step=16, label="Max new tokens", value=128),
],
outputs=gr.Textbox(label="Generated Code"),
title="Django-ORM Code Generator",
description="Ask the LoRA-finetuned LLaMA3.2 model to generate or modify Django ORM code.",
)
if __name__ == "__main__":
demo.launch()