File size: 6,550 Bytes
ec6e229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76945ee
b623416
ec6e229
b623416
19780e1
76945ee
b623416
b43f4f3
b623416
ec6e229
b623416
 
 
1edb934
ec6e229
 
 
95936d3
ec6e229
eb72a06
f86ffc7
 
ec6e229
19780e1
9e3a75e
ec6e229
 
 
 
b354b02
ec6e229
 
 
 
 
 
 
 
 
19780e1
ec6e229
3c69f29
b623416
b52a245
 
ec6e229
b623416
 
ec6e229
 
b623416
b43f4f3
b623416
 
b43f4f3
b623416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1edb934
b623416
 
ec6e229
 
b623416
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# # import gradio as gr
# # from huggingface_hub import InferenceClient

# # """
# # For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
# # """
# # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")


# # def respond(
# #     message,
# #     history: list[tuple[str, str]],
# #     system_message,
# #     max_tokens,
# #     temperature,
# #     top_p,
# # ):
# #     messages = [{"role": "system", "content": system_message}]

# #     for val in history:
# #         if val[0]:
# #             messages.append({"role": "user", "content": val[0]})
# #         if val[1]:
# #             messages.append({"role": "assistant", "content": val[1]})

# #     messages.append({"role": "user", "content": message})

# #     response = ""

# #     for message in client.chat_completion(
# #         messages,
# #         max_tokens=max_tokens,
# #         stream=True,
# #         temperature=temperature,
# #         top_p=top_p,
# #     ):
# #         token = message.choices[0].delta.content

# #         response += token
# #         yield response


# # """
# # For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
# # """
# # demo = gr.ChatInterface(
# #     respond,
# #     additional_inputs=[
# #         gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
# #         gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
# #         gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
# #         gr.Slider(
# #             minimum=0.1,
# #             maximum=1.0,
# #             value=0.95,
# #             step=0.05,
# #             label="Top-p (nucleus sampling)",
# #         ),
# #     ],
# # )


# # if __name__ == "__main__":
# #     demo.launch()

import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
from safetensors.torch import load_file, save_file

# Define model names
# MODEL_1_PATH = "./adapter_model.safetensors"  # Local path inside Space
###
MODEL_1_PATH = "Priyanka6/fine-tuning-inference"
###
MODEL_2_NAME = "sarvamai/sarvam-1"  # The base model on Hugging Face Hub
# MODEL_3_NAME = 

def trim_adapter_weights(model_path):
    """
    Trims the last token from the adapter's lm_head.lora_B.default.weight 
    if there is a mismatch with the base model. 
    """
    model_path = "./adapter_model.safetensors"
    # if not os.path.exists(model_path):
    #     raise FileNotFoundError(f"Adapter file not found: {model_path}")
    
    checkpoint = load_file(model_path)
    print("Keys in checkpoint:", list(checkpoint.keys()))

    key_to_trim = "lm_head.lora_B.default.weight"
    
    if key_to_trim in checkpoint:
        print("Entered")
        original_size = checkpoint[key_to_trim].shape[0]
        expected_size = original_size - 1  # Removing last token
        
        print(f"Trimming {key_to_trim}: {original_size} -> {expected_size}")

        checkpoint[key_to_trim] = checkpoint[key_to_trim][:-1]  # Trim the last row

        # Save the modified adapter
        trimmed_adapter_path = os.path.join(model_path, "adapter_model_trimmed.safetensors")
        save_file(checkpoint, trimmed_adapter_path)
        return trimmed_adapter_path
    print("didn't execute the if block!")
    return model_path
model_path=os.path.join(MODEL_1_PATH,"adapter_model.safetensors")
trimmed_adapter_path = trim_adapter_weights(model_path)

# Load the tokenizer (same for both models)
TOKENIZER_NAME = "sarvamai/sarvam-1"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

# Function to load a model
def load_model(model_choice):
    if model_choice == "Hugging face dataset":
        model = AutoModelForCausalLM.from_pretrained("./", torch_dtype=torch.float16, device_map="auto")
        trimmed_adapter_path = os.path.join("Priyanka6/fine-tuning-inference", "adapter_model_trimmed.safetensors")
        model.load_adapter(trimmed_adapter_path, "safe_tensors")  # Load safetensors adapter
    else:
        model = AutoModelForCausalLM.from_pretrained(MODEL_2_NAME)
    model.eval()
    return model

# Load default model on startup
current_model = load_model("Hugging face dataset")

# Chatbot response function
def respond(message, history, model_choice, max_tokens, temperature, top_p):
    global current_model
    
    # Switch model if user selects a different one
    if (model_choice == "Hugging face dataset" and current_model is not None and current_model.config.name_or_path != MODEL_1_PATH) or \
       (model_choice == "Proprietary dataset1" and current_model is not None and current_model.config.name_or_path != MODEL_2_NAME):
        current_model = load_model(model_choice)

    # Convert chat history to format
    messages = [{"role": "system", "content": "You are a friendly AI assistant."}]
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})
    messages.append({"role": "user", "content": message})

    # Tokenize and generate response
    inputs = tokenizer.apply_chat_template(messages, tokenize=False)
    input_tokens = tokenizer(inputs, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")

    output_tokens = current_model.generate(
        **input_tokens,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    return response

# Define Gradio Chat Interface
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Dropdown(choices=["Hugging face dataset", "Proprietary dataset1"], value="Fine-Tuned Model", label="Select Model"),
        gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max Tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
    ],
)

if __name__ == "__main__":
    demo.launch()


# # Test the chatbot
# if __name__ == "__main__":
#     while True:
#         query = input("User: ")
#         if query.lower() in ["exit", "quit"]:
#             break
#         response = chat(query)
#         print(f"Bot: {response}")