### What this PR does / why we need it?
| File Path |
| :--- |
| `tests/e2e/singlecard/compile/backend.py` |
| `tests/e2e/singlecard/compile/test_graphex_norm_quant_fusion.py` |
| `tests/e2e/singlecard/compile/test_graphex_qknorm_rope_fusion.py` |
| `tests/e2e/singlecard/compile/test_norm_quant_fusion.py` |
| `tests/e2e/singlecard/model_runner_v2/test_basic.py` |
| `tests/e2e/singlecard/test_aclgraph_accuracy.py` |
| `tests/e2e/singlecard/test_aclgraph_batch_invariant.py` |
| `tests/e2e/singlecard/test_aclgraph_mem.py` |
| `tests/e2e/singlecard/test_async_scheduling.py` |
| `tests/e2e/singlecard/test_auto_fit_max_mode_len.py` |
| `tests/e2e/singlecard/test_batch_invariant.py` |
| `tests/e2e/singlecard/test_camem.py` |
| `tests/e2e/singlecard/test_completion_with_prompt_embeds.py` |
| `tests/e2e/singlecard/test_cpu_offloading.py` |
| `tests/e2e/singlecard/test_guided_decoding.py` |
| `tests/e2e/singlecard/test_ilama_lora.py` |
| `tests/e2e/singlecard/test_llama32_lora.py` |
| `tests/e2e/singlecard/test_models.py` |
| `tests/e2e/singlecard/test_multistream_overlap_shared_expert.py` |
| `tests/e2e/singlecard/test_quantization.py` |
| `tests/e2e/singlecard/test_qwen3_multi_loras.py` |
| `tests/e2e/singlecard/test_sampler.py` |
| `tests/e2e/singlecard/test_vlm.py` |
| `tests/e2e/singlecard/test_xlite.py` |
| `tests/e2e/singlecard/utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -22,6 +22,7 @@ import random
|
||||
import pytest
|
||||
import torch
|
||||
from vllm import SamplingParams
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
|
||||
@@ -69,9 +70,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
|
||||
if target_words > 50:
|
||||
# For longer prompts, repeat context
|
||||
padding_text = (
|
||||
" This is an interesting topic that deserves more explanation. " *
|
||||
(target_words // 50))
|
||||
padding_text = " This is an interesting topic that deserves more explanation. " * (target_words // 50)
|
||||
base_prompt = base_prompt + padding_text
|
||||
|
||||
return base_prompt
|
||||
@@ -107,8 +106,7 @@ def _extract_step_logprobs(generate_output):
|
||||
|
||||
|
||||
@pytest.mark.timeout(1000)
|
||||
def test_aclgraph_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_aclgraph_v1_generation_is_deterministic_across_batch_sizes_with_needle(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
||||
@@ -162,20 +160,16 @@ def test_aclgraph_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
needle_prompt = "There once was a "
|
||||
|
||||
with VllmRunner(
|
||||
model_name=model,
|
||||
max_num_seqs=max_batch_size,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
enable_prefix_caching=False,
|
||||
distributed_executor_backend="mp",
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [1, 32, 64]
|
||||
}
|
||||
model_name=model,
|
||||
max_num_seqs=max_batch_size,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
enable_prefix_caching=False,
|
||||
distributed_executor_backend="mp",
|
||||
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 32, 64]},
|
||||
) as vllm_model:
|
||||
|
||||
# Baseline generation for the needle prompt alone.
|
||||
baseline_out = vllm_model.generate([needle_prompt], sampling)
|
||||
assert len(baseline_out) == 1
|
||||
@@ -194,8 +188,7 @@ def test_aclgraph_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
if i == needle_pos:
|
||||
prompts.append(needle_prompt)
|
||||
else:
|
||||
prompts.append(
|
||||
_random_prompt(min_random_prompt, max_random_prompt))
|
||||
prompts.append(_random_prompt(min_random_prompt, max_random_prompt))
|
||||
|
||||
# Generate with the larger-batch engine
|
||||
outputs = vllm_model.generate(prompts, sampling)
|
||||
@@ -204,24 +197,23 @@ def test_aclgraph_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
text = needle_output[0]
|
||||
|
||||
if text != baseline_text:
|
||||
print(
|
||||
f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
|
||||
print(f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
|
||||
mismatches += 1
|
||||
|
||||
passes = num_trials - mismatches
|
||||
# Dump how many passed vs failed
|
||||
print(f"[determinism] total={num_trials}, passed={passes}, "
|
||||
f"failed={mismatches}, max_batch_size={max_batch_size}")
|
||||
print(
|
||||
f"[determinism] total={num_trials}, passed={passes}, failed={mismatches}, max_batch_size={max_batch_size}"
|
||||
)
|
||||
|
||||
if mismatches > 0:
|
||||
pytest.fail(
|
||||
f"Nondeterministic outputs detected: {mismatches} failed out "
|
||||
f"of {num_trials} trials (max_batch_size={max_batch_size}).")
|
||||
f"of {num_trials} trials (max_batch_size={max_batch_size})."
|
||||
)
|
||||
|
||||
|
||||
|
||||
def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(monkeypatch: pytest.MonkeyPatch):
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = DEFAULT_MODEL
|
||||
@@ -235,24 +227,19 @@ def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
|
||||
if disable_custom_ar:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(
|
||||
f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})"
|
||||
)
|
||||
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
with VllmRunner(
|
||||
model_name=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
gpu_memory_utilization=0.9,
|
||||
distributed_executor_backend="mp",
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [1, 32, 64]
|
||||
}
|
||||
model_name=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
gpu_memory_utilization=0.9,
|
||||
distributed_executor_backend="mp",
|
||||
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 32, 64]},
|
||||
) as vllm_model:
|
||||
# Use more realistic prompts for better token generation
|
||||
prompts = [_random_prompt(10, 50) for i in range(32)]
|
||||
@@ -273,16 +260,13 @@ def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
bs1_logprobs_per_prompt = []
|
||||
bs1_tokens_per_prompt = []
|
||||
for idx, p in enumerate(prompts):
|
||||
print(
|
||||
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
||||
)
|
||||
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
|
||||
outs = vllm_model.generate_w_logprobs([p], sp, use_tqdm=False)
|
||||
assert len(outs) == 1
|
||||
# print(outs)
|
||||
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
bs1_tokens_per_prompt.append(token_ids)
|
||||
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
||||
@@ -304,108 +288,91 @@ def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
||||
step_logprobs, token_ids = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
|
||||
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||
bsN_tokens_per_prompt.append(token_ids)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||
failed_prompts = []
|
||||
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)):
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)
|
||||
):
|
||||
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||
reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
||||
f"vs {len(logprobs_bsN)} (BS=N)")
|
||||
failed_prompts.append({
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
reason = f"Different number of steps: {len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if tokens match first
|
||||
if tokens_bs1 != tokens_bsN:
|
||||
failed_prompts.append({
|
||||
"prompt_idx":
|
||||
i,
|
||||
"step":
|
||||
"sampling",
|
||||
"reason":
|
||||
"Different tokens sampled",
|
||||
"prompt_preview":
|
||||
prompts[i][:100],
|
||||
"bs1_tokens":
|
||||
tokens_bs1,
|
||||
"bsN_tokens":
|
||||
tokens_bsN,
|
||||
"bs1_all_logprobs":
|
||||
[logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
|
||||
"bsN_all_logprobs":
|
||||
[logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
|
||||
})
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "sampling",
|
||||
"reason": "Different tokens sampled",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
"bs1_all_logprobs": [logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
|
||||
"bsN_all_logprobs": [logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||
if a.shape != b.shape:
|
||||
failed_prompts.append({
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
if not torch.equal(a, b):
|
||||
max_diff = torch.abs(a - b).max().item()
|
||||
# Print which token failed
|
||||
print(
|
||||
f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}"
|
||||
)
|
||||
print(f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
|
||||
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
||||
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
||||
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
||||
print(f" BS=1 logprob: {a.tolist()}")
|
||||
print(f" BS=N logprob: {b.tolist()}")
|
||||
failed_prompts.append({
|
||||
"prompt_idx":
|
||||
i,
|
||||
"step":
|
||||
t,
|
||||
"reason":
|
||||
f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview":
|
||||
prompts[i][:100],
|
||||
"bs1_tokens":
|
||||
tokens_bs1,
|
||||
"bsN_tokens":
|
||||
tokens_bsN,
|
||||
"bs1_all_logprobs": [
|
||||
logprobs_bs1[s].tolist()
|
||||
for s in range(len(logprobs_bs1))
|
||||
],
|
||||
"bsN_all_logprobs": [
|
||||
logprobs_bsN[s].tolist()
|
||||
for s in range(len(logprobs_bsN))
|
||||
],
|
||||
})
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
"bs1_all_logprobs": [logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
|
||||
"bsN_all_logprobs": [logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
# Print summary of all failures
|
||||
if failed_prompts:
|
||||
print(f"\n{'=' * 80}")
|
||||
fail_msg = (f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/"
|
||||
f"{len(prompts)} prompts failed")
|
||||
fail_msg = f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/{len(prompts)} prompts failed"
|
||||
print(fail_msg)
|
||||
print(f"{'=' * 80}")
|
||||
for fail in failed_prompts:
|
||||
@@ -420,21 +387,18 @@ def test_aclgraph_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
print(f" BS=N tokens: {fail['bsN_tokens']}")
|
||||
|
||||
if "bs1_all_logprobs" in fail:
|
||||
print(
|
||||
f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:"
|
||||
)
|
||||
print(f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:")
|
||||
for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]):
|
||||
print(f" Step {step_idx}: {logprobs}")
|
||||
print(
|
||||
f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:"
|
||||
)
|
||||
print(f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:")
|
||||
for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]):
|
||||
print(f" Step {step_idx}: {logprobs}")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Fail the test with summary
|
||||
msg = (f"Batch invariance violated in {len(failed_prompts)}/"
|
||||
f"{len(prompts)} prompts. See output above for details.")
|
||||
msg = (
|
||||
f"Batch invariance violated in {len(failed_prompts)}/{len(prompts)} prompts. See output above for details."
|
||||
)
|
||||
pytest.fail(msg)
|
||||
|
||||
|
||||
@@ -446,18 +410,15 @@ def test_aclgraph_simple_generation(monkeypatch: pytest.MonkeyPatch):
|
||||
model = DEFAULT_MODEL
|
||||
|
||||
with VllmRunner(
|
||||
model_name=model,
|
||||
max_num_seqs=1,
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048,
|
||||
dtype="float16",
|
||||
enable_prefix_caching=False,
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [1, 32, 64]
|
||||
},
|
||||
distributed_executor_backend="mp",
|
||||
model_name=model,
|
||||
max_num_seqs=1,
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048,
|
||||
dtype="float16",
|
||||
enable_prefix_caching=False,
|
||||
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 32, 64]},
|
||||
distributed_executor_backend="mp",
|
||||
) as vllm_model:
|
||||
prompt = "The capital of France is"
|
||||
sampling_params = SamplingParams(
|
||||
@@ -479,11 +440,7 @@ def test_aclgraph_simple_generation(monkeypatch: pytest.MonkeyPatch):
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def test_aclgraph_logprobs_without_batch_invariance_should_fail(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_aclgraph_logprobs_without_batch_invariance_should_fail(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
|
||||
It DISABLES batch invariance mode and expects to see non-deterministic behavior
|
||||
@@ -505,19 +462,15 @@ def test_aclgraph_logprobs_without_batch_invariance_should_fail(
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
with VllmRunner(
|
||||
model_name=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [1, 32, 64]
|
||||
},
|
||||
distributed_executor_backend="mp",
|
||||
model_name=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 32, 64]},
|
||||
distributed_executor_backend="mp",
|
||||
) as vllm_model:
|
||||
|
||||
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
||||
long_min = int(os.getenv("VLLM_MIN_PROMPT", "768"))
|
||||
long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
|
||||
@@ -549,16 +502,13 @@ def test_aclgraph_logprobs_without_batch_invariance_should_fail(
|
||||
bs1_logprobs_per_prompt = []
|
||||
bs1_tokens_per_prompt = []
|
||||
for idx, p in enumerate(prompts):
|
||||
print(
|
||||
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
||||
)
|
||||
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
|
||||
outs = vllm_model.generate_w_logprobs([p], sp, use_tqdm=False)
|
||||
|
||||
assert len(outs) == 1
|
||||
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
bs1_tokens_per_prompt.append(token_ids)
|
||||
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
||||
@@ -579,84 +529,90 @@ def test_aclgraph_logprobs_without_batch_invariance_should_fail(
|
||||
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
||||
step_logprobs, token_ids = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
pytest.skip("Logits are not available on RequestOutput; enable logprobs return to run this test.")
|
||||
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||
bsN_tokens_per_prompt.append(token_ids)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||
differences_found = []
|
||||
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)):
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)
|
||||
):
|
||||
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||
reason = (f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
||||
f"vs {len(logprobs_bsN)} (BS=N)")
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
reason = f"Different number of steps: {len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if tokens match first
|
||||
if tokens_bs1 != tokens_bsN:
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": "sampling",
|
||||
"reason": "Different tokens sampled",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "sampling",
|
||||
"reason": "Different tokens sampled",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||
if a.shape != b.shape:
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
if not torch.equal(a, b):
|
||||
max_diff = torch.abs(a - b).max().item()
|
||||
print(f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
|
||||
f"Token {t}: max_diff={max_diff:.6e}")
|
||||
print(f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
|
||||
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
||||
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
||||
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
||||
print(f" BS=1 logprob: {a.tolist()}")
|
||||
print(f" BS=N logprob: {b.tolist()}")
|
||||
differences_found.append({
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
})
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'=' * 80}")
|
||||
if differences_found:
|
||||
success_msg = (
|
||||
f"✓ SUCCESS: Batch invariance is doing something! "
|
||||
f"Found {len(differences_found)}/{len(prompts)} prompts "
|
||||
f"with differences when batch invariance was DISABLED.")
|
||||
f"with differences when batch invariance was DISABLED."
|
||||
)
|
||||
print(success_msg)
|
||||
print(f"{'=' * 80}")
|
||||
for diff in differences_found:
|
||||
@@ -676,7 +632,8 @@ def test_aclgraph_logprobs_without_batch_invariance_should_fail(
|
||||
f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
|
||||
f"between BS=1 and BS=N even with batch invariance DISABLED. "
|
||||
f"This suggests batch invariance might not be necessary, "
|
||||
f"or the test needs more sensitive prompts.")
|
||||
f"or the test needs more sensitive prompts."
|
||||
)
|
||||
print(fail_msg)
|
||||
print(f"{'=' * 80}\n")
|
||||
pytest.fail(fail_msg)
|
||||
|
||||
Reference in New Issue
Block a user