Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ import gradio as gr
|
|
5 |
from gradio.themes import Base
|
6 |
from gradio.themes.utils import colors
|
7 |
|
8 |
-
from transformers import pipeline, TextIteratorStreamer
|
9 |
|
10 |
|
11 |
# Custom theme colors based on brand standards
|
@@ -433,7 +433,9 @@ paper_plane_svg = """<svg xmlns="http://www.w3.org/2000/svg" width="20" height="
|
|
433 |
|
434 |
|
435 |
# Pipeline loading
|
436 |
-
generator = pipeline("text-generation", model="openai-community/gpt2")
|
|
|
|
|
437 |
|
438 |
# Mock data function for chatbot
|
439 |
def send_message(message, history):
|
@@ -442,7 +444,6 @@ def send_message(message, history):
|
|
442 |
#history.append({"role": "user", "content": message})
|
443 |
#history.append({"role": "assistant", "content": f"This is a response about: {message}"})
|
444 |
#return history
|
445 |
-
tokenizer = generator.tokenizer
|
446 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
447 |
input_ids = tokenizer.encode(message, return_tensors="pt")
|
448 |
gen_kwargs = {
|
@@ -455,7 +456,7 @@ def send_message(message, history):
|
|
455 |
"repetition_penalty": 1.25,
|
456 |
}
|
457 |
partial = ""
|
458 |
-
thread = Thread(target=
|
459 |
thread.start()
|
460 |
#for token in generator(message, max_new_tokens=200):
|
461 |
for t in streamer:
|
|
|
5 |
from gradio.themes import Base
|
6 |
from gradio.themes.utils import colors
|
7 |
|
8 |
+
from transformers import pipeline, TextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer
|
9 |
|
10 |
|
11 |
# Custom theme colors based on brand standards
|
|
|
433 |
|
434 |
|
435 |
# Pipeline loading
|
436 |
+
#generator = pipeline("text-generation", model="openai-community/gpt2")
|
437 |
+
tokenizer = AutoTokenizer("openai-community/gpt2")
|
438 |
+
model = AutoModelForCausalLM("openai-community/gpt2")
|
439 |
|
440 |
# Mock data function for chatbot
|
441 |
def send_message(message, history):
|
|
|
444 |
#history.append({"role": "user", "content": message})
|
445 |
#history.append({"role": "assistant", "content": f"This is a response about: {message}"})
|
446 |
#return history
|
|
|
447 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
448 |
input_ids = tokenizer.encode(message, return_tensors="pt")
|
449 |
gen_kwargs = {
|
|
|
456 |
"repetition_penalty": 1.25,
|
457 |
}
|
458 |
partial = ""
|
459 |
+
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
460 |
thread.start()
|
461 |
#for token in generator(message, max_new_tokens=200):
|
462 |
for t in streamer:
|