File size: 3,912 Bytes
36942d4
a7a20a5
39c555f
f5f805b
39c555f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7a20a5
 
39c555f
a7a20a5
 
 
 
 
 
39c555f
a7a20a5
 
 
 
 
 
 
 
 
 
39c555f
a7a20a5
 
 
39c555f
a7a20a5
 
39c555f
a7a20a5
 
39c555f
52a9a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c555f
 
 
52a9a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaff982
52a9a97
341bd22
f5f805b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import threading
from collections import defaultdict

import gradio as gr
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)

# Define model paths
model_name_to_path = {
    "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
    "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
    "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}

# Load Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]

# Preload models and tokenizers
loaded_models = defaultdict(dict)

for name, path in model_name_to_path.items():
    loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
    loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
    loaded_models[name]["model"].eval()

def respond(
    prompt: str,
    chat_history,
    model_name: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    # Select the appropriate model and tokenizer
    tokenizer = loaded_models[model_name]["tokenizer"]
    model = loaded_models[model_name]["model"]

    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt")

    # Set up streaming
    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=False,
        skip_special_tokens=True,
    )

    # Configure generation parameters
    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
    )

    # Run generation in a background thread
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    # Stream results
    accumulated = ""
    for new_text in streamer:
        accumulated += new_text
        yield accumulated

# Create Gradio Chat Interface
with gr.Blocks() as demo:
    # Custom title with logo
    with gr.Row():
        gr.HTML(
            '<div style="display: flex; align-items: center;">'
            f'<img src="file/{os.path.abspath("media/le-carnet.png")}" style="height: 50px; margin-right: 10px;" />'
            '<h1 style="margin: 0;">LeCarnet</h1>'
            '</div>'
        )
    
    # Chat interface
    chatbot = gr.ChatInterface(
        fn=respond,
        title=None,  # Remove default title
        description=None,  # Remove default description
        examples=[
            ["Il était une fois un petit garçon qui vivait dans un village paisible."],
            ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
            ["Il était une fois un petit lapin perdu"],
        ],
        cache_examples=False,
    )
    
    # Sidebar for model selection and parameters
    with gr.Column(elem_classes="sidebar", variant="panel"):
        gr.Markdown("### Model Configuration")
        model_dropdown = gr.Dropdown(
            choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
            value="LeCarnet-8M",
            label="Model",
        )
        max_tokens_slider = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
        temperature_slider = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
        top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
        
        # Pass parameters to the chatbot
        chatbot.load(
            fn=lambda x, y, z, w: None,
            inputs=[model_dropdown, max_tokens_slider, temperature_slider, top_p_slider],
            outputs=None,
        )
        chatbot.config.update({
            "model_name": model_dropdown,
            "max_tokens": max_tokens_slider,
            "temperature": temperature_slider,
            "top_p": top_p_slider,
        })

# Launch the app
if __name__ == "__main__":
    demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)