File size: 8,246 Bytes
e0be88b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# Copyright 2023 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from queue import Empty
from threading import Thread
from unittest.mock import patch
import pytest
from transformers import (
AsyncTextIteratorStreamer,
AutoTokenizer,
TextIteratorStreamer,
TextStreamer,
is_torch_available,
)
from transformers.testing_utils import CaptureStdout, require_torch, torch_device
from transformers.utils.logging import _get_library_root_logger
from ..test_modeling_common import ids_tensor
if is_torch_available():
import torch
from transformers import AutoModelForCausalLM
@require_torch
class StreamerTester(unittest.TestCase):
def test_text_streamer_matches_non_streaming(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
greedy_text = tokenizer.decode(greedy_ids[0])
with CaptureStdout() as cs:
streamer = TextStreamer(tokenizer)
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
streamer_text = cs.out[:-1]
self.assertEqual(streamer_text, greedy_text)
def test_iterator_streamer_matches_non_streaming(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
greedy_text = tokenizer.decode(greedy_ids[0])
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
streamer_text = ""
for new_text in streamer:
streamer_text += new_text
self.assertEqual(streamer_text, greedy_text)
def test_text_streamer_skip_prompt(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
new_greedy_ids = greedy_ids[:, input_ids.shape[1] :]
new_greedy_text = tokenizer.decode(new_greedy_ids[0])
with CaptureStdout() as cs:
streamer = TextStreamer(tokenizer, skip_prompt=True)
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
streamer_text = cs.out[:-1]
self.assertEqual(streamer_text, new_greedy_text)
def test_text_streamer_decode_kwargs(self):
# Tests that we can pass `decode_kwargs` to the streamer to control how the tokens are decoded. Must be tested
# with actual models -- the dummy models' tokenizers are not aligned with their models, and
# `skip_special_tokens=True` has no effect on them
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2").to(torch_device)
model.config.eos_token_id = -1
input_ids = torch.ones((1, 5), device=torch_device).long() * model.config.bos_token_id
root = _get_library_root_logger()
with patch.object(root, "propagate", False):
with CaptureStdout() as cs:
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)
# The prompt contains a special token, so the streamer should not print it. As such, the output text, when
# re-tokenized, must only contain one token
streamer_text = cs.out[:-1] # Remove the final "\n"
streamer_text_tokenized = tokenizer(streamer_text, return_tensors="pt")
self.assertEqual(streamer_text_tokenized.input_ids.shape, (1, 1))
def test_iterator_streamer_timeout(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
streamer = TextIteratorStreamer(tokenizer, timeout=0.001)
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# The streamer will timeout after 0.001 seconds, so an exception will be raised
with self.assertRaises(Empty):
streamer_text = ""
for new_text in streamer:
streamer_text += new_text
@require_torch
@pytest.mark.asyncio(loop_scope="class")
class AsyncStreamerTester(unittest.IsolatedAsyncioTestCase):
async def test_async_iterator_streamer_matches_non_streaming(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
greedy_text = tokenizer.decode(greedy_ids[0])
streamer = AsyncTextIteratorStreamer(tokenizer)
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
streamer_text = ""
async for new_text in streamer:
streamer_text += new_text
self.assertEqual(streamer_text, greedy_text)
async def test_async_iterator_streamer_timeout(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
streamer = AsyncTextIteratorStreamer(tokenizer, timeout=0.001)
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# The streamer will timeout after 0.001 seconds, so TimeoutError will be raised
with self.assertRaises(TimeoutError):
streamer_text = ""
async for new_text in streamer:
streamer_text += new_text
|