import os import threading import gradio as gr from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, ) MODEL_NAMES = ["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"] HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") MEDIA_PATH = "media/le-carnet.png" models = {} tokenizers = {} for name in MODEL_NAMES: hub_id = f"MaxLSB/{name}" tokenizers[name] = AutoTokenizer.from_pretrained(hub_id, token=HF_TOKEN) models[name] = AutoModelForCausalLM.from_pretrained(hub_id, token=HF_TOKEN) models[name].eval() def respond( prompt: str, chat_history, selected_model: str, max_tokens: int, temperature: float, top_p: float, ): tokenizer = tokenizers[selected_model] model = models[selected_model] inputs = tokenizer(prompt, return_tensors="pt") streamer = TextIteratorStreamer( tokenizer, skip_prompt=False, skip_special_tokens=True, ) generate_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_p=top_p, eos_token_id=tokenizer.eos_token_id, ) thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() prefix = f" {selected_model}: " accumulated = "" first = True for new_text in streamer: if first: accumulated = prefix + new_text first = False else: accumulated += new_text yield accumulated with gr.Blocks(css=".gr-chatbox {height: 600px !important;}") as demo: gr.Markdown("## LeCarnet") with gr.Row(): with gr.Column(scale=4): with gr.Row(): toggle_btn = gr.Button("Show/hide parameters", elem_id="toggle-btn") chat = gr.ChatInterface( fn=respond, additional_inputs=[], examples=[ ["Il était une fois un petit garçon qui vivait dans un village paisible."], ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."], ["Il était une fois un petit lapin perdu"], ], cache_examples=False, ) with gr.Column(scale=1, visible=True, elem_id="settings-panel") as param_panel: selected_model = gr.Dropdown(MODEL_NAMES, value="LeCarnet-8M", label="Model") max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max new tokens") temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top‑p") chat.additional_inputs = [selected_model, max_tokens, temperature, top_p] demo.load(None, None, _js=""" () => { const toggleBtn = document.querySelector('#toggle-btn button') || document.querySelector('#toggle-btn'); const panel = document.querySelector('#settings-panel'); toggleBtn.addEventListener('click', () => { panel.style.display = (panel.style.display === 'none') ? 'flex' : 'none'; }); } """) if __name__ == "__main__": demo.queue(default_concurrency_limit=10,max_size=10).launch(ssr_mode=False, max_threads=10)