OrangeEye commited on
Commit
4d0d167
·
1 Parent(s): 6154a6f
Files changed (2) hide show
  1. app.py +1 -0
  2. utils.py +3 -2
app.py CHANGED
@@ -141,6 +141,7 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
141
  input = gr.Textbox(visible=False) # placeholder
142
  gr_md = gr.Markdown(mark_text + md_text_initial)
143
 
 
144
  def update_with_rag_md(message, llm_results_use = 5, database_choice = index_info, llm_model_picked = 'Trust-Align-Qwen2.5'):
145
  chat_round = [
146
  {"role": "user",
 
141
  input = gr.Textbox(visible=False) # placeholder
142
  gr_md = gr.Markdown(mark_text + md_text_initial)
143
 
144
+ @spaces.GPU(duration=60)
145
  def update_with_rag_md(message, llm_results_use = 5, database_choice = index_info, llm_model_picked = 'Trust-Align-Qwen2.5'):
146
  chat_round = [
147
  {"role": "user",
utils.py CHANGED
@@ -166,7 +166,7 @@ def load_llama_guard(model_id = "meta-llama/Llama-Guard-3-1B"):
166
  return llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID
167
 
168
 
169
- @spaces.GPU(duration=60)
170
  def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
171
 
172
  prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False)
@@ -188,11 +188,12 @@ def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
188
  ######
189
  # Get generated text
190
  ######
191
-
192
  # Number of tokens that correspond to the input prompt
193
  input_length = inputs.input_ids.shape[1]
194
  # Ignore the tokens from the input to get the tokens generated by the model
195
  generated_token_ids = outputs.sequences[:, input_length:].cpu()
 
196
  generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)
197
  logger.info(generated_text)
198
  ######
 
166
  return llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID
167
 
168
 
169
+ @spaces.GPU(duration=120)
170
  def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
171
 
172
  prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False)
 
188
  ######
189
  # Get generated text
190
  ######
191
+ logger.info(outputs)
192
  # Number of tokens that correspond to the input prompt
193
  input_length = inputs.input_ids.shape[1]
194
  # Ignore the tokens from the input to get the tokens generated by the model
195
  generated_token_ids = outputs.sequences[:, input_length:].cpu()
196
+ logger.info(generated_token_ids)
197
  generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)
198
  logger.info(generated_text)
199
  ######