Spaces:
Sleeping
Sleeping
File size: 4,150 Bytes
36942d4 a7a20a5 39c555f 954f37f 39c555f 954f37f 39c555f 954f37f a167f72 39c555f 954f37f 39c555f a167f72 39c555f 644b0a5 a167f72 954f37f 39c555f 954f37f 644b0a5 954f37f a7a20a5 954f37f a7a20a5 954f37f a7a20a5 954f37f a167f72 a7a20a5 39c555f 644b0a5 a7a20a5 644b0a5 6ecb51d 644b0a5 954f37f 52a9a97 6ecb51d 644b0a5 a167f72 6ecb51d a167f72 644b0a5 954f37f 6ecb51d 7b4f2fa 954f37f 7b4f2fa 954f37f 7b4f2fa 644b0a5 954f37f a167f72 644b0a5 7b4f2fa 644b0a5 a167f72 6ecb51d 341bd22 f5f805b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import os
import threading
from collections import defaultdict
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
# Define model paths
model_name_to_path = {
"LeCarnet-3M": "MaxLSB/LeCarnet-3M",
"LeCarnet-8M": "MaxLSB/LeCarnet-8M",
"LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}
# Load Hugging Face token
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN", "default_token") # Use default to avoid errors
# Preload models and tokenizers
loaded_models = defaultdict(dict)
for name, path in model_name_to_path.items():
try:
loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
loaded_models[name]["model"].eval()
except Exception as e:
print(f"Error loading {name}: {str(e)}")
def respond(message, history, model_name, max_tokens, temperature, top_p):
history = history + [(message, "")]
yield history
tokenizer = loaded_models[model_name]["tokenizer"]
model = loaded_models[model_name]["model"]
inputs = tokenizer(message, return_tensors="pt")
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=False,
skip_special_tokens=True,
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
accumulated = "" # Removed model name prefix
for new_text in streamer:
accumulated += new_text
history[-1] = (message, accumulated)
yield history
def submit(message, history, model_name, max_tokens, temperature, top_p):
for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
yield updated_history, ""
with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
gr.Markdown("# LeCarnet")
gr.Markdown("Select a model on the right and type a message to chat.")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
avatar_images=(None, "https://raw.githubusercontent.com/maxlsb/le-carnet/main/media/le-carnet.png"), # Using URL for reliability
label="Chat",
height=600,
)
user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
submit_btn = gr.Button("Send")
examples = gr.Examples(
examples=[
["Il était une fois un petit garçon qui vivait dans un village paisible."],
["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
["Il était une fois un petit lapin perdu"],
],
inputs=user_input,
)
with gr.Column(scale=1, min_width=200):
model_dropdown = gr.Dropdown(
choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
value="LeCarnet-8M",
label="Select Model"
)
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
# Submit button click
submit_btn.click(
fn=submit,
inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
outputs=[chatbot, user_input],
)
# Enter key press
user_input.submit(
fn=submit,
inputs=[user_input, chatbot, model_dropdown, max_tokens, temperature, top_p],
outputs=[chatbot, user_input],
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10) |