File size: 4,150 Bytes
36942d4
a7a20a5
39c555f
954f37f
39c555f
 
 
 
 
 
 
954f37f
39c555f
 
 
 
 
 
954f37f
a167f72
39c555f
954f37f
39c555f
 
 
a167f72
 
 
 
 
 
39c555f
644b0a5
 
a167f72
954f37f
39c555f
 
954f37f
644b0a5
954f37f
a7a20a5
 
 
 
 
954f37f
a7a20a5
 
 
 
 
 
 
 
 
954f37f
a7a20a5
 
954f37f
a167f72
a7a20a5
39c555f
644b0a5
 
a7a20a5
644b0a5
 
 
 
6ecb51d
644b0a5
954f37f
 
52a9a97
6ecb51d
644b0a5
a167f72
6ecb51d
a167f72
644b0a5
 
 
954f37f
 
 
 
 
 
 
 
 
6ecb51d
7b4f2fa
954f37f
7b4f2fa
954f37f
7b4f2fa
644b0a5
 
 
954f37f
a167f72
644b0a5
 
7b4f2fa
644b0a5
 
a167f72
 
 
 
 
 
 
6ecb51d
341bd22
f5f805b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import threading
from collections import defaultdict

import gradio as gr
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)

# Define model paths
model_name_to_path = {
    "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
    "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
    "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}

# Load Hugging Face token
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN", "default_token")  # Use default to avoid errors

# Preload models and tokenizers
loaded_models = defaultdict(dict)

for name, path in model_name_to_path.items():
    try:
        loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
        loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
        loaded_models[name]["model"].eval()
    except Exception as e:
        print(f"Error loading {name}: {str(e)}")

def respond(message, history, model_name, max_tokens, temperature, top_p):
    history = history + [(message, "")]
    yield history

    tokenizer = loaded_models[model_name]["tokenizer"]
    model = loaded_models[model_name]["model"]

    inputs = tokenizer(message, return_tensors="pt")

    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=False,
        skip_special_tokens=True,
    )

    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
    )

    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    accumulated = ""  # Removed model name prefix
    for new_text in streamer:
        accumulated += new_text
        history[-1] = (message, accumulated)
        yield history

def submit(message, history, model_name, max_tokens, temperature, top_p):
    for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
        yield updated_history, ""

with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
    gr.Markdown("# LeCarnet")
    gr.Markdown("Select a model on the right and type a message to chat.")

    with gr.Row():
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(
                avatar_images=(None, "https://raw.githubusercontent.com/maxlsb/le-carnet/main/media/le-carnet.png"),   # Using URL for reliability
                label="Chat",
                height=600,
            )
            user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
            submit_btn = gr.Button("Send")
            examples = gr.Examples(
                examples=[
                    ["Il était une fois un petit garçon qui vivait dans un village paisible."],
                    ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
                    ["Il était une fois un petit lapin perdu"],
                ],
                inputs=user_input,
            )

        with gr.Column(scale=1, min_width=200):
            model_dropdown = gr.Dropdown(
                choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
                value="LeCarnet-8M",
                label="Select Model"
            )
            max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
            temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")

    # Submit button click
    submit_btn.click(
        fn=submit,
        inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
        outputs=[chatbot, user_input],
    )
    
    # Enter key press
    user_input.submit(
        fn=submit,
        inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
        outputs=[chatbot, user_input],
    )

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)