File size: 2,813 Bytes
36942d4
a7a20a5
39c555f
f5f805b
39c555f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7a20a5
 
39c555f
a7a20a5
 
 
 
 
 
39c555f
a7a20a5
 
 
 
 
 
 
 
 
 
39c555f
a7a20a5
 
 
39c555f
a7a20a5
 
39c555f
a7a20a5
 
39c555f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaff982
341bd22
f5f805b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import threading
from collections import defaultdict

import gradio as gr
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)

# Define model paths
model_name_to_path = {
    "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
    "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
    "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}

# Load Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]

# Preload models and tokenizers
loaded_models = defaultdict(dict)

for name, path in model_name_to_path.items():
    loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
    loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
    loaded_models[name]["model"].eval()

def respond(
    prompt: str,
    chat_history,
    model_name: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    # Select the appropriate model and tokenizer
    tokenizer = loaded_models[model_name]["tokenizer"]
    model = loaded_models[model_name]["model"]

    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt")

    # Set up streaming
    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=False,
        skip_special_tokens=True,
    )

    # Configure generation parameters
    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
    )

    # Run generation in a background thread
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    # Stream results
    accumulated = ""
    for new_text in streamer:
        accumulated += new_text
        yield accumulated

# Create Gradio Chat Interface
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Dropdown(
            choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
            value="LeCarnet-8M",
            label="Model",
        ),
        gr.Slider(1, 512, value=512, step=1, label="Max New Tokens"),
        gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
    ],
    title="LeCarnet",
    description="Select a model and enter text to get started.",
    examples=[
        ["Il était une fois un petit garçon qui vivait dans un village paisible."],
        ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
        ["Il était une fois un petit lapin perdu"],
    ],
    cache_examples=False,
)

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)