Spaces:
Sleeping
Sleeping
File size: 1,420 Bytes
2418c6c dc8887b 2418c6c dc8887b 2418c6c 359d23c 632b373 dc8887b 2418c6c fcb3fb8 b261ed6 2418c6c b261ed6 fcb3fb8 2418c6c dc8887b 2418c6c fcb3fb8 2418c6c dc8887b ccd7279 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from functools import partial
tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_llama3_8b")
model = AutoModelForCausalLM.from_pretrained("IlyaGusev/saiga_llama3_8b", torch_dtype=torch.bfloat16)
model = model
def predict(message, history):
# print(history) [[вопрос1, ответ1], [вопрос2, ответ2]...]
history_transformer_format = history + [{"role": "user", "content": message},
{"role": "assistant", "content": ""}]
model_inputs = tokenizer.apply_chat_template(history_transformer_format, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, timeout=20., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=1.0,
num_beams=1,
)
generating_func = partial(model.generate, model_inputs)
t = Thread(target=generating_func, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
print(new_token)
if new_token != '<':
partial_message += new_token
yield partial_message
gr.ChatInterface(predict).launch(share=True)
|