File size: 3,540 Bytes
36942d4
a7a20a5
852d26e
a7a20a5
 
 
 
 
 
 
 
eaff982
a7a20a5
 
 
 
892e21c
a7a20a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaff982
a7a20a5
 
 
 
 
 
 
 
 
 
 
eaff982
1dc8ea5
eaff982
a7a20a5
eaff982
 
 
a7a20a5
 
eaff982
a7a20a5
 
 
 
 
 
 
eaff982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341bd22
1dc8ea5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import threading
import gradio as gr
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)

MODEL_NAMES = ["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"]
HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
MEDIA_PATH = "media/le-carnet.png"

models = {}
tokenizers = {}
for name in MODEL_NAMES:
    hub_id = f"MaxLSB/{name}"
    tokenizers[name] = AutoTokenizer.from_pretrained(hub_id, token=HF_TOKEN)
    models[name] = AutoModelForCausalLM.from_pretrained(hub_id, token=HF_TOKEN)
    models[name].eval()


def respond(
    prompt: str,
    chat_history,
    selected_model: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    tokenizer = tokenizers[selected_model]
    model = models[selected_model]
    inputs = tokenizer(prompt, return_tensors="pt")

    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=False,
        skip_special_tokens=True,
    )

    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
    )

    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    prefix = f"<img src='{MEDIA_PATH}' width='24' style='display:inline; vertical-align:middle; margin-right:6px;'/> <strong>{selected_model}</strong>: "
    accumulated = ""
    first = True
    for new_text in streamer:
        if first:
            accumulated = prefix + new_text
            first = False
        else:
            accumulated += new_text
        yield accumulated


with gr.Blocks(css=".gr-chatbox {height: 600px !important;}") as demo:
    gr.Markdown("## LeCarnet")

    with gr.Row():
        with gr.Column(scale=4):
            with gr.Row():
                toggle_btn = gr.Button("Show/hide parameters", elem_id="toggle-btn")
            chat = gr.ChatInterface(
                fn=respond,
                additional_inputs=[],
                examples=[
                    ["Il était une fois un petit garçon qui vivait dans un village paisible."],
                    ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
                    ["Il était une fois un petit lapin perdu"],
                ],
                cache_examples=False,
            )

        with gr.Column(scale=1, visible=True, elem_id="settings-panel") as param_panel:
            selected_model = gr.Dropdown(MODEL_NAMES, value="LeCarnet-8M", label="Model")
            max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max new tokens")
            temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top‑p")

            chat.additional_inputs = [selected_model, max_tokens, temperature, top_p]

    demo.load(None, None, _js="""
        () => {
            const toggleBtn = document.querySelector('#toggle-btn button') || document.querySelector('#toggle-btn');
            const panel = document.querySelector('#settings-panel');
            toggleBtn.addEventListener('click', () => {
                panel.style.display = (panel.style.display === 'none') ? 'flex' : 'none';
            });
        }
    """)

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=10,max_size=10).launch(ssr_mode=False, max_threads=10)