OrangeEye commited on
Commit
a8bbba9
·
1 Parent(s): 5361e2c

update gpu control

Browse files
Files changed (2) hide show
  1. app.py +0 -1
  2. utils.py +3 -2
app.py CHANGED
@@ -141,7 +141,6 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
141
  input = gr.Textbox(visible=False) # placeholder
142
  gr_md = gr.Markdown(mark_text + md_text_initial)
143
 
144
- @spaces.GPU(duration=60)
145
  def update_with_rag_md(message, llm_results_use = 5, database_choice = index_info, llm_model_picked = 'Trust-Align-Qwen2.5'):
146
  chat_round = [
147
  {"role": "user",
 
141
  input = gr.Textbox(visible=False) # placeholder
142
  gr_md = gr.Markdown(mark_text + md_text_initial)
143
 
 
144
  def update_with_rag_md(message, llm_results_use = 5, database_choice = index_info, llm_model_picked = 'Trust-Align-Qwen2.5'):
145
  chat_round = [
146
  {"role": "user",
utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import datetime
2
  import string
3
 
@@ -163,8 +164,9 @@ def load_llama_guard(model_id = "meta-llama/Llama-Guard-3-1B"):
163
  UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe")
164
 
165
  return llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID
166
-
167
 
 
 
168
  def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
169
 
170
  prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False)
@@ -179,7 +181,6 @@ def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
179
  return_dict_in_generate=True,
180
  pad_token_id=tokenizer.eos_token_id,
181
  output_logits=True, # get logits
182
- do_sample=False
183
  )
184
  ######
185
  # Get generated text
 
1
+ import spaces
2
  import datetime
3
  import string
4
 
 
164
  UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe")
165
 
166
  return llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID
 
167
 
168
+
169
+ @spaces.GPU(duration=60)
170
  def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
171
 
172
  prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False)
 
181
  return_dict_in_generate=True,
182
  pad_token_id=tokenizer.eos_token_id,
183
  output_logits=True, # get logits
 
184
  )
185
  ######
186
  # Get generated text