Priyanka6 commited on
Commit
19780e1
Β·
1 Parent(s): 5277acd

Update space

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -67,6 +67,7 @@ import torch
67
  import gradio as gr
68
  from transformers import AutoModelForCausalLM, AutoTokenizer
69
  import os
 
70
 
71
  # Define model names
72
  # MODEL_1_PATH = "./adapter_model.safetensors" # Local path inside Space
@@ -85,7 +86,7 @@ def trim_adapter_weights(model_path):
85
  # if not os.path.exists(model_path):
86
  # raise FileNotFoundError(f"Adapter file not found: {model_path}")
87
 
88
- checkpoint = torch.load(model_path, map_location="cpu",weights_only=False)
89
 
90
  key_to_trim = "lm_head.lora_B.default.weight"
91
 
@@ -99,7 +100,7 @@ def trim_adapter_weights(model_path):
99
 
100
  # Save the modified adapter
101
  trimmed_adapter_path = os.path.join(model_path, "adapter_model_trimmed.safetensors")
102
- torch.save(checkpoint, trimmed_adapter_path)
103
  return trimmed_adapter_path
104
 
105
  return model_path
 
67
  import gradio as gr
68
  from transformers import AutoModelForCausalLM, AutoTokenizer
69
  import os
70
+ from safetensors.torch import load_file, save_file
71
 
72
  # Define model names
73
  # MODEL_1_PATH = "./adapter_model.safetensors" # Local path inside Space
 
86
  # if not os.path.exists(model_path):
87
  # raise FileNotFoundError(f"Adapter file not found: {model_path}")
88
 
89
+ checkpoint = load_file(model_path)
90
 
91
  key_to_trim = "lm_head.lora_B.default.weight"
92
 
 
100
 
101
  # Save the modified adapter
102
  trimmed_adapter_path = os.path.join(model_path, "adapter_model_trimmed.safetensors")
103
+ save_file(checkpoint, trimmed_adapter_path)
104
  return trimmed_adapter_path
105
 
106
  return model_path