OrangeEye commited on
Commit
6a2e657
·
1 Parent(s): e6b01c8

update generate

Browse files
Files changed (2) hide show
  1. __pycache__/utils.cpython-310.pyc +0 -0
  2. utils.py +5 -3
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
utils.py CHANGED
@@ -182,6 +182,8 @@ def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
182
  pad_token_id=tokenizer.eos_token_id,
183
  output_logits=True, # get logits
184
  do_sample=False,
 
 
185
  )
186
  ######
187
  # Get generated text
@@ -192,7 +194,7 @@ def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
192
  # Ignore the tokens from the input to get the tokens generated by the model
193
  generated_token_ids = outputs.sequences[:, input_length:].cpu()
194
  generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)
195
-
196
  ######
197
  # Get Probability of "unsafe" token
198
  ######
@@ -206,9 +208,9 @@ def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
206
  unsafe_probability = first_token_probs[0, UNSAFE_TOKEN_ID]
207
  unsafe_probability = unsafe_probability.item()
208
 
209
- ######
210
  # Result
211
- ######
212
  return {
213
  "unsafe_score": unsafe_probability,
214
  "generated_text": generated_text
 
182
  pad_token_id=tokenizer.eos_token_id,
183
  output_logits=True, # get logits
184
  do_sample=False,
185
+ temperature=None,
186
+ top_p=None
187
  )
188
  ######
189
  # Get generated text
 
194
  # Ignore the tokens from the input to get the tokens generated by the model
195
  generated_token_ids = outputs.sequences[:, input_length:].cpu()
196
  generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)
197
+ logger.info(generated_text)
198
  ######
199
  # Get Probability of "unsafe" token
200
  ######
 
208
  unsafe_probability = first_token_probs[0, UNSAFE_TOKEN_ID]
209
  unsafe_probability = unsafe_probability.item()
210
 
211
+ ########
212
  # Result
213
+ ########
214
  return {
215
  "unsafe_score": unsafe_probability,
216
  "generated_text": generated_text