alakxender commited on
Commit
688d8ab
·
1 Parent(s): 056f706
Files changed (1) hide show
  1. content_gen.py +7 -3
content_gen.py CHANGED
@@ -51,7 +51,11 @@ def generate_content(prompt, max_new_tokens, num_beams, repetition_penalty, no_r
51
  repetition_penalty=repetition_penalty,
52
  no_repeat_ngram_size=no_repeat_ngram_size,
53
  do_sample=do_sample,
54
- early_stopping=True
55
  )
56
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
57
-
 
 
 
 
 
51
  repetition_penalty=repetition_penalty,
52
  no_repeat_ngram_size=no_repeat_ngram_size,
53
  do_sample=do_sample,
54
+ early_stopping=False
55
  )
56
+ output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
+ # Trim to the last period
58
+ if '.' in output_text:
59
+ last_period = output_text.rfind('.')
60
+ output_text = output_text[:last_period+1]
61
+ return output_text