Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import torch | |
from unsloth import FastLanguageModel | |
from peft import PeftModel | |
from transformers import AutoTokenizer | |
class ModelManager: | |
_instance = None | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model, self.tokenizer = self.load_model() | |
def get_instance(cls): | |
if cls._instance is None: | |
cls._instance = cls() | |
return cls._instance | |
def load_model(self): | |
# Load base model | |
backbone, tokenizer = FastLanguageModel.from_pretrained( | |
"unsloth/Llama-3.2-1B-Instruct-bnb-4bit", | |
load_in_4bit=True, | |
dtype=torch.float16, | |
device_map=self.device, | |
) | |
# Load your fine-tuned adapter | |
try: | |
model = PeftModel.from_pretrained( | |
backbone, | |
"samith-a/Django-orm-code-gen", | |
torch_dtype=torch.float16, | |
device_map=self.device, | |
) | |
print("Adapter weights loaded successfully") | |
except Exception as e: | |
print(f"Error loading adapter: {e}") | |
model = backbone | |
FastLanguageModel.for_inference(model) | |
return model, tokenizer | |
def generate(self, instruction: str, input_text: str, max_new_tokens: int = 128) -> str: | |
alpaca_template = ( | |
"### Instruction:\n{}\n\n" | |
"### Input:\n{}\n\n" | |
"### Response:\n" | |
) | |
prompt = alpaca_template.format(instruction, input_text) | |
encoded = self.tokenizer([prompt], return_tensors="pt").to(self.device) | |
outputs = self.model.generate(**encoded, max_new_tokens=max_new_tokens) | |
raw = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return raw.split("### Response:")[-1].strip() | |
# Instantiate once | |
manager = ModelManager.get_instance() | |
def predict(instruction, context, max_tokens=128): | |
return manager.generate(instruction, context, max_new_tokens=int(max_tokens)) | |
# Gradio UI / API | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox(lines=2, label="Instruction", placeholder="Describe what you want…"), | |
gr.Textbox(lines=5, label="Input (code/context)", placeholder="Optional context…"), | |
gr.Slider(minimum=16, maximum=512, step=16, label="Max new tokens", value=128), | |
], | |
outputs=gr.Textbox(label="Generated Code"), | |
title="Django-ORM Code Generator", | |
description="Ask the LoRA-finetuned LLaMA3.2 model to generate or modify Django ORM code.", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |