Spaces:
Runtime error
Runtime error
Update space
Browse files
app.py
CHANGED
@@ -67,6 +67,7 @@ import torch
|
|
67 |
import gradio as gr
|
68 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
69 |
import os
|
|
|
70 |
|
71 |
# Define model names
|
72 |
# MODEL_1_PATH = "./adapter_model.safetensors" # Local path inside Space
|
@@ -85,7 +86,7 @@ def trim_adapter_weights(model_path):
|
|
85 |
# if not os.path.exists(model_path):
|
86 |
# raise FileNotFoundError(f"Adapter file not found: {model_path}")
|
87 |
|
88 |
-
checkpoint =
|
89 |
|
90 |
key_to_trim = "lm_head.lora_B.default.weight"
|
91 |
|
@@ -99,7 +100,7 @@ def trim_adapter_weights(model_path):
|
|
99 |
|
100 |
# Save the modified adapter
|
101 |
trimmed_adapter_path = os.path.join(model_path, "adapter_model_trimmed.safetensors")
|
102 |
-
|
103 |
return trimmed_adapter_path
|
104 |
|
105 |
return model_path
|
|
|
67 |
import gradio as gr
|
68 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
69 |
import os
|
70 |
+
from safetensors.torch import load_file, save_file
|
71 |
|
72 |
# Define model names
|
73 |
# MODEL_1_PATH = "./adapter_model.safetensors" # Local path inside Space
|
|
|
86 |
# if not os.path.exists(model_path):
|
87 |
# raise FileNotFoundError(f"Adapter file not found: {model_path}")
|
88 |
|
89 |
+
checkpoint = load_file(model_path)
|
90 |
|
91 |
key_to_trim = "lm_head.lora_B.default.weight"
|
92 |
|
|
|
100 |
|
101 |
# Save the modified adapter
|
102 |
trimmed_adapter_path = os.path.join(model_path, "adapter_model_trimmed.safetensors")
|
103 |
+
save_file(checkpoint, trimmed_adapter_path)
|
104 |
return trimmed_adapter_path
|
105 |
|
106 |
return model_path
|