File size: 6,076 Bytes
8f3cc16
 
 
082be55
8f3cc16
 
54e054c
dcfd9dc
 
ba6870e
5d612fc
9607fd1
52476d9
 
da0bbab
 
1088bcc
 
 
da0bbab
dcfd9dc
1ba4249
1927641
6a15da8
 
a53045b
50b08fe
 
7c2d80e
a53045b
50b08fe
bf5e819
50b08fe
2b06017
bf5e819
 
10a538d
 
2b06017
10a538d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b06017
10a538d
da0bbab
 
 
8f3cc16
da0bbab
089b99f
 
 
 
 
 
 
8f3cc16
da0bbab
 
 
 
 
 
 
 
 
 
8f3cc16
299561a
da0bbab
 
 
8f3cc16
da0bbab
e06d1bb
089b99f
50b08fe
c62394a
 
 
089b99f
da0bbab
454505f
 
10a538d
454505f
f9097ba
 
454505f
 
45524d1
da0bbab
10a538d
76c52a4
 
da0bbab
 
 
e06d1bb
45524d1
da0bbab
8f3cc16
45524d1
 
 
 
 
10a538d
 
 
 
da0bbab
8f3cc16
 
da0bbab
7d1cfd0
45524d1
10a538d
8f3cc16
c6393ae
8f3cc16
 
 
 
9e08c49
8f3cc16
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import gradio as gr
from transformers import pipeline
from config import ModelArgs
from inference import remove_prefix
from model import Llama
import torch
from inference_sft import topk_sampling
import os
import subprocess
import re
from tokenizer import Tokenizer
import torch.nn.functional as F
import shutil

# Define model paths
model_paths = {
    "SFT": "weights/fine_tuned/models--YuvrajSingh9886--smol-llama-finetuned/snapshots/35df7811c322dab5f8df56f6a1f61b5bcf6a7339/snapshot_fine_tuned_model_900.pt",
    "DPO": "weights/DPO/models--YuvrajSingh9886--smol-llama-dpo/snapshots/13b9f059cf4630e6079e8822ea7cb7a703d77ed8/DPO_model_1650.pt",
    "Pretrained": "weights/pretrained/models--YuvrajSingh9886--smol-llama-base/snapshots/8e50a94e3c83649e48eb02549558130c87e75a87/snapshot_6750.pt"
}



ACCESS_TOKEN = os.getenv("GDRIVE_ACCESS_TOKEN")


# def download_models():
for i in model_paths:
    subprocess.run(["python", "download_model_weight.py", "--model_type", i.lower()], check=True)

# download_models()

tk = Tokenizer()
tk = tk.ready_tokenizer()



def beam_search(model, prompt, device, max_length=50, beam_width=5, top_k=50, temperature=1.0):
    input_ids = tk.encode(prompt, return_tensors='pt').to(device)
    
    # Initialize beams with initial input repeated beam_width times
    beams = input_ids.repeat(beam_width, 1)
    beam_scores = torch.zeros(beam_width).to(device)  # Initialize scores

    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(beams)
            logits = outputs[:, -1, :]  # Get last token logits
            
            # Apply temperature scaling
            scaled_logits = logits / temperature
            
            # Calculate log probabilities
            log_probs = F.log_softmax(scaled_logits, dim=-1)
            
            # Get top k candidates for each beam
            topk_log_probs, topk_indices = torch.topk(log_probs, top_k, dim=-1)
            
            # Generate all possible candidates
            expanded_beams = beams.repeat_interleave(top_k, dim=0)
            new_tokens = topk_indices.view(-1, 1)
            candidate_beams = torch.cat([expanded_beams, new_tokens], dim=1)
            
            # Calculate new scores for all candidates
            expanded_scores = beam_scores.repeat_interleave(top_k)
            candidate_scores = expanded_scores + topk_log_probs.view(-1)
            
            # Select top beam_width candidates
            top_scores, top_indices = candidate_scores.topk(beam_width)
            beams = candidate_beams[top_indices]
            beam_scores = top_scores

    # Select best beam
    best_idx = beam_scores.argmax()
    best_sequence = beams[best_idx]
    return tk.decode(best_sequence, skip_special_tokens=True)

# Function to load the selected model
def load_model(model_type):
    model_path = model_paths[model_type]

    # Check if the model exists; if not, download it
    # if not os.path.exists(model_path):
    #     shutil.rmtree(model_path)
    #     os.mkdir(model_path)
    #     print(f"{model_type} Model not found! Downloading...")
    #     subprocess.run(["python", "download_model_weight.py", f"--{model_type.lower()}"], check=True)
    # else:
    #     print(f"{model_type} Model found, skipping download.")

    # Load the model
    model = Llama(
        device=ModelArgs.device, 
        embeddings_dims=ModelArgs.embeddings_dims, 
        no_of_decoder_layers=ModelArgs.no_of_decoder_layers, 
        block_size=ModelArgs.block_size, 
        vocab_size=ModelArgs.vocab_size, 
        dropout=ModelArgs.dropout
    )
    model = model.to(ModelArgs.device)

    dict_model = torch.load(model_path, weights_only=False)
    dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.')
    model.load_state_dict(dict_model['MODEL_STATE'])
    model.eval()

    return model


# download_models()
current_model = load_model("SFT")
        
       



def clean_prompt(text):
    
    text = re.sub(r'### Instruction:\s*', '', text)
    text = re.sub(r'### Input:\s*', '', text)
    text = re.sub(r'### Response:\s*', '', text)
    return text.strip()

def answer_question(model_type, prompt, temperature, top_k, max_length):
    global current_model
    # Reload model if the selected model type is different
    if model_type == "Base (Pretrained)":
        model_type = "Pretrained"
    if model_paths[model_type] != model_paths.get(current_model, None):
        current_model = load_model(model_type)


    formatted_prompt = f"### Instruction: Answer the following query. \n\n ### Input: {prompt}.\n\n ### Response: "
    
    with torch.no_grad():
        # if decoding_method == "Beam Search":
        #     generated_text = beam_search(current_model, formatted_prompt, device=ModelArgs.device, 
        #                                    max_length=max_length, beam_width=5, top_k=top_k, temperature=temperature)
        # else:
        generated_text = topk_sampling(current_model, formatted_prompt, max_length=max_length, 
                                           top_k=top_k, temperature=temperature, device=ModelArgs.device)
        generated_text = clean_prompt(generated_text)
        return generated_text


iface = gr.Interface(
    fn=answer_question,
    inputs=[
        gr.Dropdown(choices=["SFT", "DPO", "Base (Pretrained)"], value="DPO", label="Select Model"),
        # gr.Dropdown(choices=["Top-K Sampling", "Beam Search"], value="Top-K Sampling", label="Decoding Method"),
        gr.Textbox(label="Prompt", lines=5),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature"),
        gr.Slider(minimum=50,maximum = ModelArgs.vocab_size, value=50, step=1, label="Top-k"),
        gr.Slider(minimum=10, maximum=ModelArgs.block_size, value=256, step=1, label="Max Length")
    ],
    outputs=gr.Textbox(label="Answer"),
    title="SmolLlama",
    description="Enter a prompt, select a model (SFT or DPO) and the model will generate an answer."
)

# Launch the Gradio app
if __name__ == "__main__":
    iface.launch()