Spaces:
Running
on
Zero
Running
on
Zero
update to_device
Browse files
app.py
CHANGED
@@ -152,7 +152,7 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
|
|
152 |
}
|
153 |
]
|
154 |
# llama guard check for it
|
155 |
-
|
156 |
prompt_safety = "safe"
|
157 |
|
158 |
if prompt_safety == "safe":
|
|
|
152 |
}
|
153 |
]
|
154 |
# llama guard check for it
|
155 |
+
prompt_safety = moderate(chat_round, llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID)['generated_text']
|
156 |
prompt_safety = "safe"
|
157 |
|
158 |
if prompt_safety == "safe":
|
utils.py
CHANGED
@@ -158,7 +158,7 @@ def load_llama_guard(model_id = "meta-llama/Llama-Guard-3-1B"):
|
|
158 |
|
159 |
logger.info("loading llama_guard")
|
160 |
llama_guard_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
161 |
-
llama_guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype
|
162 |
|
163 |
# Get the id of the "unsafe" token, this will later be used to extract its probability
|
164 |
UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe")
|
|
|
158 |
|
159 |
logger.info("loading llama_guard")
|
160 |
llama_guard_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
161 |
+
llama_guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to('cuda')
|
162 |
|
163 |
# Get the id of the "unsafe" token, this will later be used to extract its probability
|
164 |
UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe")
|