OrangeEye commited on
Commit
6154a6f
·
1 Parent(s): c0519e0

update to_device

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. utils.py +1 -1
app.py CHANGED
@@ -152,7 +152,7 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
152
  }
153
  ]
154
  # llama guard check for it
155
- # prompt_safety = moderate(chat_round, llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID)['generated_text']
156
  prompt_safety = "safe"
157
 
158
  if prompt_safety == "safe":
 
152
  }
153
  ]
154
  # llama guard check for it
155
+ prompt_safety = moderate(chat_round, llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID)['generated_text']
156
  prompt_safety = "safe"
157
 
158
  if prompt_safety == "safe":
utils.py CHANGED
@@ -158,7 +158,7 @@ def load_llama_guard(model_id = "meta-llama/Llama-Guard-3-1B"):
158
 
159
  logger.info("loading llama_guard")
160
  llama_guard_tokenizer = AutoTokenizer.from_pretrained(model_id)
161
- llama_guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="cuda")
162
 
163
  # Get the id of the "unsafe" token, this will later be used to extract its probability
164
  UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe")
 
158
 
159
  logger.info("loading llama_guard")
160
  llama_guard_tokenizer = AutoTokenizer.from_pretrained(model_id)
161
+ llama_guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to('cuda')
162
 
163
  # Get the id of the "unsafe" token, this will later be used to extract its probability
164
  UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe")