File size: 2,500 Bytes
7b524eb
 
 
f16c710
7b524eb
 
 
f16c710
7b524eb
 
b84fae7
 
f16c710
 
 
 
 
b84fae7
f16c710
7b524eb
 
f16c710
7b524eb
f16c710
7b524eb
f16c710
 
7b524eb
f16c710
 
7b524eb
f16c710
 
7b524eb
f16c710
 
7b524eb
f16c710
 
 
b84fae7
 
 
f16c710
 
 
 
 
 
 
b84fae7
 
 
7b524eb
f16c710
7b524eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B")
model = AutoModelForCausalLM.from_pretrained("unsloth/Llama-3.2-1B")

# Use a pipeline for text generation
text_gen_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Text generation function with repetition penalty
def generate_text(prompt, max_length=50, temperature=0.7, top_p=0.9, top_k=50, repetition_penalty=1.2):
    generated_text = text_gen_pipeline(prompt, 
                                       max_length=max_length, 
                                       temperature=temperature, 
                                       top_p=top_p, 
                                       top_k=top_k, 
                                       repetition_penalty=repetition_penalty,  # Penalty added
                                       num_return_sequences=1)
    return generated_text[0]['generated_text']

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Text Generation with Llama 3.2 - 1B")
    
    # Input box for user prompt
    prompt_input = gr.Textbox(label="Input (Prompt)", placeholder="Enter your prompt here...")
    
    # Slider for maximum text length
    max_length_input = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Maximum Length")
    
    # Slider for temperature (controls creativity)
    temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature (creativity)")
    
    # Slider for top_p (nucleus sampling)
    top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
    
    # Slider for top_k (controls diversity)
    top_k_input = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k (sampling diversity)")
    
    # Slider for repetition penalty
    repetition_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty")
    
    # Output box for the generated text
    output_text = gr.Textbox(label="Generated Text")
    
    # Submit button
    generate_button = gr.Button("Generate")
    
    # Action on button click
    generate_button.click(generate_text, 
                          inputs=[prompt_input, max_length_input, temperature_input, top_p_input, top_k_input, repetition_penalty_input], 
                          outputs=output_text)

# Launch the app
demo.launch()