LeCarnet-Demo / app.py
MaxLSB's picture
Update app.py
7b4f2fa verified
raw
history blame
4.32 kB
import os
import threading
from collections import defaultdict
from PIL import Image
import tempfile
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
model_name_to_path = {
"LeCarnet-3M": "MaxLSB/LeCarnet-3M",
"LeCarnet-8M": "MaxLSB/LeCarnet-8M",
"LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
loaded_models = defaultdict(dict)
for name, path in model_name_to_path.items():
loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
loaded_models[name]["model"].eval()
def resize_logo(input_path, size=(100, 100)):
with Image.open(input_path) as img:
img = img.resize(size, Image.LANCZOS)
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name, format="PNG")
return temp_file.name
def respond(message, history, model_name, max_tokens, temperature, top_p):
history = history + [(message, "")]
yield history
tokenizer = loaded_models[model_name]["tokenizer"]
model = loaded_models[model_name]["model"]
inputs = tokenizer(message, return_tensors="pt")
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=False,
skip_special_tokens=True,
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
accumulated = f"**{model_name}**\n\n"
for new_text in streamer:
accumulated += new_text
history[-1] = (message, accumulated)
yield history
def submit(message, history, model_name, max_tokens, temperature, top_p):
for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
yield updated_history, ""
def start_with_example(example, model_name, max_tokens, temperature, top_p):
for updated_history in respond(example, [], model_name, max_tokens, temperature, top_p):
yield updated_history, ""
resized_logo_path = resize_logo("media/le-carnet.png", size=(100, 100))
examples = [
"Il était une fois un petit garçon qui vivait dans un village paisible.",
"Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang.",
"Il était une fois un petit lapin perdu",
]
with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
gr.Markdown("# LeCarnet")
gr.Markdown("Select a model on the right and type a message to chat, or choose an example below.")
with gr.Row():
with gr.Column(scale=4):
dataset = gr.Dataset(components=[gr.Textbox(visible=False)], samples=[[ex] for ex in examples], type="values")
chatbot = gr.Chatbot(
avatar_images=(None, resized_logo_path),
label="Chat",
height=600,
)
user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
submit_btn = gr.Button("Send")
with gr.Column(scale=1, min_width=200):
model_dropdown = gr.Dropdown(
choices=list(model_name_to_path.keys()),
value="LeCarnet-8M",
label="Model"
)
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
submit_btn.click(
fn=submit,
inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
outputs=[chatbot, user_input],
)
dataset.change(
fn=start_with_example,
inputs=[dataset, model_dropdown, max_tokens, temperature, top_p],
outputs=[chatbot, user_input],
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)