Spaces:
Sleeping
Sleeping
File size: 4,107 Bytes
36942d4 f6b834f a7a20a5 39c555f 790cffd 954f37f 790cffd 954f37f 2c7b633 f6b834f 3a38c1f 790cffd 3a38c1f 790cffd 3a38c1f 790cffd 3a38c1f 790cffd 90d1b16 954f37f a7a20a5 954f37f f6b834f a7a20a5 954f37f 603f014 a7a20a5 603f014 790cffd 603f014 790cffd 603f014 7811152 f19d748 99f5fa0 07012cb f19d748 3a38c1f b7b0fd1 02deb9a 52a9a97 790cffd 1cbb5a4 790cffd b7b0fd1 5b21f39 790cffd 954f37f 56d40da 2d0a01f 954f37f 790cffd 7b4f2fa 954f37f 5b21f39 790cffd 603f014 a167f72 603f014 6ecb51d 341bd22 3a38c1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import torch
import threading
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
torch.set_num_threads(1)
# Globals
tokenizer = None
model = None
current_model_name = None
# Load selected model
def load_model(model_name):
global tokenizer, model, current_model_name
full_model_name = f"MaxLSB/{model_name}"
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
model.eval()
current_model_name = model_name
# Initialize default model
load_model("LeCarnet-8M")
# Streaming generation function
def respond(message, max_tokens, temperature, top_p):
inputs = tokenizer(message, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
def run():
with torch.no_grad():
model.generate(**generate_kwargs)
thread = threading.Thread(target=run)
thread.start()
response = ""
for new_text in streamer:
response += new_text
# prepend model name on its own line
yield f"**{current_model_name}**\n\n{response}"
# User input handler
def user(message, chat_history):
chat_history.append([message, None])
return "", chat_history
# Bot response handler
def bot(chatbot, max_tokens, temperature, top_p):
message = chatbot[-1][0]
response_generator = respond(message, max_tokens, temperature, top_p)
for response in response_generator:
chatbot[-1][1] = response
yield chatbot
# Model selector handler
def update_model(model_name):
load_model(model_name)
return []
# Gradio UI
with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
with gr.Row():
gr.HTML("""
<div style="text-align: center; width: 100%;">
<h1 style="margin: 0;">LeCarnet Demo 📊</h1>
</div>
""" )
msg_input = gr.Textbox(
placeholder="Il était une fois un petit garçon",
label="User Input"
)
with gr.Row():
with gr.Column(scale=1, min_width=150):
model_selector = gr.Dropdown(
choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
value="LeCarnet-8M",
label="Select Model"
)
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling")
clear_button = gr.Button("Clear Chat")
gr.Examples(
examples=[
["Il était une fois un petit phoque nommé Zoom. Zoom était très habile et aimait jouer dans l'eau."],
["Il était une fois un petit écureuil nommé Pipo. Pipo adorait grimper aux arbres."],
["Il était une fois un petit garçon nommé Tom. Tom aimait beaucoup dessiner."],
],
inputs=msg_input,
label="Example Prompts"
)
with gr.Column(scale=4):
chatbot = gr.Chatbot(
bubble_full_width=False,
height=500
)
model_selector.change(fn=update_model, inputs=[model_selector], outputs=[])
msg_input.submit(fn=user, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot], queue=False).then(
fn=bot, inputs=[chatbot, max_tokens, temperature, top_p], outputs=[chatbot]
)
clear_button.click(fn=lambda: None, inputs=None, outputs=chatbot, queue=False)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)
|