|
import gc |
|
import unittest |
|
import weakref |
|
from unittest.mock import MagicMock |
|
|
|
import torch |
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline |
|
from transformers.generation.candidate_generator import ( |
|
AssistantToTargetTranslator, |
|
AssistantVocabTranslatorCache, |
|
UniversalSpeculativeDecodingGenerator, |
|
) |
|
from transformers.testing_utils import require_torch, torch_device |
|
|
|
|
|
@require_torch |
|
class TestAssistantToTargetTranslator(unittest.TestCase): |
|
def setUp(self): |
|
|
|
self.target_tokenizer = MagicMock() |
|
self.assistant_tokenizer = MagicMock() |
|
self.assistant_model = MagicMock(device=torch_device) |
|
|
|
|
|
self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3} |
|
self.assistant_vocab = {"hello": 0, "world": 1, "foo": 2, "baz": 4} |
|
|
|
self.target_tokenizer.get_vocab.return_value = self.target_vocab |
|
self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab |
|
self.target_vocab_size = 6 |
|
|
|
|
|
self.translator = AssistantToTargetTranslator( |
|
target_tokenizer=self.target_tokenizer, |
|
assistant_tokenizer=self.assistant_tokenizer, |
|
target_vocab_size=self.target_vocab_size, |
|
assistant_model=self.assistant_model, |
|
assistant_prune_lm_head=False, |
|
) |
|
|
|
def test_get_assistant_to_target_input_ids(self): |
|
"""Test the mapping from assistant tokens to target tokens.""" |
|
expected_mapping = [0, 1, 2, self.translator.SUPPRESS_TOKEN_ID, self.translator.SUPPRESS_TOKEN_ID] |
|
actual_mapping = self.translator._assistant_to_target_input_ids.tolist() |
|
self.assertEqual(actual_mapping, expected_mapping) |
|
|
|
def test_get_suppress_input_ids(self): |
|
"""Test the suppression of assistant input IDs not present in the target vocabulary.""" |
|
expected_suppress_ids = [3, 4] |
|
actual_suppress_ids = self.translator._get_suppress_input_ids().tolist() |
|
self.assertEqual(actual_suppress_ids, expected_suppress_ids) |
|
|
|
def test_get_target_ids(self): |
|
"""Test the translation of assistant candidate IDs to target candidate IDs.""" |
|
assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to( |
|
self.assistant_model.device |
|
) |
|
target_input_ids = torch.LongTensor([[0, 1, 2]]).to( |
|
self.assistant_model.device |
|
) |
|
assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to( |
|
self.assistant_model.device |
|
) |
|
|
|
expected_target_ids = torch.LongTensor( |
|
[[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]] |
|
).to( |
|
self.assistant_model.device |
|
) |
|
|
|
actual_target_ids = self.translator.get_target_ids( |
|
assistant_input_ids, target_input_ids, assistant_candidate_ids |
|
) |
|
self.assertTrue(torch.equal(actual_target_ids, expected_target_ids)) |
|
|
|
def test_get_target_logits(self): |
|
"""Test the conversion of assistant logits to target logits.""" |
|
|
|
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to( |
|
self.assistant_model.device |
|
) |
|
|
|
|
|
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to( |
|
self.assistant_model.device |
|
) |
|
expected_target_logits[0, 0, 0] = 0.1 |
|
expected_target_logits[0, 0, 1] = 0.2 |
|
expected_target_logits[0, 0, 2] = 0.3 |
|
|
|
|
|
actual_target_logits = self.translator.get_target_logits(assistant_logits) |
|
self.assertTrue(torch.equal(actual_target_logits, expected_target_logits)) |
|
|
|
|
|
class MockTokenizer: |
|
"""A simple mock tokenizer class that supports weak references.""" |
|
|
|
def __init__(self, vocab=None): |
|
self._vocab = vocab or {} |
|
|
|
def get_vocab(self): |
|
return self._vocab |
|
|
|
def __call__(self, text, add_special_tokens=True): |
|
|
|
tokens = text.split() |
|
input_ids = [self._vocab.get(token, 0) for token in tokens] |
|
return {"input_ids": input_ids} |
|
|
|
|
|
@require_torch |
|
class TestAssistantVocabTranslatorCache(unittest.TestCase): |
|
def setUp(self): |
|
|
|
AssistantVocabTranslatorCache._cache.clear() |
|
|
|
self.target_tokenizer = MockTokenizer({"hello": 0, "world": 1}) |
|
self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2}) |
|
self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3}) |
|
self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5}) |
|
self.assistant_model = MagicMock(device=torch_device) |
|
|
|
self.target_vocab_size = 6 |
|
|
|
def test_same_instance_for_same_tokenizers(self): |
|
"""Test that the same translator is returned for the same tokenizers.""" |
|
translator1 = AssistantVocabTranslatorCache.get_translator( |
|
self.target_tokenizer, |
|
self.assistant_tokenizer, |
|
target_vocab_size=self.target_vocab_size, |
|
assistant_model=self.assistant_model, |
|
assistant_prune_lm_head=False, |
|
) |
|
translator2 = AssistantVocabTranslatorCache.get_translator( |
|
self.target_tokenizer, |
|
self.assistant_tokenizer, |
|
target_vocab_size=self.target_vocab_size, |
|
assistant_model=self.assistant_model, |
|
assistant_prune_lm_head=False, |
|
) |
|
self.assertIs(translator1, translator2, "Translators should be cached and identical") |
|
|
|
def test_different_instances_for_different_tokenizers(self): |
|
"""Test that different tokenizers produce different translators.""" |
|
translator1 = AssistantVocabTranslatorCache.get_translator( |
|
self.target_tokenizer, |
|
self.assistant_tokenizer, |
|
target_vocab_size=self.target_vocab_size, |
|
assistant_model=self.assistant_model, |
|
assistant_prune_lm_head=False, |
|
) |
|
translator2 = AssistantVocabTranslatorCache.get_translator( |
|
self.other_target_tokenizer, |
|
self.other_assistant_tokenizer, |
|
target_vocab_size=self.target_vocab_size, |
|
assistant_model=self.assistant_model, |
|
assistant_prune_lm_head=False, |
|
) |
|
self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers") |
|
|
|
def test_cache_with_weakref_key(self): |
|
"""Ensure that the cache uses weak references as keys.""" |
|
initial_cache_size = len(AssistantVocabTranslatorCache._cache) |
|
target_tokenizer = MockTokenizer({"hello": 0}) |
|
assistant_tokenizer = MockTokenizer({"hello": 0}) |
|
|
|
|
|
translator = AssistantVocabTranslatorCache.get_translator( |
|
target_tokenizer, |
|
assistant_tokenizer, |
|
target_vocab_size=self.target_vocab_size, |
|
assistant_model=self.assistant_model, |
|
assistant_prune_lm_head=False, |
|
) |
|
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1) |
|
|
|
|
|
del target_tokenizer |
|
del assistant_tokenizer |
|
del translator |
|
|
|
|
|
gc.collect() |
|
|
|
|
|
AssistantVocabTranslatorCache.cleanup() |
|
|
|
|
|
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1) |
|
|
|
def test_weakref_cache_cleanup(self): |
|
"""Test that the cache cleans up translators when tokenizers are garbage collected.""" |
|
|
|
def create_translator(): |
|
target_tokenizer = MockTokenizer({"hello": 0}) |
|
assistant_tokenizer = MockTokenizer({"hello": 0}) |
|
translator = AssistantVocabTranslatorCache.get_translator( |
|
target_tokenizer, |
|
assistant_tokenizer, |
|
target_vocab_size=self.target_vocab_size, |
|
assistant_model=self.assistant_model, |
|
assistant_prune_lm_head=False, |
|
) |
|
|
|
refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer)) |
|
|
|
del target_tokenizer |
|
del assistant_tokenizer |
|
del translator |
|
return refs |
|
|
|
translator_ref, target_ref, assistant_ref = create_translator() |
|
|
|
|
|
gc.collect() |
|
|
|
|
|
AssistantVocabTranslatorCache.cleanup() |
|
|
|
|
|
self.assertIsNotNone(target_ref(), "Target tokenizer should still be alive due to strong references") |
|
self.assertIsNotNone(assistant_ref(), "Assistant tokenizer should still be alive due to strong references") |
|
self.assertIsNotNone(translator_ref(), "Translator should still be alive due to strong references") |
|
|
|
|
|
@require_torch |
|
class TestUniversalSpeculativeDecoding(unittest.TestCase): |
|
@classmethod |
|
def setUpClass(cls): |
|
cls.target_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" |
|
cls.assistant_name = "hf-internal-testing/tiny-random-PhiForCausalLM" |
|
|
|
def setUp(self): |
|
self.target_tokenizer = AutoTokenizer.from_pretrained(self.target_name) |
|
self.target_config = AutoConfig.from_pretrained(self.target_name) |
|
self.assistant_model = AutoModelForCausalLM.from_pretrained(self.assistant_name).to(torch_device) |
|
self.assistant_tokenizer = AutoTokenizer.from_pretrained(self.assistant_name) |
|
|
|
self.generation_config = GenerationConfig() |
|
|
|
|
|
if self.target_tokenizer.pad_token_id is None: |
|
self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id |
|
if self.target_tokenizer.bos_token_id is None: |
|
self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id |
|
if self.assistant_tokenizer.pad_token_id is None: |
|
self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id |
|
if self.assistant_tokenizer.bos_token_id is None: |
|
self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id |
|
|
|
self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) |
|
self.model_kwargs = { |
|
"attention_mask": torch.ones_like(self.input_ids).to(torch_device), |
|
} |
|
atm_translator = AssistantVocabTranslatorCache.get_translator( |
|
target_tokenizer=self.target_tokenizer, |
|
assistant_tokenizer=self.assistant_tokenizer, |
|
assistant_model=self.assistant_model, |
|
target_vocab_size=self.target_config.vocab_size, |
|
) |
|
self.generator = UniversalSpeculativeDecodingGenerator( |
|
input_ids=self.input_ids, |
|
assistant_model=self.assistant_model, |
|
target_tokenizer=self.target_tokenizer, |
|
assistant_tokenizer=self.assistant_tokenizer, |
|
generation_config=self.generation_config, |
|
model_kwargs=self.model_kwargs, |
|
atm_translator=atm_translator, |
|
) |
|
|
|
def test_basic_generation(self): |
|
"""Test basic speculative decoding works""" |
|
input_text = "The quick brown fox" |
|
input_ids = self.target_tokenizer.encode(input_text, return_tensors="pt") |
|
self.generator.input_ids = input_ids |
|
candidates, scores = self.generator.get_candidates(input_ids) |
|
|
|
self.assertIsNotNone(candidates) |
|
self.assertIsNotNone(scores) |
|
self.assertTrue(torch.is_tensor(candidates)) |
|
self.assertTrue(torch.is_tensor(scores)) |
|
|
|
def test_mismatched_vocabularies(self): |
|
"""Test handling of mismatched vocabularies between models""" |
|
|
|
|
|
|
|
missing_token = next( |
|
token |
|
for token in self.target_tokenizer.get_vocab() |
|
if token not in self.assistant_tokenizer.get_vocab() |
|
and token not in self.target_tokenizer.all_special_tokens |
|
and "reserved_" not in token |
|
) |
|
input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]]) |
|
self.generator.input_ids = input_ids |
|
candidates, _ = self.generator.get_candidates(input_ids) |
|
self.assertIsNotNone(candidates) |
|
|
|
def test_speculation_depth(self): |
|
"""Test different speculation depths""" |
|
input_ids = self.target_tokenizer.encode("Test text", return_tensors="pt") |
|
self.generator.input_ids = input_ids |
|
|
|
for depth in [1, 8, 17]: |
|
self.generator.num_assistant_tokens = depth |
|
candidates, _ = self.generator.get_candidates(input_ids) |
|
self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth) |
|
|
|
def test_device_consistency(self): |
|
"""Test handling of inputs on different devices""" |
|
input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) |
|
self.generator.input_ids = input_ids |
|
candidates, _ = self.generator.get_candidates(input_ids) |
|
self.assertEqual(candidates.device, input_ids.device) |
|
|
|
def test_usd_vs_vanilla_sampling(cls): |
|
"""Test that USD matches vanilla sampling with temperature set to nearly 0""" |
|
prompt = "Test text" |
|
|
|
pipe_vanilla = pipeline( |
|
"text-generation", |
|
model=cls.target_name, |
|
) |
|
pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False) |
|
vanilla_text = pipe_vanilla_output[0]["generated_text"] |
|
|
|
pipe_usd = pipeline( |
|
"text-generation", |
|
model=cls.target_name, |
|
assistant_model=cls.assistant_name, |
|
) |
|
pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) |
|
usd_text = pipe_usd_output[0]["generated_text"] |
|
|
|
|
|
cls.assertEqual(usd_text, vanilla_text) |
|
|