Leo8613 commited on
Commit
b84fae7
·
verified ·
1 Parent(s): f16c710

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -8,13 +8,14 @@ model = AutoModelForCausalLM.from_pretrained("unsloth/Llama-3.2-1B")
8
  # Use a pipeline for text generation
9
  text_gen_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
10
 
11
- # Text generation function with additional parameters
12
- def generate_text(prompt, max_length=50, temperature=0.7, top_p=0.9, top_k=50):
13
  generated_text = text_gen_pipeline(prompt,
14
  max_length=max_length,
15
  temperature=temperature,
16
  top_p=top_p,
17
  top_k=top_k,
 
18
  num_return_sequences=1)
19
  return generated_text[0]['generated_text']
20
 
@@ -37,6 +38,9 @@ with gr.Blocks() as demo:
37
  # Slider for top_k (controls diversity)
38
  top_k_input = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k (sampling diversity)")
39
 
 
 
 
40
  # Output box for the generated text
41
  output_text = gr.Textbox(label="Generated Text")
42
 
@@ -44,7 +48,9 @@ with gr.Blocks() as demo:
44
  generate_button = gr.Button("Generate")
45
 
46
  # Action on button click
47
- generate_button.click(generate_text, inputs=[prompt_input, max_length_input, temperature_input, top_p_input, top_k_input], outputs=output_text)
 
 
48
 
49
  # Launch the app
50
  demo.launch()
 
8
  # Use a pipeline for text generation
9
  text_gen_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
10
 
11
+ # Text generation function with repetition penalty
12
+ def generate_text(prompt, max_length=50, temperature=0.7, top_p=0.9, top_k=50, repetition_penalty=1.2):
13
  generated_text = text_gen_pipeline(prompt,
14
  max_length=max_length,
15
  temperature=temperature,
16
  top_p=top_p,
17
  top_k=top_k,
18
+ repetition_penalty=repetition_penalty, # Penalty added
19
  num_return_sequences=1)
20
  return generated_text[0]['generated_text']
21
 
 
38
  # Slider for top_k (controls diversity)
39
  top_k_input = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k (sampling diversity)")
40
 
41
+ # Slider for repetition penalty
42
+ repetition_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty")
43
+
44
  # Output box for the generated text
45
  output_text = gr.Textbox(label="Generated Text")
46
 
 
48
  generate_button = gr.Button("Generate")
49
 
50
  # Action on button click
51
+ generate_button.click(generate_text,
52
+ inputs=[prompt_input, max_length_input, temperature_input, top_p_input, top_k_input, repetition_penalty_input],
53
+ outputs=output_text)
54
 
55
  # Launch the app
56
  demo.launch()