[Lint]Style: Convert test/ to ruff format(Batch #5) (#6747)

### What this PR does / why we need it?
| File Path |
| :--- |
| `tests/e2e/singlecard/compile/backend.py` |
| `tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py` |
| `tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py` |
| `tests/e2e/singlecard/compile/test_norm_quant_fusion.py` |
| `tests/e2e/singlecard/model_runner_v2/test_basic.py` |
| `tests/e2e/singlecard/test_aclgraph_accuracy.py` |
| `tests/e2e/singlecard/test_aclgraph_batch_invariant.py` |
| `tests/e2e/singlecard/test_aclgraph_mem.py` |
| `tests/e2e/singlecard/test_async_scheduling.py` |
| `tests/e2e/singlecard/test_auto_fit_max_mode_len.py` |
| `tests/e2e/singlecard/test_batch_invariant.py` |
| `tests/e2e/singlecard/test_camem.py` |
| `tests/e2e/singlecard/test_completion_with_prompt_embeds.py` |
| `tests/e2e/singlecard/test_cpu_offloading.py` |
| `tests/e2e/singlecard/test_guided_decoding.py` |
| `tests/e2e/singlecard/test_ilama_lora.py` |
| `tests/e2e/singlecard/test_llama32_lora.py` |
| `tests/e2e/singlecard/test_models.py` |
| `tests/e2e/singlecard/test_multistream_overlap_shared_expert.py` |
| `tests/e2e/singlecard/test_quantization.py` |
| `tests/e2e/singlecard/test_qwen3_multi_loras.py` |
| `tests/e2e/singlecard/test_sampler.py` |
| `tests/e2e/singlecard/test_vlm.py` |
| `tests/e2e/singlecard/test_xlite.py` |
| `tests/e2e/singlecard/utils.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-02-24 15:50:00 +08:00
committed by GitHub
parent 747484cb64
commit 62ea664aa7
26 changed files with 859 additions and 1052 deletions

View File

@@ -22,6 +22,7 @@ import random
import pytest
import torch
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
@@ -69,9 +70,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
if target_words > 50:
# For longer prompts, repeat context
padding_text = (
" This is an interesting topic that deserves more explanation. " *
(target_words // 50))
padding_text = " This is an interesting topic that deserves more explanation. " * (target_words // 50)
base_prompt = base_prompt + padding_text
return base_prompt
@@ -107,8 +106,7 @@ def _extract_step_logprobs(generate_output):
@pytest.mark.timeout(1000)
def test_aclgraph_v1_generation_is_deterministic_across_batch_sizes_with_needle(
monkeypatch: pytest.MonkeyPatch):
def test_aclgraph_v1_generation_is_deterministic_across_batch_sizes_with_needle(monkeypatch: pytest.MonkeyPatch):
"""
Ensures that the same request (the 'needle' prompt) yields identical output
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
@@ -162,20 +160,16 @@ def test_aclgraph_v1_generation_is_deterministic_across_batch_sizes_with_needle(
needle_prompt = "There once was a "
with VllmRunner(
model_name=model,
max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
distributed_executor_backend="mp",
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [1, 32, 64]
}
model_name=model,
max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
distributed_executor_backend="mp",
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 32, 64]},
) as vllm_model:
# Baseline generation for the needle prompt alone.
baseline_out = vllm_model.generate([needle_prompt], sampling)
assert len(baseline_out) == 1
@@ -194,8 +188,7 @@ def test_aclgraph_v1_generation_is_deterministic_across_batch_sizes_with_needle(
if i == needle_pos:
prompts.append(needle_prompt)
else:
prompts.append(
_random_prompt(min_random_prompt, max_random_prompt))
prompts.append(_random_prompt(min_random_prompt, max_random_prompt))
# Generate with the larger-batch engine
outputs = vllm_model.generate(prompts, sampling)
@@ -204,24 +197,23 @@ def test_aclgraph_v1_generation_is_deterministic_across_batch_sizes_with_needle(
text = needle_output[0]
if text != baseline_text:
print(
f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
print(f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
mismatches += 1
passes = num_trials - mismatches
# Dump how many passed vs failed
print(f"[determinism] total={num_trials}, passed={passes}, "
f"failed={mismatches}, max_batch_size={max_batch_size}")
print(
f"[determinism] total={num_trials}, passed={passes}, failed={mismatches}, max_batch_size={max_batch_size}"
)
if mismatches > 0:
pytest.fail(
f"Nondeterministic outputs detected: {mismatches} failed out "
f"of {num_trials} trials (max_batch_size={max_batch_size}).")
f"of {num_trials} trials (max_batch_size={max_batch_size})."
)
def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
monkeypatch: pytest.MonkeyPatch):
def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(monkeypatch: pytest.MonkeyPatch):
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = DEFAULT_MODEL
@@ -235,24 +227,19 @@ def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
if disable_custom_ar:
print(f"\n{'=' * 80}")
print(
f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})"
)
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
print(f"{'=' * 80}\n")
with VllmRunner(
model_name=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
gpu_memory_utilization=0.9,
distributed_executor_backend="mp",
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [1, 32, 64]
}
model_name=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
gpu_memory_utilization=0.9,
distributed_executor_backend="mp",
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 32, 64]},
) as vllm_model:
# Use more realistic prompts for better token generation
prompts = [_random_prompt(10, 50) for i in range(32)]
@@ -273,16 +260,13 @@ def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
bs1_logprobs_per_prompt = []
bs1_tokens_per_prompt = []
for idx, p in enumerate(prompts):
print(
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
)
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
outs = vllm_model.generate_w_logprobs([p], sp, use_tqdm=False)
assert len(outs) == 1
# print(outs)
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
if step_logprobs is None:
pytest.skip("Logits are not available on RequestOutput; "
"enable logprobs return to run this test.")
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
bs1_logprobs_per_prompt.append(step_logprobs)
bs1_tokens_per_prompt.append(token_ids)
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
@@ -304,108 +288,91 @@ def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
step_logprobs, token_ids = _extract_step_logprobs(o)
if step_logprobs is None:
pytest.skip("Logits are not available on RequestOutput; "
"enable logprobs return to run this test.")
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
bsN_logprobs_per_prompt.append(step_logprobs)
bsN_tokens_per_prompt.append(token_ids)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
failed_prompts = []
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
zip(
bs1_logprobs_per_prompt,
bsN_logprobs_per_prompt,
bs1_tokens_per_prompt,
bsN_tokens_per_prompt,
)):
zip(
bs1_logprobs_per_prompt,
bsN_logprobs_per_prompt,
bs1_tokens_per_prompt,
bsN_tokens_per_prompt,
)
):
if len(logprobs_bs1) != len(logprobs_bsN):
reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
f"vs {len(logprobs_bsN)} (BS=N)")
failed_prompts.append({
"prompt_idx": i,
"step": "all",
"reason": reason,
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
reason = f"Different number of steps: {len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
failed_prompts.append(
{
"prompt_idx": i,
"step": "all",
"reason": reason,
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
continue
# Check if tokens match first
if tokens_bs1 != tokens_bsN:
failed_prompts.append({
"prompt_idx":
i,
"step":
"sampling",
"reason":
"Different tokens sampled",
"prompt_preview":
prompts[i][:100],
"bs1_tokens":
tokens_bs1,
"bsN_tokens":
tokens_bsN,
"bs1_all_logprobs":
[logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
"bsN_all_logprobs":
[logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
})
failed_prompts.append(
{
"prompt_idx": i,
"step": "sampling",
"reason": "Different tokens sampled",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
"bs1_all_logprobs": [logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
"bsN_all_logprobs": [logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
}
)
continue
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
if a.shape != b.shape:
failed_prompts.append({
"prompt_idx": i,
"step": t,
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
failed_prompts.append(
{
"prompt_idx": i,
"step": t,
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
break
if not torch.equal(a, b):
max_diff = torch.abs(a - b).max().item()
# Print which token failed
print(
f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}"
)
print(f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
print(f" BS=1 logprob: {a.tolist()}")
print(f" BS=N logprob: {b.tolist()}")
failed_prompts.append({
"prompt_idx":
i,
"step":
t,
"reason":
f"Bitwise mismatch (max_diff={max_diff:.6e})",
"prompt_preview":
prompts[i][:100],
"bs1_tokens":
tokens_bs1,
"bsN_tokens":
tokens_bsN,
"bs1_all_logprobs": [
logprobs_bs1[s].tolist()
for s in range(len(logprobs_bs1))
],
"bsN_all_logprobs": [
logprobs_bsN[s].tolist()
for s in range(len(logprobs_bsN))
],
})
failed_prompts.append(
{
"prompt_idx": i,
"step": t,
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
"bs1_all_logprobs": [logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
"bsN_all_logprobs": [logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
}
)
break
# Print summary of all failures
if failed_prompts:
print(f"\n{'=' * 80}")
fail_msg = (f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/"
f"{len(prompts)} prompts failed")
fail_msg = f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/{len(prompts)} prompts failed"
print(fail_msg)
print(f"{'=' * 80}")
for fail in failed_prompts:
@@ -420,21 +387,18 @@ def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
print(f" BS=N tokens: {fail['bsN_tokens']}")
if "bs1_all_logprobs" in fail:
print(
f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:"
)
print(f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:")
for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]):
print(f" Step {step_idx}: {logprobs}")
print(
f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:"
)
print(f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:")
for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]):
print(f" Step {step_idx}: {logprobs}")
print(f"{'=' * 80}\n")
# Fail the test with summary
msg = (f"Batch invariance violated in {len(failed_prompts)}/"
f"{len(prompts)} prompts. See output above for details.")
msg = (
f"Batch invariance violated in {len(failed_prompts)}/{len(prompts)} prompts. See output above for details."
)
pytest.fail(msg)
@@ -446,18 +410,15 @@ def test_aclgraph_simple_generation(monkeypatch: pytest.MonkeyPatch):
model = DEFAULT_MODEL
with VllmRunner(
model_name=model,
max_num_seqs=1,
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="float16",
enable_prefix_caching=False,
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [1, 32, 64]
},
distributed_executor_backend="mp",
model_name=model,
max_num_seqs=1,
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="float16",
enable_prefix_caching=False,
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 32, 64]},
distributed_executor_backend="mp",
) as vllm_model:
prompt = "The capital of France is"
sampling_params = SamplingParams(
@@ -479,11 +440,7 @@ def test_aclgraph_simple_generation(monkeypatch: pytest.MonkeyPatch):
print(f"{'=' * 80}\n")
def test_aclgraph_logprobs_without_batch_invariance_should_fail(
monkeypatch: pytest.MonkeyPatch):
def test_aclgraph_logprobs_without_batch_invariance_should_fail(monkeypatch: pytest.MonkeyPatch):
"""
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
It DISABLES batch invariance mode and expects to see non-deterministic behavior
@@ -505,19 +462,15 @@ def test_aclgraph_logprobs_without_batch_invariance_should_fail(
print(f"{'=' * 80}\n")
with VllmRunner(
model_name=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [1, 32, 64]
},
distributed_executor_backend="mp",
model_name=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 32, 64]},
distributed_executor_backend="mp",
) as vllm_model:
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
long_min = int(os.getenv("VLLM_MIN_PROMPT", "768"))
long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
@@ -549,16 +502,13 @@ def test_aclgraph_logprobs_without_batch_invariance_should_fail(
bs1_logprobs_per_prompt = []
bs1_tokens_per_prompt = []
for idx, p in enumerate(prompts):
print(
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
)
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
outs = vllm_model.generate_w_logprobs([p], sp, use_tqdm=False)
assert len(outs) == 1
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
if step_logprobs is None:
pytest.skip("Logits are not available on RequestOutput; "
"enable logprobs return to run this test.")
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
bs1_logprobs_per_prompt.append(step_logprobs)
bs1_tokens_per_prompt.append(token_ids)
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
@@ -579,84 +529,90 @@ def test_aclgraph_logprobs_without_batch_invariance_should_fail(
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
step_logprobs, token_ids = _extract_step_logprobs(o)
if step_logprobs is None:
pytest.skip("Logits are not available on RequestOutput; "
"enable logprobs return to run this test.")
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
bsN_logprobs_per_prompt.append(step_logprobs)
bsN_tokens_per_prompt.append(token_ids)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
differences_found = []
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
zip(
bs1_logprobs_per_prompt,
bsN_logprobs_per_prompt,
bs1_tokens_per_prompt,
bsN_tokens_per_prompt,
)):
zip(
bs1_logprobs_per_prompt,
bsN_logprobs_per_prompt,
bs1_tokens_per_prompt,
bsN_tokens_per_prompt,
)
):
if len(logprobs_bs1) != len(logprobs_bsN):
reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
f"vs {len(logprobs_bsN)} (BS=N)")
differences_found.append({
"prompt_idx": i,
"step": "all",
"reason": reason,
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
reason = f"Different number of steps: {len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
differences_found.append(
{
"prompt_idx": i,
"step": "all",
"reason": reason,
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
continue
# Check if tokens match first
if tokens_bs1 != tokens_bsN:
differences_found.append({
"prompt_idx": i,
"step": "sampling",
"reason": "Different tokens sampled",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
differences_found.append(
{
"prompt_idx": i,
"step": "sampling",
"reason": "Different tokens sampled",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
continue
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
if a.shape != b.shape:
differences_found.append({
"prompt_idx": i,
"step": t,
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
differences_found.append(
{
"prompt_idx": i,
"step": t,
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
break
if not torch.equal(a, b):
max_diff = torch.abs(a - b).max().item()
print(f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
f"Token {t}: max_diff={max_diff:.6e}")
print(f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
print(f" BS=1 logprob: {a.tolist()}")
print(f" BS=N logprob: {b.tolist()}")
differences_found.append({
"prompt_idx": i,
"step": t,
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
})
differences_found.append(
{
"prompt_idx": i,
"step": t,
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
break
# Print summary
print(f"\n{'=' * 80}")
if differences_found:
success_msg = (
f"✓ SUCCESS: Batch invariance is doing something! "
f"Found {len(differences_found)}/{len(prompts)} prompts "
f"with differences when batch invariance was DISABLED.")
f"with differences when batch invariance was DISABLED."
)
print(success_msg)
print(f"{'=' * 80}")
for diff in differences_found:
@@ -676,7 +632,8 @@ def test_aclgraph_logprobs_without_batch_invariance_should_fail(
f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
f"between BS=1 and BS=N even with batch invariance DISABLED. "
f"This suggests batch invariance might not be necessary, "
f"or the test needs more sensitive prompts.")
f"or the test needs more sensitive prompts."
)
print(fail_msg)
print(f"{'=' * 80}\n")
pytest.fail(fail_msg)