Spaces:
Sleeping
Sleeping
File size: 4,321 Bytes
36942d4 a7a20a5 39c555f 7b4f2fa 39c555f 7b4f2fa 644b0a5 7b4f2fa 39c555f 644b0a5 a7a20a5 7b4f2fa a7a20a5 39c555f 644b0a5 a7a20a5 644b0a5 7b4f2fa 6ecb51d 644b0a5 7b4f2fa 52a9a97 6ecb51d 7b4f2fa 644b0a5 7b4f2fa 6ecb51d 7b4f2fa 644b0a5 6ecb51d 7b4f2fa 644b0a5 7b4f2fa 644b0a5 7b4f2fa 6ecb51d 341bd22 f5f805b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import os
import threading
from collections import defaultdict
from PIL import Image
import tempfile
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
model_name_to_path = {
"LeCarnet-3M": "MaxLSB/LeCarnet-3M",
"LeCarnet-8M": "MaxLSB/LeCarnet-8M",
"LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
loaded_models = defaultdict(dict)
for name, path in model_name_to_path.items():
loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
loaded_models[name]["model"].eval()
def resize_logo(input_path, size=(100, 100)):
with Image.open(input_path) as img:
img = img.resize(size, Image.LANCZOS)
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(temp_file.name, format="PNG")
return temp_file.name
def respond(message, history, model_name, max_tokens, temperature, top_p):
history = history + [(message, "")]
yield history
tokenizer = loaded_models[model_name]["tokenizer"]
model = loaded_models[model_name]["model"]
inputs = tokenizer(message, return_tensors="pt")
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=False,
skip_special_tokens=True,
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
accumulated = f"**{model_name}**\n\n"
for new_text in streamer:
accumulated += new_text
history[-1] = (message, accumulated)
yield history
def submit(message, history, model_name, max_tokens, temperature, top_p):
for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
yield updated_history, ""
def start_with_example(example, model_name, max_tokens, temperature, top_p):
for updated_history in respond(example, [], model_name, max_tokens, temperature, top_p):
yield updated_history, ""
resized_logo_path = resize_logo("media/le-carnet.png", size=(100, 100))
examples = [
"Il était une fois un petit garçon qui vivait dans un village paisible.",
"Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang.",
"Il était une fois un petit lapin perdu",
]
with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
gr.Markdown("# LeCarnet")
gr.Markdown("Select a model on the right and type a message to chat, or choose an example below.")
with gr.Row():
with gr.Column(scale=4):
dataset = gr.Dataset(components=[gr.Textbox(visible=False)], samples=[[ex] for ex in examples], type="values")
chatbot = gr.Chatbot(
avatar_images=(None, resized_logo_path),
label="Chat",
height=600,
)
user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
submit_btn = gr.Button("Send")
with gr.Column(scale=1, min_width=200):
model_dropdown = gr.Dropdown(
choices=list(model_name_to_path.keys()),
value="LeCarnet-8M",
label="Model"
)
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
submit_btn.click(
fn=submit,
inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
outputs=[chatbot, user_input],
)
dataset.change(
fn=start_with_example,
inputs=[dataset, model_dropdown, max_tokens, temperature, top_p],
outputs=[chatbot, user_input],
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10) |