File size: 4,321 Bytes
36942d4
a7a20a5
39c555f
7b4f2fa
 
39c555f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b4f2fa
 
 
 
 
 
 
644b0a5
 
7b4f2fa
39c555f
 
644b0a5
a7a20a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b4f2fa
a7a20a5
39c555f
644b0a5
 
a7a20a5
644b0a5
 
 
 
7b4f2fa
 
 
 
 
 
 
 
 
 
 
6ecb51d
 
644b0a5
7b4f2fa
52a9a97
6ecb51d
7b4f2fa
644b0a5
7b4f2fa
6ecb51d
7b4f2fa
644b0a5
 
 
6ecb51d
7b4f2fa
 
 
 
 
644b0a5
 
 
 
 
7b4f2fa
644b0a5
 
7b4f2fa
 
 
 
6ecb51d
 
341bd22
f5f805b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import threading
from collections import defaultdict
from PIL import Image
import tempfile
import gradio as gr
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)

model_name_to_path = {
    "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
    "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
    "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}

hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]

loaded_models = defaultdict(dict)

for name, path in model_name_to_path.items():
    loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
    loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
    loaded_models[name]["model"].eval()

def resize_logo(input_path, size=(100, 100)):
    with Image.open(input_path) as img:
        img = img.resize(size, Image.LANCZOS)
        temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
        img.save(temp_file.name, format="PNG")
        return temp_file.name

def respond(message, history, model_name, max_tokens, temperature, top_p):
    history = history + [(message, "")]
    yield history
    tokenizer = loaded_models[model_name]["tokenizer"]
    model = loaded_models[model_name]["model"]
    inputs = tokenizer(message, return_tensors="pt")
    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=False,
        skip_special_tokens=True,
    )
    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
    )
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()
    accumulated = f"**{model_name}**\n\n"
    for new_text in streamer:
        accumulated += new_text
        history[-1] = (message, accumulated)
        yield history

def submit(message, history, model_name, max_tokens, temperature, top_p):
    for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
        yield updated_history, ""

def start_with_example(example, model_name, max_tokens, temperature, top_p):
    for updated_history in respond(example, [], model_name, max_tokens, temperature, top_p):
        yield updated_history, ""

resized_logo_path = resize_logo("media/le-carnet.png", size=(100, 100))

examples = [
    "Il était une fois un petit garçon qui vivait dans un village paisible.",
    "Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang.",
    "Il était une fois un petit lapin perdu",
]

with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
    gr.Markdown("# LeCarnet")
    gr.Markdown("Select a model on the right and type a message to chat, or choose an example below.")
    with gr.Row():
        with gr.Column(scale=4):
            dataset = gr.Dataset(components=[gr.Textbox(visible=False)], samples=[[ex] for ex in examples], type="values")
            chatbot = gr.Chatbot(
                avatar_images=(None, resized_logo_path),
                label="Chat",
                height=600,
            )
            user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
            submit_btn = gr.Button("Send")
        with gr.Column(scale=1, min_width=200):
            model_dropdown = gr.Dropdown(
                choices=list(model_name_to_path.keys()),
                value="LeCarnet-8M",
                label="Model"
            )
            max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
            temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
    submit_btn.click(
        fn=submit,
        inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
        outputs=[chatbot, user_input],
    )
    dataset.change(
        fn=start_with_example,
        inputs=[dataset, model_dropdown, max_tokens, temperature, top_p],
        outputs=[chatbot, user_input],
    )

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)