Spaces:
Paused
Paused
File size: 6,076 Bytes
8f3cc16 082be55 8f3cc16 54e054c dcfd9dc ba6870e 5d612fc 9607fd1 52476d9 da0bbab 1088bcc da0bbab dcfd9dc 1ba4249 1927641 6a15da8 a53045b 50b08fe 7c2d80e a53045b 50b08fe bf5e819 50b08fe 2b06017 bf5e819 10a538d 2b06017 10a538d 2b06017 10a538d da0bbab 8f3cc16 da0bbab 089b99f 8f3cc16 da0bbab 8f3cc16 299561a da0bbab 8f3cc16 da0bbab e06d1bb 089b99f 50b08fe c62394a 089b99f da0bbab 454505f 10a538d 454505f f9097ba 454505f 45524d1 da0bbab 10a538d 76c52a4 da0bbab e06d1bb 45524d1 da0bbab 8f3cc16 45524d1 10a538d da0bbab 8f3cc16 da0bbab 7d1cfd0 45524d1 10a538d 8f3cc16 c6393ae 8f3cc16 9e08c49 8f3cc16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import gradio as gr
from transformers import pipeline
from config import ModelArgs
from inference import remove_prefix
from model import Llama
import torch
from inference_sft import topk_sampling
import os
import subprocess
import re
from tokenizer import Tokenizer
import torch.nn.functional as F
import shutil
# Define model paths
model_paths = {
"SFT": "weights/fine_tuned/models--YuvrajSingh9886--smol-llama-finetuned/snapshots/35df7811c322dab5f8df56f6a1f61b5bcf6a7339/snapshot_fine_tuned_model_900.pt",
"DPO": "weights/DPO/models--YuvrajSingh9886--smol-llama-dpo/snapshots/13b9f059cf4630e6079e8822ea7cb7a703d77ed8/DPO_model_1650.pt",
"Pretrained": "weights/pretrained/models--YuvrajSingh9886--smol-llama-base/snapshots/8e50a94e3c83649e48eb02549558130c87e75a87/snapshot_6750.pt"
}
ACCESS_TOKEN = os.getenv("GDRIVE_ACCESS_TOKEN")
# def download_models():
for i in model_paths:
subprocess.run(["python", "download_model_weight.py", "--model_type", i.lower()], check=True)
# download_models()
tk = Tokenizer()
tk = tk.ready_tokenizer()
def beam_search(model, prompt, device, max_length=50, beam_width=5, top_k=50, temperature=1.0):
input_ids = tk.encode(prompt, return_tensors='pt').to(device)
# Initialize beams with initial input repeated beam_width times
beams = input_ids.repeat(beam_width, 1)
beam_scores = torch.zeros(beam_width).to(device) # Initialize scores
for _ in range(max_length):
with torch.no_grad():
outputs = model(beams)
logits = outputs[:, -1, :] # Get last token logits
# Apply temperature scaling
scaled_logits = logits / temperature
# Calculate log probabilities
log_probs = F.log_softmax(scaled_logits, dim=-1)
# Get top k candidates for each beam
topk_log_probs, topk_indices = torch.topk(log_probs, top_k, dim=-1)
# Generate all possible candidates
expanded_beams = beams.repeat_interleave(top_k, dim=0)
new_tokens = topk_indices.view(-1, 1)
candidate_beams = torch.cat([expanded_beams, new_tokens], dim=1)
# Calculate new scores for all candidates
expanded_scores = beam_scores.repeat_interleave(top_k)
candidate_scores = expanded_scores + topk_log_probs.view(-1)
# Select top beam_width candidates
top_scores, top_indices = candidate_scores.topk(beam_width)
beams = candidate_beams[top_indices]
beam_scores = top_scores
# Select best beam
best_idx = beam_scores.argmax()
best_sequence = beams[best_idx]
return tk.decode(best_sequence, skip_special_tokens=True)
# Function to load the selected model
def load_model(model_type):
model_path = model_paths[model_type]
# Check if the model exists; if not, download it
# if not os.path.exists(model_path):
# shutil.rmtree(model_path)
# os.mkdir(model_path)
# print(f"{model_type} Model not found! Downloading...")
# subprocess.run(["python", "download_model_weight.py", f"--{model_type.lower()}"], check=True)
# else:
# print(f"{model_type} Model found, skipping download.")
# Load the model
model = Llama(
device=ModelArgs.device,
embeddings_dims=ModelArgs.embeddings_dims,
no_of_decoder_layers=ModelArgs.no_of_decoder_layers,
block_size=ModelArgs.block_size,
vocab_size=ModelArgs.vocab_size,
dropout=ModelArgs.dropout
)
model = model.to(ModelArgs.device)
dict_model = torch.load(model_path, weights_only=False)
dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.')
model.load_state_dict(dict_model['MODEL_STATE'])
model.eval()
return model
# download_models()
current_model = load_model("SFT")
def clean_prompt(text):
text = re.sub(r'### Instruction:\s*', '', text)
text = re.sub(r'### Input:\s*', '', text)
text = re.sub(r'### Response:\s*', '', text)
return text.strip()
def answer_question(model_type, prompt, temperature, top_k, max_length):
global current_model
# Reload model if the selected model type is different
if model_type == "Base (Pretrained)":
model_type = "Pretrained"
if model_paths[model_type] != model_paths.get(current_model, None):
current_model = load_model(model_type)
formatted_prompt = f"### Instruction: Answer the following query. \n\n ### Input: {prompt}.\n\n ### Response: "
with torch.no_grad():
# if decoding_method == "Beam Search":
# generated_text = beam_search(current_model, formatted_prompt, device=ModelArgs.device,
# max_length=max_length, beam_width=5, top_k=top_k, temperature=temperature)
# else:
generated_text = topk_sampling(current_model, formatted_prompt, max_length=max_length,
top_k=top_k, temperature=temperature, device=ModelArgs.device)
generated_text = clean_prompt(generated_text)
return generated_text
iface = gr.Interface(
fn=answer_question,
inputs=[
gr.Dropdown(choices=["SFT", "DPO", "Base (Pretrained)"], value="DPO", label="Select Model"),
# gr.Dropdown(choices=["Top-K Sampling", "Beam Search"], value="Top-K Sampling", label="Decoding Method"),
gr.Textbox(label="Prompt", lines=5),
gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature"),
gr.Slider(minimum=50,maximum = ModelArgs.vocab_size, value=50, step=1, label="Top-k"),
gr.Slider(minimum=10, maximum=ModelArgs.block_size, value=256, step=1, label="Max Length")
],
outputs=gr.Textbox(label="Answer"),
title="SmolLlama",
description="Enter a prompt, select a model (SFT or DPO) and the model will generate an answer."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()
|