Files
enginex-mlu370-any2any/transformers/tests/generation/test_continuous_batching.py
2025-10-09 16:47:16 +08:00

291 lines
13 KiB
Python

# Copyright 2025 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import Optional
import torch
from parameterized import parameterized
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.generation.continuous_batching.cache import group_layers_by_attn_type
from transformers.testing_utils import Expectations, require_kernels, require_torch_gpu, slow
ALLOW_EXPECTED_OUTPUTS = True # this is a debug flag when you want to measure deviation between CB and non-CB gen
class ContinuousBatchingTest(unittest.TestCase):
@parameterized.expand(
[
(None, None, "0"),
(None, 4096, "0"),
("f", None, "0"),
("ffff", None, "0000"),
("sssss", 4096, "00000"),
("fs", 4096, "01"),
("ssfssf", 4096, "001221"),
("ssssf", 4096, "01234"),
("fffsffs", 4096, "0123456"),
]
)
def test_group_layers(
self,
layer_types_str: Optional[str],
sliding_window: Optional[int],
expected_groups: str,
) -> None:
# Take a config and change the layer_types attribute to the mix we want
config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B")
if layer_types_str is not None:
layer_types = [{"f": "full_attention", "s": "sliding_window"}[char] for char in layer_types_str]
else:
layer_types = None
config.num_hidden_layers = len(expected_groups)
config.layer_types = layer_types
config.sliding_window = sliding_window
expected_lg = {}
for i, group in enumerate(expected_groups):
group = int(group)
expected_lg[group] = expected_lg.get(group, []) + [i]
expected_layer_groups = [expected_lg[i] for i in sorted(expected_lg.keys())]
# Test layer groups formation
layer_groups, group_types = group_layers_by_attn_type(config)
self.assertEqual(
sorted(expected_layer_groups),
sorted(layer_groups),
f"Test failed for: {layer_types_str = }, {sliding_window = }, {expected_layer_groups = }, {layer_groups = }",
)
# If layer_types is provided, check that group_types matches the type of the all layers in each group
if layer_types is not None:
for layer_group, group_type in zip(layer_groups, group_types):
layer_types = [config.layer_types[i] for i in layer_group]
self.assertEqual(layer_types, [group_type] * len(layer_types))
# If layer_types is None, all groups should be of the same type
else:
for group_type in group_types:
sliding_window = getattr(config, "sliding_window", None)
expected_group_type = "sliding_attention" if sliding_window is not None else "full_attention"
self.assertEqual(
group_type,
expected_group_type,
f"Test failed for: {layer_types_str = }, {sliding_window = }, {group_types = }",
)
def _continuous_batching_parity(
self, model_id: str, attn_implementation: str, expected_outputs: dict[str, str]
) -> None:
# Prepare common elements
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
prompts = [
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her "
"friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh "
"duck egg. How much in dollars does she make every day at the farmers' market? The answer is:",
"A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take? "
"The answer is:",
"Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. "
"This increased the value of the house by 150%. How much profit did he make? The answer is:",
] # fmt: skip
batched_inputs = [tokenizer.encode(prompt) for prompt in prompts]
# Generation with continuous batching
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation, dtype="auto")
model = model.cuda().eval()
model.generation_config.max_new_tokens = 40
model.generation_config.do_sample = False
model.generation_config.use_cuda_graph = False
cb_outputs = model.generate_batch(inputs=batched_inputs, generation_config=model.generation_config)
# Generation without continuous batching
if attn_implementation == "sdpa_paged":
non_cb_attn_implementation = "sdpa"
elif attn_implementation == "eager_paged":
non_cb_attn_implementation = "eager"
elif attn_implementation == "paged_attention|kernels-community/flash-attn":
non_cb_attn_implementation = "eager"
else:
raise ValueError(f"Invalid attention implementation: {attn_implementation}")
# We regenerate the model because just changing the attn_implementation does not work
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=non_cb_attn_implementation, dtype="auto"
)
model = model.cuda().eval()
model.generation_config.max_new_tokens = 40
model.generation_config.do_sample = False
model.generation_config.use_cuda_graph = False
for request_id, request in cb_outputs.items():
# Generate without continuous batching
input_ids = torch.tensor([request.prompt_ids]).cuda()
attention_mask = torch.ones_like(input_ids)
outputs = model.generate(
input_ids, attention_mask=attention_mask, generation_config=model.generation_config
)
generated_tokens = outputs[0][input_ids.shape[1] :]
non_cb_decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
input_ids = input_ids.tolist()[0]
# Check that the generated output with and without CB match
cb_decoded_output = tokenizer.decode(request.generated_tokens, skip_special_tokens=True)
outputs_match = non_cb_decoded_output == cb_decoded_output
# If they dont, that might be expected: the outputs can differ slightly due to numerical differences
# If that's the case, there is an expected output ready
if not outputs_match:
expected_output = expected_outputs.get(request_id) if ALLOW_EXPECTED_OUTPUTS else None
if expected_output is None:
self.fail(
f"Test {request_id = } failed, no expected output was provided.\nRef:"
f"{repr(non_cb_decoded_output)}\nOut:{repr(cb_decoded_output)}"
)
else:
self.assertEqual(
expected_output,
cb_decoded_output,
msg=f"Test {request_id = } failed, expected output did not match.\n"
f"Exp:{repr(expected_output)}\nOut:{repr(cb_decoded_output)}",
)
# Eager tests
@require_torch_gpu
@slow
def test_continuous_batching_parity_llama_eager(self) -> None:
expected_outputs = Expectations({
("rocm", (9, 4)): {
"req_0": " $16. How did I get that answer? I used the following equation: 16 - 3 - 4 = 9. 9 x $2 = $18. $18 -"
},
("cuda", (9, 0)): {
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5. The total number of bolts is 4.5. The total",
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
}
}).get_expectation() # fmt: skip
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "eager_paged", expected_outputs)
@require_torch_gpu
@slow
def test_continuous_batching_parity_gemma_eager(self) -> None:
expected_outputs = Expectations({
("rocm", (9, 4)): {
"req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 ="
},
("cuda", (9, 0)): {
"req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13",
"req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n "
}
}).get_expectation() # fmt: skip
self._continuous_batching_parity("google/gemma-2-2b-it", "eager_paged", expected_outputs)
@require_torch_gpu
@slow
def test_continuous_batching_parity_qwen_eager(self) -> None:
expected_outputs = {}
self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "eager_paged", expected_outputs)
@require_torch_gpu
@slow
def test_continuous_batching_parity_gpt_oss_eager(self) -> None:
expected_outputs = Expectations({
("cuda", (9, 0)): {
"req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The",
"req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%."
}
}).get_expectation() # fmt: skip
self._continuous_batching_parity("openai/gpt-oss-20b", "eager_paged", expected_outputs)
# SDPA tests
@require_torch_gpu
@slow
def test_continuous_batching_parity_llama_sdpa(self) -> None:
expected_outputs = Expectations({
("rocm", (9, 4)): {
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
}
}).get_expectation() # fmt: skip
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "sdpa_paged", expected_outputs)
@require_torch_gpu
@slow
def test_continuous_batching_parity_gemma_sdpa(self) -> None:
expected_outputs = Expectations({
("cuda", (9, 0)): {
"req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =",
}
}).get_expectation() # fmt: skip
self._continuous_batching_parity("google/gemma-2-2b-it", "sdpa_paged", expected_outputs)
@require_torch_gpu
@slow
def test_continuous_batching_parity_qwen_sdpa(self) -> None:
expected_outputs = {}
self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "sdpa_paged", expected_outputs)
# GPT-OSS is not compatible with SDPA because it has an attention sink. TODO: is this fixable?
# Flash attention test
@require_torch_gpu
@require_kernels
@slow
def test_continuous_batching_parity_llama_flash(self) -> None:
expected_outputs = Expectations({
("cuda", (9, 0)): {
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5 bolts. The total number of bolts is 4.5 bolts.",
}
}).get_expectation() # fmt: skip
self._continuous_batching_parity(
"meta-llama/Llama-3.1-8B", "paged_attention|kernels-community/flash-attn", expected_outputs
)
@require_torch_gpu
@require_kernels
@slow
def test_continuous_batching_parity_gemma_flash(self) -> None:
expected_outputs = Expectations({
("cuda", (9, 0)): {
"req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n ",
}
}).get_expectation() # fmt: skip
self._continuous_batching_parity(
"google/gemma-2-2b-it", "paged_attention|kernels-community/flash-attn", expected_outputs
)
@require_torch_gpu
@require_kernels
@slow
def test_continuous_batching_parity_qwen_flash(self) -> None:
expected_outputs = {}
self._continuous_batching_parity(
"Qwen/Qwen3-4B-Instruct-2507", "paged_attention|kernels-community/flash-attn", expected_outputs
)
@require_torch_gpu
@require_kernels
@slow
def test_continuous_batching_parity_gpt_oss_flash(self) -> None:
expected_outputs = {}
self._continuous_batching_parity(
"openai/gpt-oss-20b", "paged_attention|kernels-community/flash-attn", expected_outputs
)
# FIXME: the gemma test seem broken, there is a message about cuda graphs and the sdpa and flash expecteations are
# inverted on CUDA. On AMD they do fine.