fCola commited on
Commit
ccd20ce
·
verified ·
1 Parent(s): 1677d8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -5,7 +5,7 @@ import gradio as gr
5
  from gradio.themes import Base
6
  from gradio.themes.utils import colors
7
 
8
- from transformers import pipeline, TextIteratorStreamer
9
 
10
 
11
  # Custom theme colors based on brand standards
@@ -433,7 +433,9 @@ paper_plane_svg = """<svg xmlns="http://www.w3.org/2000/svg" width="20" height="
433
 
434
 
435
  # Pipeline loading
436
- generator = pipeline("text-generation", model="openai-community/gpt2")
 
 
437
 
438
  # Mock data function for chatbot
439
  def send_message(message, history):
@@ -442,7 +444,6 @@ def send_message(message, history):
442
  #history.append({"role": "user", "content": message})
443
  #history.append({"role": "assistant", "content": f"This is a response about: {message}"})
444
  #return history
445
- tokenizer = generator.tokenizer
446
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
447
  input_ids = tokenizer.encode(message, return_tensors="pt")
448
  gen_kwargs = {
@@ -455,7 +456,7 @@ def send_message(message, history):
455
  "repetition_penalty": 1.25,
456
  }
457
  partial = ""
458
- thread = Thread(target=generator, kwargs=gen_kwargs)
459
  thread.start()
460
  #for token in generator(message, max_new_tokens=200):
461
  for t in streamer:
 
5
  from gradio.themes import Base
6
  from gradio.themes.utils import colors
7
 
8
+ from transformers import pipeline, TextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer
9
 
10
 
11
  # Custom theme colors based on brand standards
 
433
 
434
 
435
  # Pipeline loading
436
+ #generator = pipeline("text-generation", model="openai-community/gpt2")
437
+ tokenizer = AutoTokenizer("openai-community/gpt2")
438
+ model = AutoModelForCausalLM("openai-community/gpt2")
439
 
440
  # Mock data function for chatbot
441
  def send_message(message, history):
 
444
  #history.append({"role": "user", "content": message})
445
  #history.append({"role": "assistant", "content": f"This is a response about: {message}"})
446
  #return history
 
447
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
448
  input_ids = tokenizer.encode(message, return_tensors="pt")
449
  gen_kwargs = {
 
456
  "repetition_penalty": 1.25,
457
  }
458
  partial = ""
459
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
460
  thread.start()
461
  #for token in generator(message, max_new_tokens=200):
462
  for t in streamer: