init
This commit is contained in:
0
transformers/tests/generation/__init__.py
Normal file
0
transformers/tests/generation/__init__.py
Normal file
114
transformers/tests/generation/test_beam_constraints.py
Normal file
114
transformers/tests/generation/test_beam_constraints.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation import DisjunctiveConstraint
|
||||
|
||||
|
||||
@require_torch
|
||||
class ConstraintTest(unittest.TestCase):
|
||||
def test_input_types(self):
|
||||
# For consistency across different places the DisjunctiveConstraint is called,
|
||||
# dc.token_ids is a list of integers. It is also initialized only by integers.
|
||||
|
||||
cset = [[1, 2, 4], [1, 2, 3, 4]]
|
||||
dc = DisjunctiveConstraint(cset)
|
||||
self.assertTrue(isinstance(dc.token_ids, list))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
DisjunctiveConstraint(torch.LongTensor([[1, 2, 4], [1, 2, 3]]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
DisjunctiveConstraint([torch.LongTensor([1, 2, 4]), torch.LongTensor([1, 2, 3, 4, 5])])
|
||||
|
||||
def test_check_illegal_input(self):
|
||||
# We can't have constraints that are complete subsets of another. This leads to a perverse
|
||||
# interpretation of "constraint fulfillment": does generating [1,2,3] fulfill the constraint?
|
||||
# It would mean that it generated [1,2] which fulfills it, but it's in the middle of potentially
|
||||
# fulfilling [1,2,3,4]. If we believe that [1,2,3] does fulfill the constraint, then the algorithm
|
||||
# will necessarily never reach [1,2,3,4], giving users a false sense of control (better to just not allow it).
|
||||
cset = [[1, 2], [1, 2, 3, 4]]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
DisjunctiveConstraint(cset) # fails here
|
||||
|
||||
def test_example_progression(self):
|
||||
cset = [[1, 2, 3], [1, 2, 4]]
|
||||
|
||||
dc = DisjunctiveConstraint(cset)
|
||||
|
||||
stepped, completed, reset = dc.update(1)
|
||||
desired = stepped is True and completed is False and reset is False
|
||||
self.assertTrue(desired)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.current_seq == [1])
|
||||
|
||||
stepped, completed, reset = dc.update(2)
|
||||
desired = stepped is True and completed is False and reset is False
|
||||
self.assertTrue(desired)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.current_seq == [1, 2])
|
||||
|
||||
stepped, completed, reset = dc.update(3)
|
||||
desired = stepped is True and completed is True and reset is False
|
||||
self.assertTrue(desired)
|
||||
self.assertTrue(dc.completed) # Completed!
|
||||
self.assertTrue(dc.current_seq == [1, 2, 3])
|
||||
|
||||
def test_example_progression_unequal_three_mid_and_reset(self):
|
||||
cset = [[1, 2, 3], [1, 2, 4, 5], [1, 2, 5]]
|
||||
|
||||
dc = DisjunctiveConstraint(cset)
|
||||
|
||||
stepped, completed, reset = dc.update(1)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.current_seq == [1])
|
||||
|
||||
stepped, completed, reset = dc.update(2)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.current_seq == [1, 2])
|
||||
|
||||
stepped, completed, reset = dc.update(4)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.current_seq == [1, 2, 4])
|
||||
|
||||
stepped, completed, reset = dc.update(5)
|
||||
self.assertTrue(dc.completed) # Completed!
|
||||
self.assertTrue(dc.current_seq == [1, 2, 4, 5])
|
||||
|
||||
dc.reset()
|
||||
|
||||
stepped, completed, reset = dc.update(1)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.remaining() == 3)
|
||||
self.assertTrue(dc.current_seq == [1])
|
||||
|
||||
stepped, completed, reset = dc.update(2)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.remaining() == 2)
|
||||
self.assertTrue(dc.current_seq == [1, 2])
|
||||
|
||||
stepped, completed, reset = dc.update(5)
|
||||
self.assertTrue(dc.completed) # Completed!
|
||||
self.assertTrue(dc.remaining() == 0)
|
||||
self.assertTrue(dc.current_seq == [1, 2, 5])
|
||||
339
transformers/tests/generation/test_candidate_generator.py
Normal file
339
transformers/tests/generation/test_candidate_generator.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import gc
|
||||
import unittest
|
||||
import weakref
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline
|
||||
from transformers.generation.candidate_generator import (
|
||||
AssistantToTargetTranslator,
|
||||
AssistantVocabTranslatorCache,
|
||||
UniversalSpeculativeDecodingGenerator,
|
||||
)
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestAssistantToTargetTranslator(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Create mock tokenizers with predefined vocabularies
|
||||
self.target_tokenizer = MagicMock()
|
||||
self.assistant_tokenizer = MagicMock()
|
||||
self.assistant_model = MagicMock(device=torch_device)
|
||||
|
||||
# Define mock vocabularies for the tokenizers
|
||||
self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3}
|
||||
self.assistant_vocab = {"hello": 0, "world": 1, "foo": 2, "baz": 4}
|
||||
|
||||
self.target_tokenizer.get_vocab.return_value = self.target_vocab
|
||||
self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab
|
||||
self.target_vocab_size = 6
|
||||
|
||||
# Instantiate the class under test
|
||||
self.translator = AssistantToTargetTranslator(
|
||||
target_tokenizer=self.target_tokenizer,
|
||||
assistant_tokenizer=self.assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
|
||||
def test_get_assistant_to_target_input_ids(self):
|
||||
"""Test the mapping from assistant tokens to target tokens."""
|
||||
expected_mapping = [0, 1, 2, self.translator.SUPPRESS_TOKEN_ID, self.translator.SUPPRESS_TOKEN_ID]
|
||||
actual_mapping = self.translator._assistant_to_target_input_ids.tolist()
|
||||
self.assertEqual(actual_mapping, expected_mapping)
|
||||
|
||||
def test_get_suppress_input_ids(self):
|
||||
"""Test the suppression of assistant input IDs not present in the target vocabulary."""
|
||||
expected_suppress_ids = [3, 4]
|
||||
actual_suppress_ids = self.translator._get_suppress_input_ids().tolist()
|
||||
self.assertEqual(actual_suppress_ids, expected_suppress_ids)
|
||||
|
||||
def test_get_target_ids(self):
|
||||
"""Test the translation of assistant candidate IDs to target candidate IDs."""
|
||||
assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo' in assistant tokenizer
|
||||
target_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo' in target tokenizer
|
||||
assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to(
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo baz' in assistant tokenizer
|
||||
|
||||
expected_target_ids = torch.LongTensor(
|
||||
[[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]]
|
||||
).to(
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab)
|
||||
|
||||
actual_target_ids = self.translator.get_target_ids(
|
||||
assistant_input_ids, target_input_ids, assistant_candidate_ids
|
||||
)
|
||||
self.assertTrue(torch.equal(actual_target_ids, expected_target_ids))
|
||||
|
||||
def test_get_target_logits(self):
|
||||
"""Test the conversion of assistant logits to target logits."""
|
||||
# Assistant logits for IDs 0, 1, 2
|
||||
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to(
|
||||
self.assistant_model.device
|
||||
) # Shape (1, 1, 5)
|
||||
|
||||
# Expected target logits (target_vocab_size = 4)
|
||||
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to(
|
||||
self.assistant_model.device
|
||||
)
|
||||
expected_target_logits[0, 0, 0] = 0.1 # 'hello'
|
||||
expected_target_logits[0, 0, 1] = 0.2 # 'world'
|
||||
expected_target_logits[0, 0, 2] = 0.3 # 'foo'
|
||||
# The 'bar' token in target vocab remains at -inf
|
||||
|
||||
actual_target_logits = self.translator.get_target_logits(assistant_logits)
|
||||
self.assertTrue(torch.equal(actual_target_logits, expected_target_logits))
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
"""A simple mock tokenizer class that supports weak references."""
|
||||
|
||||
def __init__(self, vocab=None):
|
||||
self._vocab = vocab or {}
|
||||
|
||||
def get_vocab(self):
|
||||
return self._vocab
|
||||
|
||||
def __call__(self, text, add_special_tokens=True):
|
||||
# Mock implementation of the __call__ method
|
||||
tokens = text.split()
|
||||
input_ids = [self._vocab.get(token, 0) for token in tokens]
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Clear the cache before each test
|
||||
AssistantVocabTranslatorCache._cache.clear()
|
||||
# Create mock tokenizers with different vocabularies
|
||||
self.target_tokenizer = MockTokenizer({"hello": 0, "world": 1})
|
||||
self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2})
|
||||
self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3})
|
||||
self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5})
|
||||
self.assistant_model = MagicMock(device=torch_device)
|
||||
|
||||
self.target_vocab_size = 6
|
||||
|
||||
def test_same_instance_for_same_tokenizers(self):
|
||||
"""Test that the same translator is returned for the same tokenizers."""
|
||||
translator1 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.target_tokenizer,
|
||||
self.assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
translator2 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.target_tokenizer,
|
||||
self.assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
self.assertIs(translator1, translator2, "Translators should be cached and identical")
|
||||
|
||||
def test_different_instances_for_different_tokenizers(self):
|
||||
"""Test that different tokenizers produce different translators."""
|
||||
translator1 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.target_tokenizer,
|
||||
self.assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
translator2 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.other_target_tokenizer,
|
||||
self.other_assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers")
|
||||
|
||||
def test_cache_with_weakref_key(self):
|
||||
"""Ensure that the cache uses weak references as keys."""
|
||||
initial_cache_size = len(AssistantVocabTranslatorCache._cache)
|
||||
target_tokenizer = MockTokenizer({"hello": 0})
|
||||
assistant_tokenizer = MockTokenizer({"hello": 0})
|
||||
|
||||
# Store translator in a local variable to avoid it being kept alive
|
||||
translator = AssistantVocabTranslatorCache.get_translator(
|
||||
target_tokenizer,
|
||||
assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
|
||||
|
||||
# Delete all strong references
|
||||
del target_tokenizer
|
||||
del assistant_tokenizer
|
||||
del translator
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
|
||||
# Call cleanup to remove dead entries
|
||||
AssistantVocabTranslatorCache.cleanup()
|
||||
|
||||
# The cache size remains increased due to strong references
|
||||
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
|
||||
|
||||
def test_weakref_cache_cleanup(self):
|
||||
"""Test that the cache cleans up translators when tokenizers are garbage collected."""
|
||||
|
||||
def create_translator():
|
||||
target_tokenizer = MockTokenizer({"hello": 0})
|
||||
assistant_tokenizer = MockTokenizer({"hello": 0})
|
||||
translator = AssistantVocabTranslatorCache.get_translator(
|
||||
target_tokenizer,
|
||||
assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
# Create weak references before returning
|
||||
refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer))
|
||||
# Remove strong references inside the function
|
||||
del target_tokenizer
|
||||
del assistant_tokenizer
|
||||
del translator
|
||||
return refs
|
||||
|
||||
translator_ref, target_ref, assistant_ref = create_translator()
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
|
||||
# Call cleanup to remove dead entries
|
||||
AssistantVocabTranslatorCache.cleanup()
|
||||
|
||||
# The tokenizers and translator are not garbage collected due to strong references
|
||||
self.assertIsNotNone(target_ref(), "Target tokenizer should still be alive due to strong references")
|
||||
self.assertIsNotNone(assistant_ref(), "Assistant tokenizer should still be alive due to strong references")
|
||||
self.assertIsNotNone(translator_ref(), "Translator should still be alive due to strong references")
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.target_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
cls.assistant_name = "hf-internal-testing/tiny-random-PhiForCausalLM"
|
||||
|
||||
def setUp(self):
|
||||
self.target_tokenizer = AutoTokenizer.from_pretrained(self.target_name)
|
||||
self.target_config = AutoConfig.from_pretrained(self.target_name)
|
||||
self.assistant_model = AutoModelForCausalLM.from_pretrained(self.assistant_name).to(torch_device)
|
||||
self.assistant_tokenizer = AutoTokenizer.from_pretrained(self.assistant_name)
|
||||
|
||||
self.generation_config = GenerationConfig()
|
||||
|
||||
# Ensure required tokens exist
|
||||
if self.target_tokenizer.pad_token_id is None:
|
||||
self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id
|
||||
if self.target_tokenizer.bos_token_id is None:
|
||||
self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id
|
||||
if self.assistant_tokenizer.pad_token_id is None:
|
||||
self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id
|
||||
if self.assistant_tokenizer.bos_token_id is None:
|
||||
self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id
|
||||
|
||||
self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
|
||||
self.model_kwargs = {
|
||||
"attention_mask": torch.ones_like(self.input_ids).to(torch_device),
|
||||
}
|
||||
atm_translator = AssistantVocabTranslatorCache.get_translator(
|
||||
target_tokenizer=self.target_tokenizer,
|
||||
assistant_tokenizer=self.assistant_tokenizer,
|
||||
assistant_model=self.assistant_model,
|
||||
target_vocab_size=self.target_config.vocab_size,
|
||||
)
|
||||
self.generator = UniversalSpeculativeDecodingGenerator(
|
||||
input_ids=self.input_ids,
|
||||
assistant_model=self.assistant_model,
|
||||
target_tokenizer=self.target_tokenizer,
|
||||
assistant_tokenizer=self.assistant_tokenizer,
|
||||
generation_config=self.generation_config,
|
||||
model_kwargs=self.model_kwargs,
|
||||
atm_translator=atm_translator,
|
||||
)
|
||||
|
||||
def test_basic_generation(self):
|
||||
"""Test basic speculative decoding works"""
|
||||
input_text = "The quick brown fox"
|
||||
input_ids = self.target_tokenizer.encode(input_text, return_tensors="pt")
|
||||
self.generator.input_ids = input_ids
|
||||
candidates, scores = self.generator.get_candidates(input_ids)
|
||||
|
||||
self.assertIsNotNone(candidates)
|
||||
self.assertIsNotNone(scores)
|
||||
self.assertTrue(torch.is_tensor(candidates))
|
||||
self.assertTrue(torch.is_tensor(scores))
|
||||
|
||||
def test_mismatched_vocabularies(self):
|
||||
"""Test handling of mismatched vocabularies between models"""
|
||||
# Create input with tokens present in main but not assistant vocab
|
||||
# Find a token that is not in the assistant tokenizer but in
|
||||
# the main tokenizer.
|
||||
missing_token = next(
|
||||
token
|
||||
for token in self.target_tokenizer.get_vocab()
|
||||
if token not in self.assistant_tokenizer.get_vocab()
|
||||
and token not in self.target_tokenizer.all_special_tokens
|
||||
and "reserved_" not in token
|
||||
)
|
||||
input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]])
|
||||
self.generator.input_ids = input_ids
|
||||
candidates, _ = self.generator.get_candidates(input_ids)
|
||||
self.assertIsNotNone(candidates)
|
||||
|
||||
def test_speculation_depth(self):
|
||||
"""Test different speculation depths"""
|
||||
input_ids = self.target_tokenizer.encode("Test text", return_tensors="pt")
|
||||
self.generator.input_ids = input_ids
|
||||
|
||||
for depth in [1, 8, 17]:
|
||||
self.generator.num_assistant_tokens = depth
|
||||
candidates, _ = self.generator.get_candidates(input_ids)
|
||||
self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth)
|
||||
|
||||
def test_device_consistency(self):
|
||||
"""Test handling of inputs on different devices"""
|
||||
input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
|
||||
self.generator.input_ids = input_ids
|
||||
candidates, _ = self.generator.get_candidates(input_ids)
|
||||
self.assertEqual(candidates.device, input_ids.device)
|
||||
|
||||
def test_usd_vs_vanilla_sampling(cls):
|
||||
"""Test that USD matches vanilla sampling with temperature set to nearly 0"""
|
||||
prompt = "Test text"
|
||||
|
||||
pipe_vanilla = pipeline(
|
||||
"text-generation",
|
||||
model=cls.target_name,
|
||||
)
|
||||
pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False)
|
||||
vanilla_text = pipe_vanilla_output[0]["generated_text"]
|
||||
|
||||
pipe_usd = pipeline(
|
||||
"text-generation",
|
||||
model=cls.target_name,
|
||||
assistant_model=cls.assistant_name,
|
||||
)
|
||||
pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature
|
||||
usd_text = pipe_usd_output[0]["generated_text"]
|
||||
|
||||
# Assert that the outputs match
|
||||
cls.assertEqual(usd_text, vanilla_text)
|
||||
770
transformers/tests/generation/test_configuration_utils.py
Normal file
770
transformers/tests/generation/test_configuration_utils.py
Normal file
@@ -0,0 +1,770 @@
|
||||
# Copyright 2022 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from huggingface_hub import create_pull_request
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation import (
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
EpsilonLogitsWarper,
|
||||
EtaLogitsWarper,
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
GenerationMode,
|
||||
MinLengthLogitsProcessor,
|
||||
MinNewTokensLengthLogitsProcessor,
|
||||
MinPLogitsWarper,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
SuppressTokensAtBeginLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
CaptureLogger,
|
||||
LoggingLevel,
|
||||
TemporaryHubRepo,
|
||||
is_staging_test,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
class GenerationConfigTest(unittest.TestCase):
|
||||
@parameterized.expand([(None,), ("foo.json",)])
|
||||
def test_save_load_config(self, config_name):
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
bad_words_ids=[[1, 2, 3], [4, 5]],
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, config_name=config_name)
|
||||
loaded_config = GenerationConfig.from_pretrained(tmp_dir, config_name=config_name)
|
||||
|
||||
# Checks parameters that were specified
|
||||
self.assertEqual(loaded_config.do_sample, True)
|
||||
self.assertEqual(loaded_config.temperature, 0.7)
|
||||
self.assertEqual(loaded_config.length_penalty, 1.0)
|
||||
self.assertEqual(loaded_config.bad_words_ids, [[1, 2, 3], [4, 5]])
|
||||
|
||||
# Checks parameters that were not specified (defaults)
|
||||
self.assertEqual(loaded_config.top_k, 50)
|
||||
self.assertEqual(loaded_config.max_length, 20)
|
||||
self.assertEqual(loaded_config.max_time, None)
|
||||
|
||||
def test_from_model_config(self):
|
||||
model_config = AutoConfig.from_pretrained("openai-community/gpt2")
|
||||
generation_config_from_model = GenerationConfig.from_model_config(model_config)
|
||||
default_generation_config = GenerationConfig()
|
||||
|
||||
# The generation config has loaded a few non-default parameters from the model config
|
||||
self.assertNotEqual(generation_config_from_model, default_generation_config)
|
||||
|
||||
# One of those parameters is eos_token_id -- check if it matches
|
||||
self.assertNotEqual(generation_config_from_model.eos_token_id, default_generation_config.eos_token_id)
|
||||
self.assertEqual(generation_config_from_model.eos_token_id, model_config.eos_token_id)
|
||||
|
||||
def test_update(self):
|
||||
generation_config = GenerationConfig()
|
||||
update_kwargs = {
|
||||
"max_new_tokens": 1024,
|
||||
"foo": "bar",
|
||||
}
|
||||
update_kwargs_copy = copy.deepcopy(update_kwargs)
|
||||
unused_kwargs = generation_config.update(**update_kwargs)
|
||||
|
||||
# update_kwargs was not modified (no side effects)
|
||||
self.assertEqual(update_kwargs, update_kwargs_copy)
|
||||
|
||||
# update_kwargs was used to update the config on valid attributes
|
||||
self.assertEqual(generation_config.max_new_tokens, 1024)
|
||||
|
||||
# `.update()` returns a dictionary of unused kwargs
|
||||
self.assertEqual(unused_kwargs, {"foo": "bar"})
|
||||
|
||||
def test_kwarg_init(self):
|
||||
"""Tests that we can overwrite attributes at `from_pretrained` time."""
|
||||
default_config = GenerationConfig()
|
||||
self.assertEqual(default_config.temperature, 1.0)
|
||||
self.assertEqual(default_config.do_sample, False)
|
||||
self.assertEqual(default_config.num_beams, 1)
|
||||
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
bad_words_ids=[[1, 2, 3], [4, 5]],
|
||||
)
|
||||
self.assertEqual(config.temperature, 0.7)
|
||||
self.assertEqual(config.do_sample, True)
|
||||
self.assertEqual(config.num_beams, 1)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir)
|
||||
loaded_config = GenerationConfig.from_pretrained(tmp_dir, temperature=1.0)
|
||||
|
||||
self.assertEqual(loaded_config.temperature, 1.0)
|
||||
self.assertEqual(loaded_config.do_sample, True)
|
||||
self.assertEqual(loaded_config.num_beams, 1) # default value
|
||||
|
||||
def test_validate(self):
|
||||
"""
|
||||
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
|
||||
"""
|
||||
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
|
||||
|
||||
# A correct configuration will not throw any warning
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig()
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
|
||||
# parameters with `do_sample=False`). May be escalated to an error in the future.
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(return_dict_in_generate=False, output_scores=True)
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) # store for later
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
|
||||
# that is done by unsetting the parameter (i.e. setting it to None)
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
# BAD - 0.9 means it is still set, we should warn
|
||||
generation_config_bad_temperature.update(temperature=0.9)
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
|
||||
generation_config_bad_temperature.update(temperature=1.0)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
# OK - None means it is unset, nothing to warn about
|
||||
generation_config_bad_temperature.update(temperature=None)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Impossible sets of parameters will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
|
||||
|
||||
# Passing `generate()`-only flags to `validate` will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(logits_processor="foo")
|
||||
|
||||
# Model-specific parameters will NOT raise an exception or a warning
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(foo="bar")
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# By default we throw a short warning. However, we log with INFO level the details.
|
||||
# Default: we don't log the incorrect input values, only a short summary. We explain how to get more details.
|
||||
logger.warning_once.cache_clear()
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertNotIn("0.5", captured_logs.out)
|
||||
self.assertTrue(len(captured_logs.out) < 150) # short log
|
||||
self.assertIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
|
||||
|
||||
# INFO level: we share the full deets
|
||||
logger.warning_once.cache_clear()
|
||||
logger.info_once.cache_clear()
|
||||
with LoggingLevel(logging.INFO):
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertIn("0.5", captured_logs.out)
|
||||
self.assertTrue(len(captured_logs.out) > 400) # long log
|
||||
self.assertNotIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
|
||||
|
||||
# Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning.
|
||||
generation_config = GenerationConfig()
|
||||
generation_config.temperature = 0.5
|
||||
generation_config.do_sample = False
|
||||
with self.assertRaises(ValueError):
|
||||
generation_config.validate(strict=True)
|
||||
|
||||
def test_refuse_to_save(self):
|
||||
"""Tests that we refuse to save a generation config that fails validation."""
|
||||
|
||||
# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
|
||||
# is caught, doesn't save, and raises an exception
|
||||
config = GenerationConfig()
|
||||
config.temperature = 0.5
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(ValueError) as exc:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
|
||||
self.assertTrue("`temperature` is set to `0.5`" in str(exc.exception))
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||
|
||||
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
|
||||
# caught, doesn't save, and raises a warning
|
||||
config = GenerationConfig()
|
||||
config.num_return_sequences = 2
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(ValueError) as exc:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
|
||||
self.assertTrue(
|
||||
"Greedy methods without beam search do not support `num_return_sequences` different than 1"
|
||||
in str(exc.exception)
|
||||
)
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||
|
||||
# Final check: no logs at warning level/warnings/exceptions thrown if it is correct, and file is saved.
|
||||
config = GenerationConfig()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Catch warnings
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
# Catch logs (up to WARNING level, the default level)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
self.assertEqual(len(os.listdir(tmp_dir)), 1)
|
||||
|
||||
def test_generation_mode(self):
|
||||
"""Tests that the `get_generation_mode` method is working as expected."""
|
||||
config = GenerationConfig()
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.GREEDY_SEARCH)
|
||||
|
||||
config = GenerationConfig(do_sample=True)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.SAMPLE)
|
||||
|
||||
config = GenerationConfig(num_beams=2)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.BEAM_SEARCH)
|
||||
|
||||
# TODO joao, manuel: remove this in v4.62.0
|
||||
config = GenerationConfig(top_k=10, do_sample=False, penalty_alpha=0.6)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.CONTRASTIVE_SEARCH)
|
||||
|
||||
config = GenerationConfig()
|
||||
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)
|
||||
|
||||
def test_static_cache_without_cache_config(self):
|
||||
"""Regression test for #35026 -- static cache should work without a cache config."""
|
||||
config = GenerationConfig(cache_implementation="static")
|
||||
self.assertEqual(config.cache_implementation, "static")
|
||||
self.assertEqual(config.cache_config, None)
|
||||
|
||||
|
||||
class GenerationConfigSerializationTest(unittest.TestCase):
|
||||
def test_serialize_generation_sequence_bias(self):
|
||||
"""Tests that GenerationConfig is serialized and SequenceBiasLogitsProcessor is initialized with sequence_bias parameter"""
|
||||
generation_config = GenerationConfig()
|
||||
sequence_bias = [[[45, 67], -0.6], [[89], 1.2]]
|
||||
generation_config.sequence_bias = sequence_bias
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.sequence_bias, sequence_bias)
|
||||
|
||||
expected_sequence_bias = {(45, 67): -0.6, (89,): 1.2}
|
||||
bias_logits_processor = SequenceBiasLogitsProcessor(new_config.sequence_bias)
|
||||
self.assertDictEqual(bias_logits_processor.sequence_bias, expected_sequence_bias)
|
||||
|
||||
def test_serialize_generation_min_length_eos_token(self):
|
||||
"""Tests that GenerationConfig is serialized and MinLengthLogitsProcessor is initialized with min_length and eos_token_id"""
|
||||
eos_token_id = 0
|
||||
min_length = 10
|
||||
|
||||
generation_config = GenerationConfig(min_length=min_length, eos_token_id=eos_token_id)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.min_length, min_length)
|
||||
self.assertEqual(new_config.eos_token_id, eos_token_id)
|
||||
|
||||
min_dist_processor = MinLengthLogitsProcessor(
|
||||
min_length=new_config.min_length, eos_token_id=new_config.eos_token_id
|
||||
)
|
||||
self.assertEqual(min_dist_processor.min_length, min_length)
|
||||
self.assertEqual(min_dist_processor.eos_token_id, eos_token_id)
|
||||
|
||||
def test_serialize_generation_min_new_tokens(self):
|
||||
"""Tests that GenerationConfig is serialized and MinNewTokensLengthLogitsProcessor is initialized with min_new_tokens"""
|
||||
eos_token_id = 0
|
||||
min_new_tokens = 5
|
||||
prompt_length_to_skip = 2
|
||||
|
||||
generation_config = GenerationConfig(min_new_tokens=min_new_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.min_new_tokens, min_new_tokens)
|
||||
|
||||
min_new_tokens_processor = MinNewTokensLengthLogitsProcessor(
|
||||
prompt_length_to_skip=prompt_length_to_skip,
|
||||
min_new_tokens=new_config.min_new_tokens,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
self.assertEqual(min_new_tokens_processor.min_new_tokens, min_new_tokens)
|
||||
|
||||
def test_serialize_generation_temperature(self):
|
||||
"""Tests that GenerationConfig is serialized and TemperatureLogitsWarper is initialized with temperature"""
|
||||
temperature = 2.0
|
||||
|
||||
generation_config = GenerationConfig(temperature=temperature, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.temperature, temperature)
|
||||
|
||||
temperature_logits_warper = TemperatureLogitsWarper(temperature=new_config.temperature)
|
||||
self.assertEqual(temperature_logits_warper.temperature, temperature)
|
||||
|
||||
def test_serialize_generation_repetition_penalty(self):
|
||||
"""Tests that GenerationConfig is serialized and RepetitionPenaltyLogitsProcessor is initialized with repetition_penalty"""
|
||||
penalty = 2.0
|
||||
|
||||
generation_config = GenerationConfig(repetition_penalty=penalty)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.repetition_penalty, penalty)
|
||||
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=new_config.repetition_penalty)
|
||||
self.assertEqual(rep_penalty_proc.penalty, penalty)
|
||||
|
||||
def test_serialize_generation_encoder_repetition_penalty(self):
|
||||
"""Tests that GenerationConfig is serialized and EncoderRepetitionPenaltyLogitsProcessor is initialized with penalty and input_ids"""
|
||||
penalty = 2.0
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
|
||||
generation_config = GenerationConfig(encoder_repetition_penalty=penalty)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.encoder_repetition_penalty, penalty)
|
||||
|
||||
rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor(
|
||||
penalty=new_config.encoder_repetition_penalty, encoder_input_ids=input_ids
|
||||
)
|
||||
self.assertEqual(rep_penalty_proc.penalty, 1 / penalty)
|
||||
torch.testing.assert_close(rep_penalty_proc.encoder_input_ids, input_ids)
|
||||
|
||||
def test_serialize_generation_top_p(self):
|
||||
"""Tests that GenerationConfig is serialized and TopPLogitsWarper is initialized with top_p"""
|
||||
top_p = 0.8
|
||||
|
||||
generation_config = GenerationConfig(top_p=top_p, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.top_p, top_p)
|
||||
|
||||
rep_penalty_proc = TopPLogitsWarper(top_p=new_config.top_p)
|
||||
self.assertEqual(rep_penalty_proc.top_p, top_p)
|
||||
|
||||
def test_serialize_generation_top_k(self):
|
||||
"""Tests that GenerationConfig is serialized and TopKLogitsWarper is initialized with top_k"""
|
||||
top_k = 2
|
||||
|
||||
generation_config = GenerationConfig(top_k=top_k, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.top_k, top_k)
|
||||
|
||||
top_k_logits_wrap = TopKLogitsWarper(top_k=new_config.top_k)
|
||||
self.assertEqual(top_k_logits_wrap.top_k, top_k)
|
||||
|
||||
def test_serialize_generation_min_p(self):
|
||||
"""Tests that GenerationConfig is serialized and MinPLogitsWarper is initialized with min_p"""
|
||||
min_p = 0.8
|
||||
|
||||
generation_config = GenerationConfig(min_p=min_p, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.min_p, min_p)
|
||||
|
||||
min_k_logits_wrap = MinPLogitsWarper(min_p=new_config.min_p)
|
||||
self.assertEqual(min_k_logits_wrap.min_p, min_p)
|
||||
|
||||
def test_serialize_generation_typical_p(self):
|
||||
"""Tests that GenerationConfig is serialized and TypicalLogitsWarper is initialized with mass"""
|
||||
mass = 0.8
|
||||
|
||||
generation_config = GenerationConfig(typical_p=mass, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.typical_p, mass)
|
||||
|
||||
typical_p_logits_wrap = TypicalLogitsWarper(mass=new_config.typical_p)
|
||||
self.assertEqual(typical_p_logits_wrap.mass, mass)
|
||||
|
||||
def test_serialize_generation_epsilon_cutoff(self):
|
||||
"""Tests that GenerationConfig is serialized and EpsilonLogitsWarper is initialized with epsilon"""
|
||||
epsilon = 0.8
|
||||
|
||||
generation_config = GenerationConfig(epsilon_cutoff=epsilon, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.epsilon_cutoff, epsilon)
|
||||
|
||||
epsilon_logits_wrap = EpsilonLogitsWarper(epsilon=new_config.epsilon_cutoff)
|
||||
self.assertEqual(epsilon_logits_wrap.epsilon, epsilon)
|
||||
|
||||
def test_serialize_generation_eta_cutoff(self):
|
||||
"""Tests that GenerationConfig is serialized and EtaLogitsWarper is initialized with epsilon"""
|
||||
epsilon = 0.8
|
||||
|
||||
generation_config = GenerationConfig(eta_cutoff=epsilon, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.eta_cutoff, epsilon)
|
||||
|
||||
eta_logits_wrap = EtaLogitsWarper(epsilon=new_config.eta_cutoff)
|
||||
self.assertEqual(eta_logits_wrap.epsilon, epsilon)
|
||||
|
||||
def test_serialize_generation_ngram_size(self):
|
||||
"""Tests that GenerationConfig is serialized and NoRepeatNGramLogitsProcessor is initialized with ngram_size"""
|
||||
ngram_size = 2
|
||||
|
||||
generation_config = GenerationConfig(no_repeat_ngram_size=ngram_size, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.no_repeat_ngram_size, ngram_size)
|
||||
|
||||
no_repeat_ngram_proc = NoRepeatNGramLogitsProcessor(ngram_size=new_config.no_repeat_ngram_size)
|
||||
self.assertEqual(no_repeat_ngram_proc.ngram_size, ngram_size)
|
||||
|
||||
def test_serialize_generation_encoder_ngram_size(self):
|
||||
"""Tests that GenerationConfig is serialized and EncoderNoRepeatNGramLogitsProcessor is initialized with ngram_size"""
|
||||
ngram_size = 2
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
|
||||
generation_config = GenerationConfig(encoder_no_repeat_ngram_size=ngram_size, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.encoder_no_repeat_ngram_size, ngram_size)
|
||||
|
||||
encoder_no_repeat_ngram_proc = EncoderNoRepeatNGramLogitsProcessor(
|
||||
encoder_ngram_size=new_config.encoder_no_repeat_ngram_size, encoder_input_ids=input_ids
|
||||
)
|
||||
self.assertEqual(encoder_no_repeat_ngram_proc.ngram_size, ngram_size)
|
||||
|
||||
def test_serialize_generation_bad_words_ids(self):
|
||||
"""Tests that GenerationConfig is serialized and NoBadWordsLogitsProcessor is initialized with bad_words_ids"""
|
||||
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
|
||||
|
||||
generation_config = GenerationConfig(bad_words_ids=bad_word_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.bad_words_ids, bad_word_tokens)
|
||||
|
||||
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=new_config.bad_words_ids)
|
||||
self.assertSequenceEqual(no_bad_words_dist_proc.bad_word_ids, bad_word_tokens)
|
||||
|
||||
def test_serialize_generation_num_beams(self):
|
||||
"""Tests that GenerationConfig is serialized and PrefixConstrainedLogitsProcessor is initialized with num_beams"""
|
||||
num_beams = 1
|
||||
|
||||
def prefix_allowed_tokens_fn(batch_id, inputs_ids):
|
||||
return [[0, 1], [2, 3]][batch_id]
|
||||
|
||||
generation_config = GenerationConfig(num_beams=num_beams)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.num_beams, num_beams)
|
||||
|
||||
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(
|
||||
prefix_allowed_tokens_fn, num_beams=new_config.num_beams
|
||||
)
|
||||
self.assertEqual(prefix_constrained_logits_proc._num_beams, num_beams)
|
||||
|
||||
def test_serialize_generation_bos_token_id(self):
|
||||
"""Tests that GenerationConfig is serialized and ForcedBOSTokenLogitsProcessor is initialized with bos_token_id"""
|
||||
bos_token_id = 0
|
||||
|
||||
generation_config = GenerationConfig(bos_token_id=bos_token_id)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.bos_token_id, bos_token_id)
|
||||
|
||||
logits_processor = ForcedBOSTokenLogitsProcessor(bos_token_id=new_config.bos_token_id)
|
||||
self.assertEqual(logits_processor.bos_token_id, bos_token_id)
|
||||
|
||||
def test_serialize_generation_eos_token_id(self):
|
||||
"""Tests that GenerationConfig is serialized and ForcedEOSTokenLogitsProcessor is initialized with eos_token_id"""
|
||||
eos_token_id = 0
|
||||
max_length = 5
|
||||
|
||||
generation_config = GenerationConfig(eos_token_id=eos_token_id)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.eos_token_id, eos_token_id)
|
||||
|
||||
logits_processor = ForcedEOSTokenLogitsProcessor(
|
||||
max_length=max_length, eos_token_id=new_config.eos_token_id, device=torch_device
|
||||
)
|
||||
self.assertEqual(logits_processor.eos_token_id, eos_token_id)
|
||||
|
||||
def test_serialize_generation_exponential_decay_length_penalty(self):
|
||||
"""Tests that GenerationConfig is serialized and ExponentialDecayLengthPenalty is initialized with regulation_start and regulation_factor"""
|
||||
eos_token_id = 0
|
||||
penalty_start = 5
|
||||
penalty_factor = 1.1
|
||||
input_ids_seq_length = 10
|
||||
exponential_decay_length_penalty = (penalty_start, penalty_factor)
|
||||
|
||||
generation_config = GenerationConfig(exponential_decay_length_penalty=exponential_decay_length_penalty)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.exponential_decay_length_penalty, [penalty_start, penalty_factor])
|
||||
|
||||
exponential_decay_processor = ExponentialDecayLengthPenalty(
|
||||
exponential_decay_length_penalty=new_config.exponential_decay_length_penalty,
|
||||
eos_token_id=eos_token_id,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
)
|
||||
self.assertEqual(
|
||||
exponential_decay_processor.regulation_start, exponential_decay_length_penalty[0] + input_ids_seq_length
|
||||
)
|
||||
self.assertEqual(exponential_decay_processor.regulation_factor, exponential_decay_length_penalty[1])
|
||||
|
||||
def test_serialize_generation_begin_suppress_tokens(self):
|
||||
"""Tests that GenerationConfig is serialized and SuppressTokensAtBeginLogitsProcessor is initialized with begin_suppress_token and begin_index"""
|
||||
|
||||
begin_suppress_tokens = [220, 50256]
|
||||
begin_index = 0
|
||||
generation_config = GenerationConfig(begin_suppress_tokens=begin_suppress_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.begin_suppress_tokens, begin_suppress_tokens)
|
||||
|
||||
suppress_processor = SuppressTokensAtBeginLogitsProcessor(
|
||||
begin_suppress_tokens=new_config.begin_suppress_tokens, begin_index=begin_index
|
||||
)
|
||||
self.assertSequenceEqual(suppress_processor.begin_suppress_tokens, begin_suppress_tokens)
|
||||
self.assertEqual(suppress_processor.begin_index, begin_index)
|
||||
|
||||
def test_serialize_generation_suppress_tokens(self):
|
||||
"""Tests that GenerationConfig is serialized and SuppressTokensLogitsProcessor is initialized with suppress_token"""
|
||||
suppress_tokens = [220, 50256]
|
||||
|
||||
generation_config = GenerationConfig(suppress_tokens=suppress_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.suppress_tokens, suppress_tokens)
|
||||
|
||||
suppress_processor = SuppressTokensLogitsProcessor(suppress_tokens=new_config.suppress_tokens)
|
||||
self.assertSequenceEqual(suppress_processor.suppress_tokens, suppress_tokens)
|
||||
|
||||
def test_serialize_generation_guidance_scale(self):
|
||||
"""Tests that GenerationConfig is serialized and ClassifierFreeGuidanceLogitsProcessor is initialized with guidance_scale"""
|
||||
guidance_scale = 2.0
|
||||
generation_config = GenerationConfig(guidance_scale=guidance_scale)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.guidance_scale, guidance_scale)
|
||||
|
||||
classifier_processor = ClassifierFreeGuidanceLogitsProcessor(guidance_scale=new_config.guidance_scale)
|
||||
self.assertEqual(classifier_processor.guidance_scale, guidance_scale)
|
||||
|
||||
def test_serialize_generation_guidance_scale_unbatched(self):
|
||||
"""Tests that GenerationConfig is serialized and UnbatchedClassifierFreeGuidanceLogitsProcessor is initialized with guidance_scale"""
|
||||
guidance_scale = 2.0
|
||||
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
|
||||
generation_config = GenerationConfig(guidance_scale=guidance_scale)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.guidance_scale, guidance_scale)
|
||||
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(new_config.guidance_scale, {}, input_ids)
|
||||
self.assertEqual(cfg.guidance_scale, guidance_scale)
|
||||
|
||||
def test_serialize_generation_watermarking_config(self):
|
||||
"""Tests that GenerationConfig is serialized and WatermarkLogitsProcessor is initialized with WatermarkingConfig parameters"""
|
||||
|
||||
vocab_size = 20
|
||||
bias = 2.0
|
||||
greenlist_ratio = 0.5
|
||||
hashing_key = 10
|
||||
seeding_scheme = "lefthash"
|
||||
context_width = 10
|
||||
watermarking_config = WatermarkingConfig(
|
||||
bias=bias,
|
||||
greenlist_ratio=greenlist_ratio,
|
||||
hashing_key=hashing_key,
|
||||
seeding_scheme=seeding_scheme,
|
||||
context_width=context_width,
|
||||
)
|
||||
generation_config = GenerationConfig(watermarking_config=watermarking_config)
|
||||
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.watermarking_config.bias, bias)
|
||||
self.assertEqual(new_config.watermarking_config.greenlist_ratio, greenlist_ratio)
|
||||
self.assertEqual(new_config.watermarking_config.hashing_key, hashing_key)
|
||||
self.assertEqual(new_config.watermarking_config.seeding_scheme, seeding_scheme)
|
||||
self.assertEqual(new_config.watermarking_config.context_width, context_width)
|
||||
|
||||
watermark = WatermarkLogitsProcessor(
|
||||
vocab_size=vocab_size,
|
||||
device=torch_device,
|
||||
greenlist_ratio=new_config.watermarking_config.greenlist_ratio,
|
||||
bias=new_config.watermarking_config.bias,
|
||||
hashing_key=new_config.watermarking_config.hashing_key,
|
||||
seeding_scheme=new_config.watermarking_config.seeding_scheme,
|
||||
context_width=new_config.watermarking_config.context_width,
|
||||
)
|
||||
self.assertEqual(watermark.bias, bias)
|
||||
self.assertEqual(watermark.greenlist_size, int(vocab_size * greenlist_ratio))
|
||||
self.assertEqual(watermark.hash_key, hashing_key)
|
||||
self.assertEqual(watermark.seeding_scheme, seeding_scheme)
|
||||
self.assertEqual(watermark.context_width, context_width)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
|
||||
def test_push_to_hub(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
)
|
||||
config.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_config = GenerationConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_config = GenerationConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
)
|
||||
config.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_config = GenerationConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_in_organization_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_config = GenerationConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_on_pr_revision(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
# create a PR
|
||||
pr = create_pull_request(repo_id=tmp_repo.repo_id, title="Test PR", token=self._token)
|
||||
revision = f"refs/pr/{pr.num}"
|
||||
|
||||
# push to PR ref
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
)
|
||||
config.push_to_hub(tmp_repo.repo_id, token=self._token, revision=revision)
|
||||
|
||||
# load from PR ref
|
||||
new_config = GenerationConfig.from_pretrained(tmp_repo.repo_id, revision=revision)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
290
transformers/tests/generation/test_continuous_batching.py
Normal file
290
transformers/tests/generation/test_continuous_batching.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# Copyright 2025 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.generation.continuous_batching.cache import group_layers_by_attn_type
|
||||
from transformers.testing_utils import Expectations, require_kernels, require_torch_gpu, slow
|
||||
|
||||
|
||||
ALLOW_EXPECTED_OUTPUTS = True # this is a debug flag when you want to measure deviation between CB and non-CB gen
|
||||
|
||||
|
||||
class ContinuousBatchingTest(unittest.TestCase):
|
||||
@parameterized.expand(
|
||||
[
|
||||
(None, None, "0"),
|
||||
(None, 4096, "0"),
|
||||
("f", None, "0"),
|
||||
("ffff", None, "0000"),
|
||||
("sssss", 4096, "00000"),
|
||||
("fs", 4096, "01"),
|
||||
("ssfssf", 4096, "001221"),
|
||||
("ssssf", 4096, "01234"),
|
||||
("fffsffs", 4096, "0123456"),
|
||||
]
|
||||
)
|
||||
def test_group_layers(
|
||||
self,
|
||||
layer_types_str: Optional[str],
|
||||
sliding_window: Optional[int],
|
||||
expected_groups: str,
|
||||
) -> None:
|
||||
# Take a config and change the layer_types attribute to the mix we want
|
||||
config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B")
|
||||
|
||||
if layer_types_str is not None:
|
||||
layer_types = [{"f": "full_attention", "s": "sliding_window"}[char] for char in layer_types_str]
|
||||
else:
|
||||
layer_types = None
|
||||
config.num_hidden_layers = len(expected_groups)
|
||||
|
||||
config.layer_types = layer_types
|
||||
config.sliding_window = sliding_window
|
||||
|
||||
expected_lg = {}
|
||||
for i, group in enumerate(expected_groups):
|
||||
group = int(group)
|
||||
expected_lg[group] = expected_lg.get(group, []) + [i]
|
||||
expected_layer_groups = [expected_lg[i] for i in sorted(expected_lg.keys())]
|
||||
|
||||
# Test layer groups formation
|
||||
layer_groups, group_types = group_layers_by_attn_type(config)
|
||||
self.assertEqual(
|
||||
sorted(expected_layer_groups),
|
||||
sorted(layer_groups),
|
||||
f"Test failed for: {layer_types_str = }, {sliding_window = }, {expected_layer_groups = }, {layer_groups = }",
|
||||
)
|
||||
|
||||
# If layer_types is provided, check that group_types matches the type of the all layers in each group
|
||||
if layer_types is not None:
|
||||
for layer_group, group_type in zip(layer_groups, group_types):
|
||||
layer_types = [config.layer_types[i] for i in layer_group]
|
||||
self.assertEqual(layer_types, [group_type] * len(layer_types))
|
||||
# If layer_types is None, all groups should be of the same type
|
||||
else:
|
||||
for group_type in group_types:
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
expected_group_type = "sliding_attention" if sliding_window is not None else "full_attention"
|
||||
self.assertEqual(
|
||||
group_type,
|
||||
expected_group_type,
|
||||
f"Test failed for: {layer_types_str = }, {sliding_window = }, {group_types = }",
|
||||
)
|
||||
|
||||
def _continuous_batching_parity(
|
||||
self, model_id: str, attn_implementation: str, expected_outputs: dict[str, str]
|
||||
) -> None:
|
||||
# Prepare common elements
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
||||
prompts = [
|
||||
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her "
|
||||
"friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh "
|
||||
"duck egg. How much in dollars does she make every day at the farmers' market? The answer is:",
|
||||
"A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take? "
|
||||
"The answer is:",
|
||||
"Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. "
|
||||
"This increased the value of the house by 150%. How much profit did he make? The answer is:",
|
||||
] # fmt: skip
|
||||
batched_inputs = [tokenizer.encode(prompt) for prompt in prompts]
|
||||
|
||||
# Generation with continuous batching
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation, dtype="auto")
|
||||
model = model.cuda().eval()
|
||||
model.generation_config.max_new_tokens = 40
|
||||
model.generation_config.do_sample = False
|
||||
model.generation_config.use_cuda_graph = False
|
||||
|
||||
cb_outputs = model.generate_batch(inputs=batched_inputs, generation_config=model.generation_config)
|
||||
|
||||
# Generation without continuous batching
|
||||
if attn_implementation == "sdpa_paged":
|
||||
non_cb_attn_implementation = "sdpa"
|
||||
elif attn_implementation == "eager_paged":
|
||||
non_cb_attn_implementation = "eager"
|
||||
elif attn_implementation == "paged_attention|kernels-community/flash-attn":
|
||||
non_cb_attn_implementation = "eager"
|
||||
else:
|
||||
raise ValueError(f"Invalid attention implementation: {attn_implementation}")
|
||||
|
||||
# We regenerate the model because just changing the attn_implementation does not work
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation=non_cb_attn_implementation, dtype="auto"
|
||||
)
|
||||
model = model.cuda().eval()
|
||||
model.generation_config.max_new_tokens = 40
|
||||
model.generation_config.do_sample = False
|
||||
model.generation_config.use_cuda_graph = False
|
||||
|
||||
for request_id, request in cb_outputs.items():
|
||||
# Generate without continuous batching
|
||||
input_ids = torch.tensor([request.prompt_ids]).cuda()
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
outputs = model.generate(
|
||||
input_ids, attention_mask=attention_mask, generation_config=model.generation_config
|
||||
)
|
||||
generated_tokens = outputs[0][input_ids.shape[1] :]
|
||||
non_cb_decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||
input_ids = input_ids.tolist()[0]
|
||||
|
||||
# Check that the generated output with and without CB match
|
||||
cb_decoded_output = tokenizer.decode(request.generated_tokens, skip_special_tokens=True)
|
||||
outputs_match = non_cb_decoded_output == cb_decoded_output
|
||||
|
||||
# If they dont, that might be expected: the outputs can differ slightly due to numerical differences
|
||||
# If that's the case, there is an expected output ready
|
||||
if not outputs_match:
|
||||
expected_output = expected_outputs.get(request_id) if ALLOW_EXPECTED_OUTPUTS else None
|
||||
|
||||
if expected_output is None:
|
||||
self.fail(
|
||||
f"Test {request_id = } failed, no expected output was provided.\nRef:"
|
||||
f"{repr(non_cb_decoded_output)}\nOut:{repr(cb_decoded_output)}"
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
expected_output,
|
||||
cb_decoded_output,
|
||||
msg=f"Test {request_id = } failed, expected output did not match.\n"
|
||||
f"Exp:{repr(expected_output)}\nOut:{repr(cb_decoded_output)}",
|
||||
)
|
||||
|
||||
# Eager tests
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_continuous_batching_parity_llama_eager(self) -> None:
|
||||
expected_outputs = Expectations({
|
||||
("rocm", (9, 4)): {
|
||||
"req_0": " $16. How did I get that answer? I used the following equation: 16 - 3 - 4 = 9. 9 x $2 = $18. $18 -"
|
||||
},
|
||||
("cuda", (9, 0)): {
|
||||
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5. The total number of bolts is 4.5. The total",
|
||||
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
|
||||
}
|
||||
}).get_expectation() # fmt: skip
|
||||
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "eager_paged", expected_outputs)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_continuous_batching_parity_gemma_eager(self) -> None:
|
||||
expected_outputs = Expectations({
|
||||
("rocm", (9, 4)): {
|
||||
"req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 ="
|
||||
},
|
||||
("cuda", (9, 0)): {
|
||||
"req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13",
|
||||
"req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n "
|
||||
}
|
||||
}).get_expectation() # fmt: skip
|
||||
self._continuous_batching_parity("google/gemma-2-2b-it", "eager_paged", expected_outputs)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_continuous_batching_parity_qwen_eager(self) -> None:
|
||||
expected_outputs = {}
|
||||
self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "eager_paged", expected_outputs)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_continuous_batching_parity_gpt_oss_eager(self) -> None:
|
||||
expected_outputs = Expectations({
|
||||
("cuda", (9, 0)): {
|
||||
"req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The",
|
||||
"req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%."
|
||||
}
|
||||
}).get_expectation() # fmt: skip
|
||||
self._continuous_batching_parity("openai/gpt-oss-20b", "eager_paged", expected_outputs)
|
||||
|
||||
# SDPA tests
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_continuous_batching_parity_llama_sdpa(self) -> None:
|
||||
expected_outputs = Expectations({
|
||||
("rocm", (9, 4)): {
|
||||
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
|
||||
}
|
||||
}).get_expectation() # fmt: skip
|
||||
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "sdpa_paged", expected_outputs)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_continuous_batching_parity_gemma_sdpa(self) -> None:
|
||||
expected_outputs = Expectations({
|
||||
("cuda", (9, 0)): {
|
||||
"req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =",
|
||||
}
|
||||
}).get_expectation() # fmt: skip
|
||||
self._continuous_batching_parity("google/gemma-2-2b-it", "sdpa_paged", expected_outputs)
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_continuous_batching_parity_qwen_sdpa(self) -> None:
|
||||
expected_outputs = {}
|
||||
self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "sdpa_paged", expected_outputs)
|
||||
|
||||
# GPT-OSS is not compatible with SDPA because it has an attention sink. TODO: is this fixable?
|
||||
|
||||
# Flash attention test
|
||||
@require_torch_gpu
|
||||
@require_kernels
|
||||
@slow
|
||||
def test_continuous_batching_parity_llama_flash(self) -> None:
|
||||
expected_outputs = Expectations({
|
||||
("cuda", (9, 0)): {
|
||||
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5 bolts. The total number of bolts is 4.5 bolts.",
|
||||
}
|
||||
}).get_expectation() # fmt: skip
|
||||
self._continuous_batching_parity(
|
||||
"meta-llama/Llama-3.1-8B", "paged_attention|kernels-community/flash-attn", expected_outputs
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_kernels
|
||||
@slow
|
||||
def test_continuous_batching_parity_gemma_flash(self) -> None:
|
||||
expected_outputs = Expectations({
|
||||
("cuda", (9, 0)): {
|
||||
"req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n ",
|
||||
}
|
||||
}).get_expectation() # fmt: skip
|
||||
self._continuous_batching_parity(
|
||||
"google/gemma-2-2b-it", "paged_attention|kernels-community/flash-attn", expected_outputs
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_kernels
|
||||
@slow
|
||||
def test_continuous_batching_parity_qwen_flash(self) -> None:
|
||||
expected_outputs = {}
|
||||
self._continuous_batching_parity(
|
||||
"Qwen/Qwen3-4B-Instruct-2507", "paged_attention|kernels-community/flash-attn", expected_outputs
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_kernels
|
||||
@slow
|
||||
def test_continuous_batching_parity_gpt_oss_flash(self) -> None:
|
||||
expected_outputs = {}
|
||||
self._continuous_batching_parity(
|
||||
"openai/gpt-oss-20b", "paged_attention|kernels-community/flash-attn", expected_outputs
|
||||
)
|
||||
|
||||
|
||||
# FIXME: the gemma test seem broken, there is a message about cuda graphs and the sdpa and flash expecteations are
|
||||
# inverted on CUDA. On AMD they do fine.
|
||||
144
transformers/tests/generation/test_flash_attention_parity.py
Normal file
144
transformers/tests/generation/test_flash_attention_parity.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright 2025 Eduard Durech and SGLang team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Usage:
|
||||
# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.testing_utils import require_flash_attn, require_flash_attn_3, require_torch_gpu, slow
|
||||
|
||||
|
||||
class FlashAttentionParityTest(unittest.TestCase):
|
||||
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
|
||||
def _lcs(self, X, Y):
|
||||
m = len(X)
|
||||
n = len(Y)
|
||||
L = [[0] * (n + 1) for _ in range(m + 1)]
|
||||
|
||||
for i in range(m + 1):
|
||||
for j in range(n + 1):
|
||||
if i == 0 or j == 0:
|
||||
L[i][j] = 0
|
||||
elif X[i - 1] == Y[j - 1]:
|
||||
L[i][j] = L[i - 1][j - 1] + 1
|
||||
else:
|
||||
L[i][j] = max(L[i - 1][j], L[i][j - 1])
|
||||
|
||||
return L[m][n]
|
||||
|
||||
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
|
||||
def _calculate_rouge_l(self, output_strs_list1, output_strs_list2):
|
||||
rouge_l_scores = []
|
||||
|
||||
for s1, s2 in zip(output_strs_list1, output_strs_list2):
|
||||
lcs_len = self._lcs(s1, s2)
|
||||
precision = lcs_len / len(s1) if len(s1) > 0 else 0
|
||||
recall = lcs_len / len(s2) if len(s2) > 0 else 0
|
||||
if precision + recall > 0:
|
||||
fmeasure = (2 * precision * recall) / (precision + recall)
|
||||
else:
|
||||
fmeasure = 0.0
|
||||
rouge_l_scores.append(fmeasure)
|
||||
|
||||
return rouge_l_scores
|
||||
|
||||
def _benchmark_generation(self, model, inputs, n_warmup=3, n_runs=5):
|
||||
for _ in range(n_warmup):
|
||||
model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_time = torch.cuda.Event(enable_timing=True)
|
||||
end_time = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_time.record()
|
||||
for _ in range(n_runs):
|
||||
model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
end_time.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return start_time.elapsed_time(end_time) / n_runs
|
||||
|
||||
@pytest.mark.flash_attn_3_test
|
||||
@require_torch_gpu
|
||||
@require_flash_attn
|
||||
@require_flash_attn_3
|
||||
@slow
|
||||
def test_flash_attention_2_3_parity(self):
|
||||
model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
prompt = "The ETH AI Center is"
|
||||
|
||||
# 1. Load FA2 model and tokenizer
|
||||
model_2 = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
).to("cuda")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# 2. Load FA3 model
|
||||
try:
|
||||
model_3 = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_3",
|
||||
).to("cuda")
|
||||
except (ValueError, ImportError) as e:
|
||||
pytest.skip(f"Could not load Flash Attention 3 model, skipping test. Error: {e}")
|
||||
|
||||
# 3. Generate with both models
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
output_2 = model_2.generate(
|
||||
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
|
||||
)
|
||||
output_3 = model_3.generate(
|
||||
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
|
||||
)
|
||||
|
||||
# 4. Correctness check
|
||||
# 4a. Logits
|
||||
logits_2 = torch.stack(output_2.scores)
|
||||
logits_3 = torch.stack(output_3.scores)
|
||||
torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3)
|
||||
logprobs_2 = torch.nn.functional.log_softmax(logits_2, dim=-1)
|
||||
logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1)
|
||||
max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item()
|
||||
|
||||
# 4b. Generated text
|
||||
text_2 = tokenizer.decode(output_2.sequences[0], skip_special_tokens=True)
|
||||
text_3 = tokenizer.decode(output_3.sequences[0], skip_special_tokens=True)
|
||||
rouge_score = self._calculate_rouge_l([text_2], [text_3])[0]
|
||||
assert rouge_score > 0.99, f"Generated texts do not match (ROUGE-L: {rouge_score})"
|
||||
|
||||
# 5. Performance check
|
||||
with torch.no_grad():
|
||||
time_2 = self._benchmark_generation(model_2, inputs)
|
||||
time_3 = self._benchmark_generation(model_3, inputs)
|
||||
|
||||
print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---")
|
||||
print(f"Prompt: '{prompt}'")
|
||||
print(f"Generated text with Flash Attention 2: {text_2}")
|
||||
print(f"Generated text with Flash Attention 3: {text_3}")
|
||||
print(f"ROUGE-L: {rouge_score}")
|
||||
print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}")
|
||||
print(f"Flash Attention 2 latency: {time_2:.2f} ms")
|
||||
print(f"Flash Attention 3 latency: {time_3:.2f} ms")
|
||||
print(f"Speed-up: {time_2 / time_3:.2f}x")
|
||||
print("---")
|
||||
193
transformers/tests/generation/test_fsdp.py
Normal file
193
transformers/tests/generation/test_fsdp.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import textwrap
|
||||
from typing import Any, Callable
|
||||
|
||||
from transformers import is_torch_available, is_torch_xpu_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
backend_device_count,
|
||||
backend_torch_accelerator_module,
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_accelerator,
|
||||
torch_device,
|
||||
torchrun,
|
||||
)
|
||||
from transformers.utils import is_ccl_available, is_ipex_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
if is_torch_xpu_available():
|
||||
if is_ipex_available():
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
if is_ccl_available():
|
||||
import oneccl_bindings_for_pytorch # noqa: F401
|
||||
import torch.distributed
|
||||
from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
|
||||
data = 4 * [
|
||||
"Hello world!",
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
]
|
||||
|
||||
def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
|
||||
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
device_count = backend_device_count(torch_device)
|
||||
torch.distributed.init_process_group(world_size=device_count)
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
return wrapped
|
||||
|
||||
@manage_process_group
|
||||
def fsdp_generate():
|
||||
torch_accelerator_module = backend_torch_accelerator_module(torch_device)
|
||||
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
||||
|
||||
fsdp_model = FullyShardedDataParallel(
|
||||
model,
|
||||
auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={GPT2Block}),
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
|
||||
|
||||
with FullyShardedDataParallel.summon_full_params(fsdp_model):
|
||||
_ = fsdp_model.module.generate(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
@manage_process_group
|
||||
def fsdp2_generate():
|
||||
torch_accelerator_module = backend_torch_accelerator_module(torch_device)
|
||||
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
||||
|
||||
mesh = init_device_mesh(device.type, (torch.distributed.get_world_size(),))
|
||||
for submodule in model.modules():
|
||||
if isinstance(submodule, GPT2Block):
|
||||
fully_shard(submodule, mesh=mesh)
|
||||
fully_shard(model, mesh=mesh)
|
||||
|
||||
register_fsdp_forward_method(model, "generate")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
|
||||
|
||||
_ = model.generate(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
|
||||
class TestFSDPGeneration(TestCasePlus):
|
||||
@require_torch_multi_accelerator
|
||||
def test_fsdp_generate(self):
|
||||
device_count = backend_device_count(torch_device)
|
||||
distributed_args = f"""--nproc_per_node={device_count}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_fsdp.py
|
||||
""".split()
|
||||
args = ["--fsdp"]
|
||||
cmd = ["torchrun"] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_fsdp2_generate(self):
|
||||
device_count = backend_device_count(torch_device)
|
||||
|
||||
distributed_args = f"""--nproc_per_node={device_count}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_fsdp.py
|
||||
""".split()
|
||||
args = ["--fsdp2"]
|
||||
cmd = ["torchrun"] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
class TestFSDPGenericTaskModel(TestCasePlus):
|
||||
nproc_per_node = 2
|
||||
|
||||
def test_generic_task_model_can_be_sharded(self):
|
||||
script_to_run = textwrap.dedent(
|
||||
"""
|
||||
import torch
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl" if torch.cuda.is_available() else "gloo", init_method="env://"
|
||||
)
|
||||
rank = torch.distributed.get_rank()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
# Make sure it works
|
||||
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B")
|
||||
module = fully_shard(model)
|
||||
|
||||
torch.distributed.destroy_process_group()
|
||||
"""
|
||||
)
|
||||
torchrun(script_to_run, self.nproc_per_node, env=self.get_env())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
|
||||
#
|
||||
# PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 --output_dir output_dir ./tests/generation/test_fsdp.py --fsdp
|
||||
|
||||
class CLIArgs(argparse.Namespace):
|
||||
fsdp: bool
|
||||
fsdp2: bool
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--fsdp", action="store_true")
|
||||
group.add_argument("--fsdp2", action="store_true")
|
||||
args = parser.parse_args(namespace=CLIArgs())
|
||||
|
||||
if args.fsdp:
|
||||
fsdp_generate()
|
||||
elif args.fsdp2:
|
||||
fsdp2_generate()
|
||||
else:
|
||||
raise ValueError("Missing test selection")
|
||||
1362
transformers/tests/generation/test_logits_process.py
Normal file
1362
transformers/tests/generation/test_logits_process.py
Normal file
File diff suppressed because it is too large
Load Diff
149
transformers/tests/generation/test_paged_attention.py
Normal file
149
transformers/tests/generation/test_paged_attention.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
from transformers.testing_utils import require_flash_attn, require_torch_gpu, slow
|
||||
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
"A man is a walking his dog down the street, and a the turn he sees",
|
||||
"Describe a fruit that is of orange color and round. It is a sweet fruit and a great source of Vitamine C. The fruit I'm thinking of is an",
|
||||
"A plane is flying high in the sky, out of the window are clouds and mountains. Where could the plane be located?",
|
||||
"Please fill in the form to",
|
||||
"For safety reasons, the train is stopped in the middle of the",
|
||||
]
|
||||
|
||||
_EXPECTED_OUTPUTS = [
|
||||
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes up a conversation, and they quickly discover that they have a lot in common. They exchange numbers and",
|
||||
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n## Step 2: Determine the taste and nutritional value of the fruit\nThe fruit is described as sweet",
|
||||
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer is not a straightforward one, and it requires some lateral thinking to arrive at the correct solution.",
|
||||
"get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]\n[Your Message]\n\nWe are looking forward to hearing from you!\n\n[Insert Contact Information]\n\nNote:",
|
||||
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers does the train travel in 30 minutes?\n## Step 1: Convert the speed from km/h to km/min",
|
||||
]
|
||||
|
||||
|
||||
@slow
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
class TestBatchGeneration(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.2-3b-Instruct", dtype="bfloat16", device_map="auto"
|
||||
).eval()
|
||||
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3b-Instruct", padding_side="left")
|
||||
|
||||
if cls.tokenizer.pad_token is None:
|
||||
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
||||
cls.model.config.pad_token_id = cls.model.config.eos_token_id
|
||||
|
||||
cls.model.use_cache = False
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("eager_paged", 64, 128, 64),
|
||||
("sdpa_paged", 32, 256, 128),
|
||||
("paged_attention", 16, 512, 256),
|
||||
("flex_paged", 64, 128, 64),
|
||||
]
|
||||
)
|
||||
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
|
||||
self.model.config.attn_implementation = attn_impl
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=50,
|
||||
top_k=0,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
max_batch_tokens=max_batch_tokens,
|
||||
)
|
||||
|
||||
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512)
|
||||
batch_inputs = list(tokenized["input_ids"])
|
||||
|
||||
start = time.time()
|
||||
batch_outputs = self.model.generate_batch(
|
||||
inputs=batch_inputs,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
end = time.time()
|
||||
print(
|
||||
f"\n[{attn_impl}] Batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
|
||||
)
|
||||
|
||||
for i, req_id in enumerate(batch_outputs):
|
||||
generated = self.tokenizer.decode(
|
||||
batch_outputs[req_id].generated_tokens, skip_special_tokens=False
|
||||
).strip()
|
||||
expected = _EXPECTED_OUTPUTS[i].strip()
|
||||
self.assertTrue(
|
||||
generated.startswith(expected),
|
||||
msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected start: {expected}\nGot: {generated}",
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("eager_paged", 64, 128, 64),
|
||||
("sdpa_paged", 32, 256, 128),
|
||||
("paged_attention", 16, 512, 256),
|
||||
("flex_paged", 64, 128, 64),
|
||||
]
|
||||
)
|
||||
def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens):
|
||||
"""Test batch generation with do_sampling=True to verify sampling works correctly."""
|
||||
self.model.config.attn_implementation = attn_impl
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=30,
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
temperature=0.8,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
max_batch_tokens=max_batch_tokens,
|
||||
)
|
||||
|
||||
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512) # Use fewer prompts for faster test
|
||||
batch_inputs = list(tokenized["input_ids"])
|
||||
|
||||
start = time.time()
|
||||
batch_outputs = self.model.generate_batch(
|
||||
inputs=batch_inputs,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
end = time.time()
|
||||
print(
|
||||
f"\n[{attn_impl}] Sampling batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
|
||||
)
|
||||
|
||||
# With sampling enabled, we can't check exact outputs, but we should verify:
|
||||
# 1. All requests completed successfully
|
||||
# 2. Generated text is non-empty
|
||||
# 3. Generated text is different from greedy (demonstrating sampling is working)
|
||||
self.assertEqual(len(batch_outputs), len(batch_inputs), f"[{attn_impl}] Not all requests completed")
|
||||
|
||||
for i, req_id in enumerate(batch_outputs):
|
||||
generated = self.tokenizer.decode(
|
||||
batch_outputs[req_id].generated_tokens, skip_special_tokens=False
|
||||
).strip()
|
||||
self.assertTrue(
|
||||
len(generated) > 0,
|
||||
msg=f"[{attn_impl}] Empty output for request {i}",
|
||||
)
|
||||
# Check that we got at least some tokens generated
|
||||
generated_tokens = batch_outputs[req_id].generated_tokens
|
||||
self.assertGreater(
|
||||
len(generated_tokens),
|
||||
0,
|
||||
msg=f"[{attn_impl}] No tokens generated for request {i}",
|
||||
)
|
||||
289
transformers/tests/generation/test_stopping_criteria.py
Normal file
289
transformers/tests/generation/test_stopping_criteria.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
from ..test_modeling_common import ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation import (
|
||||
ConfidenceCriteria,
|
||||
EosTokenCriteria,
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
StoppingCriteriaList,
|
||||
StopStringCriteria,
|
||||
validate_stopping_criteria,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
def _get_tensors(self, length):
|
||||
batch_size = 3
|
||||
vocab_size = 250
|
||||
|
||||
input_ids = ids_tensor((batch_size, length), vocab_size)
|
||||
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
|
||||
return input_ids, scores
|
||||
|
||||
def test_list_criteria(self):
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
|
||||
criteria = StoppingCriteriaList(
|
||||
[
|
||||
MaxLengthCriteria(max_length=10),
|
||||
MaxTimeCriteria(max_time=0.1),
|
||||
]
|
||||
)
|
||||
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
def test_max_length_criteria(self):
|
||||
criteria = MaxLengthCriteria(max_length=10)
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
def test_max_time_criteria(self):
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
|
||||
criteria = MaxTimeCriteria(max_time=0.1)
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
def test_eos_token_criteria(self):
|
||||
criteria = EosTokenCriteria(eos_token_id=0)
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:, -1] = 0
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:2, -1] = 0
|
||||
input_ids[2, -1] = 1
|
||||
self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False])
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:, -1] = 1
|
||||
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
|
||||
|
||||
def test_confidence_criteria(self):
|
||||
criteria = ConfidenceCriteria(assistant_confidence_threshold=0.5)
|
||||
|
||||
vocab_size = 250
|
||||
length = 5
|
||||
|
||||
input_ids = ids_tensor((1, length), vocab_size)
|
||||
scores = (torch.randn((1, vocab_size)),)
|
||||
|
||||
# Simulate high confidence by setting the probability of the last token to be high
|
||||
scores[0][0, input_ids[0, -1]] = 10.0 # Logits before softmax
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
# Simulate low confidence by setting the probability of the last token to be low
|
||||
scores[0][0, input_ids[0, -1]] = -10.0 # Logits before softmax
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_validate_stopping_criteria(self):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
|
||||
|
||||
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
|
||||
|
||||
self.assertEqual(len(stopping_criteria), 1)
|
||||
|
||||
def test_stop_string_criteria(self):
|
||||
true_strings = [
|
||||
"<|im_start|><|im_end|>",
|
||||
"<|im_start|><|im_end|<|im_end|>",
|
||||
">><|im_start|>>stop",
|
||||
"stop",
|
||||
"e nd",
|
||||
]
|
||||
false_strings = [
|
||||
"<|im_start|><|im_end|",
|
||||
"<|im_start|><|im_end|<|im_end|",
|
||||
"<|im_end|><|im_start|>",
|
||||
"<|im_end|<>stop<|im_end|",
|
||||
"end",
|
||||
"en d",
|
||||
"eNd",
|
||||
"<|im_end|",
|
||||
"|im_end|>",
|
||||
"s",
|
||||
]
|
||||
stop_strings = ["<|im_end|>", "stop", "e nd"]
|
||||
|
||||
# Use a tokenizer that won't actually have special tokens for these
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
tokenizer.padding_side = "left"
|
||||
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
|
||||
scores = None
|
||||
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
|
||||
for i in range(len(true_strings)):
|
||||
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
|
||||
for i in range(len(false_strings)):
|
||||
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
|
||||
|
||||
# Now try it with a tokenizer where those are actually special tokens
|
||||
tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b")
|
||||
tokenizer.padding_side = "left"
|
||||
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
|
||||
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
|
||||
for i in range(len(true_strings)):
|
||||
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
|
||||
for i in range(len(false_strings)):
|
||||
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
|
||||
|
||||
def test_stop_string_criteria_vocab_size_mismatch(self):
|
||||
"""Test that StopStringCriteria handles tokens above len(tokenizer) correctly."""
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
# Create input_ids with tokens above len(tokenizer)
|
||||
input_ids = torch.tensor([[len(tokenizer) + 1024, 1, 2]], device=torch_device)
|
||||
scores = None
|
||||
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["test"])
|
||||
|
||||
# This should not raise an error and should return False since no stop string is matched
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
def test_stop_string_matching_positions(self):
|
||||
stop_string = "stop"
|
||||
token_list = ["last", "top", "topper", "s", "p"]
|
||||
token_indices = list(range(len(token_list)))
|
||||
all_token_valid_positions, all_token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions(
|
||||
token_list=token_list, token_indices=token_indices, stop_strings=[stop_string]
|
||||
)
|
||||
valid_positions = {
|
||||
token_list[idx]: positions for idx, positions in all_token_valid_positions[stop_string].items()
|
||||
}
|
||||
end_overlaps = {token_list[idx]: overlaps for idx, overlaps in all_token_end_overlaps[stop_string].items()}
|
||||
self.assertEqual(valid_positions, {"s": [3], "last": [2]})
|
||||
self.assertEqual(end_overlaps, {"top": [3], "topper": [3], "p": [1]})
|
||||
|
||||
def test_stop_string_embedding_vecs(self):
|
||||
stop_string = "stop"
|
||||
token_list = ["last", "top", "topper", "s", "p"]
|
||||
token_indices = list(range(len(token_list)))
|
||||
embedding_vec, max_valid_positions, max_valid_end_lens = StopStringCriteria._stop_string_create_embedding_vec(
|
||||
token_list=token_list, token_indices=token_indices, stop_strings=[stop_string]
|
||||
)
|
||||
|
||||
# Positions inside the stop string where the token matches (excluding end overlaps)
|
||||
valid_positions = embedding_vec[:, 0].tolist()
|
||||
self.assertEqual(valid_positions, [2, -1, -1, 3, -1, -1])
|
||||
|
||||
# Overlap lengths between end of stop string and start of token
|
||||
end_overlaps = embedding_vec[:, 1].tolist()
|
||||
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1, -1])
|
||||
|
||||
# Length of each token
|
||||
token_lengths = embedding_vec[:-1, 2].tolist()
|
||||
self.assertEqual(token_lengths, [len(token) for token in token_list])
|
||||
|
||||
def test_single_letter_stop_string(self):
|
||||
true_strings = ["a", "baa", "abc"] # "abc" is a single token
|
||||
false_strings = ["abbbbbbb", "b"] # "abbbbbbb" is split into multiple tokens
|
||||
stop_strings = ["a"]
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
|
||||
scores = None
|
||||
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
|
||||
for input_ids in true_input_ids["input_ids"]:
|
||||
self.assertTrue(criteria(input_ids.unsqueeze(0), scores))
|
||||
for input_ids in false_input_ids["input_ids"]:
|
||||
self.assertFalse(criteria(input_ids.unsqueeze(0), scores))
|
||||
|
||||
def test_criteria_per_row(self):
|
||||
text = "They completed the challenging puzzle, revealing the hidden image at the end"
|
||||
stop_strings = ["end"]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
|
||||
|
||||
scores = None
|
||||
criteria = StoppingCriteriaList(
|
||||
[
|
||||
MaxLengthCriteria(max_length=20),
|
||||
StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings),
|
||||
]
|
||||
)
|
||||
|
||||
# trigger stopping when at least one criteria is satisfied, one value per batch
|
||||
self.assertTrue(criteria(inputs["input_ids"], scores))
|
||||
|
||||
# return False when neither is satisfied
|
||||
self.assertFalse(criteria(inputs["input_ids"][:, :-1], scores))
|
||||
|
||||
def test_criteria_per_row_batched(self):
|
||||
text = [
|
||||
"They completed the challenging puzzle, revealing the hidden image at the end",
|
||||
"Today a dragon flew over France",
|
||||
"The aroma of freshly baked pizza filled the kitchen",
|
||||
]
|
||||
stop_strings = ["end"]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
tokenizer.padding_side = "left"
|
||||
inputs = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
|
||||
scores = None
|
||||
criteria = StoppingCriteriaList(
|
||||
[
|
||||
MaxLengthCriteria(max_length=20),
|
||||
StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings),
|
||||
]
|
||||
)
|
||||
|
||||
# trigger stopping when at least one criteria is satisfied
|
||||
self.assertListEqual(criteria(inputs["input_ids"], scores).tolist(), [True, False, False])
|
||||
|
||||
# False when neither is satisfied
|
||||
self.assertListEqual(criteria(inputs["input_ids"][:, :-1], scores).tolist(), [False, False, False])
|
||||
174
transformers/tests/generation/test_streamers.py
Normal file
174
transformers/tests/generation/test_streamers.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright 2023 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from queue import Empty
|
||||
from threading import Thread
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import (
|
||||
AsyncTextIteratorStreamer,
|
||||
AutoTokenizer,
|
||||
TextIteratorStreamer,
|
||||
TextStreamer,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import CaptureStdout, require_torch, torch_device
|
||||
from transformers.utils.logging import _get_library_root_logger
|
||||
|
||||
from ..test_modeling_common import ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
@require_torch
|
||||
class StreamerTester(unittest.TestCase):
|
||||
def test_text_streamer_matches_non_streaming(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
greedy_text = tokenizer.decode(greedy_ids[0])
|
||||
|
||||
with CaptureStdout() as cs:
|
||||
streamer = TextStreamer(tokenizer)
|
||||
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
|
||||
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
|
||||
streamer_text = cs.out[:-1]
|
||||
|
||||
self.assertEqual(streamer_text, greedy_text)
|
||||
|
||||
def test_iterator_streamer_matches_non_streaming(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
greedy_text = tokenizer.decode(greedy_ids[0])
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
streamer_text = ""
|
||||
for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
|
||||
self.assertEqual(streamer_text, greedy_text)
|
||||
|
||||
def test_text_streamer_skip_prompt(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
new_greedy_ids = greedy_ids[:, input_ids.shape[1] :]
|
||||
new_greedy_text = tokenizer.decode(new_greedy_ids[0])
|
||||
|
||||
with CaptureStdout() as cs:
|
||||
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
||||
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
|
||||
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
|
||||
streamer_text = cs.out[:-1]
|
||||
|
||||
self.assertEqual(streamer_text, new_greedy_text)
|
||||
|
||||
def test_text_streamer_decode_kwargs(self):
|
||||
# Tests that we can pass `decode_kwargs` to the streamer to control how the tokens are decoded. Must be tested
|
||||
# with actual models -- the dummy models' tokenizers are not aligned with their models, and
|
||||
# `skip_special_tokens=True` has no effect on them
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = torch.ones((1, 5), device=torch_device).long() * model.config.bos_token_id
|
||||
|
||||
root = _get_library_root_logger()
|
||||
with patch.object(root, "propagate", False):
|
||||
with CaptureStdout() as cs:
|
||||
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
|
||||
model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)
|
||||
|
||||
# The prompt contains a special token, so the streamer should not print it. As such, the output text, when
|
||||
# re-tokenized, must only contain one token
|
||||
streamer_text = cs.out[:-1] # Remove the final "\n"
|
||||
streamer_text_tokenized = tokenizer(streamer_text, return_tensors="pt")
|
||||
self.assertEqual(streamer_text_tokenized.input_ids.shape, (1, 1))
|
||||
|
||||
def test_iterator_streamer_timeout(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=0.001)
|
||||
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
# The streamer will timeout after 0.001 seconds, so an exception will be raised
|
||||
with self.assertRaises(Empty):
|
||||
streamer_text = ""
|
||||
for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
|
||||
|
||||
@require_torch
|
||||
@pytest.mark.asyncio(loop_scope="class")
|
||||
class AsyncStreamerTester(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_async_iterator_streamer_matches_non_streaming(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
greedy_text = tokenizer.decode(greedy_ids[0])
|
||||
|
||||
streamer = AsyncTextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
streamer_text = ""
|
||||
async for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
|
||||
self.assertEqual(streamer_text, greedy_text)
|
||||
|
||||
async def test_async_iterator_streamer_timeout(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
streamer = AsyncTextIteratorStreamer(tokenizer, timeout=0.001)
|
||||
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
# The streamer will timeout after 0.001 seconds, so TimeoutError will be raised
|
||||
with self.assertRaises(TimeoutError):
|
||||
streamer_text = ""
|
||||
async for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
5049
transformers/tests/generation/test_utils.py
Normal file
5049
transformers/tests/generation/test_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user