Spaces:
Running
Running
File size: 5,486 Bytes
36942d4 a7a20a5 39c555f 954f37f 39c555f 954f37f 39c555f 954f37f 39c555f 954f37f 39c555f 644b0a5 954f37f 644b0a5 954f37f 39c555f 954f37f 644b0a5 954f37f a7a20a5 954f37f a7a20a5 954f37f a7a20a5 954f37f a7a20a5 39c555f 644b0a5 a7a20a5 644b0a5 954f37f 644b0a5 954f37f 6ecb51d 954f37f 644b0a5 954f37f 52a9a97 954f37f 6ecb51d 644b0a5 954f37f 6ecb51d 954f37f 644b0a5 954f37f 6ecb51d 954f37f 7b4f2fa 954f37f 7b4f2fa 954f37f 7b4f2fa 954f37f 644b0a5 954f37f 644b0a5 7b4f2fa 644b0a5 6ecb51d 341bd22 f5f805b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import threading
from collections import defaultdict
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
# Define model paths
model_name_to_path = {
"LeCarnet-3M": "MaxLSB/LeCarnet-3M",
"LeCarnet-8M": "MaxLSB/LeCarnet-8M",
"LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}
# Load Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
# Preload models and tokenizers
loaded_models = defaultdict(dict)
for name, path in model_name_to_path.items():
loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
loaded_models[name]["model"].eval()
def respond(message, history, model_name, max_tokens, temperature, top_p):
"""
Generate a response from the selected model, streaming the output and updating chat history.
Args:
message (str): User's input message.
history (list): Current chat history as list of (user_msg, bot_msg) tuples.
model_name (str): Selected model name.
max_tokens (int): Maximum number of tokens to generate.
temperature (float): Sampling temperature.
top_p (float): Top-p sampling parameter.
Yields:
list: Updated chat history with the user's message and streaming bot response.
"""
# Append user's message to history with an empty bot response
history = history + [(message, "")]
yield history # Display user's message immediately
# Select tokenizer and model
tokenizer = loaded_models[model_name]["tokenizer"]
model = loaded_models[model_name]["model"]
# Tokenize input
inputs = tokenizer(message, return_tensors="pt")
# Set up streaming
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=False,
skip_special_tokens=True,
)
# Configure generation parameters
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
# Start generation in a background thread
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# Stream the response with model name prefix
accumulated = f"**{model_name}:** "
for new_text in streamer:
accumulated += new_text
history[-1] = (message, accumulated)
yield history
def submit(message, history, model_name, max_tokens, temperature, top_p):
"""
Handle form submission by calling respond and clearing the input box.
Args:
message (str): User's input message.
history (list): Current chat history.
model_name (str): Selected model name.
max_tokens (int): Max tokens parameter.
temperature (float): Temperature parameter.
top_p (float): Top-p parameter.
Yields:
tuple: (updated chat history, cleared user input)
"""
for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
yield updated_history, ""
# Create the Gradio interface with Blocks
with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
# Title and description
gr.Markdown("# LeCarnet")
gr.Markdown("Select a model on the right and type a message to chat.")
# Two-column layout with specific widths
with gr.Row():
# Left column: Chat interface (80% width)
with gr.Column(scale=4):
chatbot = gr.Chatbot(
avatar_images=(None, "media/le-carnet.png"), # User avatar: None, Bot avatar: Logo
label="Chat",
height=600, # Increase chat height for larger display
)
user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
submit_btn = gr.Button("Send")
# Example prompts
examples = gr.Examples(
examples=[
["Il était une fois un petit garçon qui vivait dans un village paisible."],
["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
["Il était une fois un petit lapin perdu"],
],
inputs=user_input,
)
# Right column: Model selection and parameters (20% width)
with gr.Column(scale=1, min_width=200):
# Dropdown for model selection
model_dropdown = gr.Dropdown(
choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
value="LeCarnet-8M",
label="Select Model"
)
# Sliders for parameters
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
# Event handling for submit button
submit_btn.click(
fn=submit,
inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
outputs=[chatbot, user_input],
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10) |