[Misc] Refactor aclgraph accuracy test to use logprob-based comparison (#7455)

### What this PR does / why we need it?

Replace text-match assertions with a two-tier logprob accuracy check:

- Prefill (token 0): assert token ID is identical between eager baseline
and compiled mode, then verify logprob matches within `atol`.
- Decode (tokens 1-2): if chosen tokens match, compare logprobs
directly; if they differ, cross-lookup the baseline token in the
compiled model's top-20 distribution and assert the assigned logprob is
within `decode_atol` (defaults to 2x atol). This tolerates minor argmax
drift caused by floating-point differences while still catching
distribution divergence.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
8a680463fa

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2026-03-23 09:08:21 +08:00
committed by GitHub
parent 9bf9b4b267
commit 75fae619d5
5 changed files with 228 additions and 145 deletions

View File

@@ -1,9 +1,8 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
PROMPTS_SHORT = [
"Hello, my name is",
@@ -51,31 +50,143 @@ PROMPTS_LONG = [
class LLMTestCase:
model: str
prompts: list[str]
golden_answers: list[str]
golden_answers: list[str] | None = None
quantization: str | None = None
sampling_params: SamplingParams = field(
default_factory=lambda: SamplingParams(
max_tokens=32,
temperature=0.0,
top_p=1.0,
top_k=0,
n=1,
# Keys that are specific to compilation/graph capture and should not be passed
# to the eager baseline runner.
_COMPILATION_KEYS = {"compilation_config", "additional_config", "cudagraph_capture_sizes"}
# Top-K logprobs to fetch per token; used for decode-phase cross-lookup.
_DECODE_TOPK = 20
_LOGPROB_SAMPLING_PARAMS = SamplingParams(
max_tokens=3,
temperature=0.0,
top_p=1.0,
top_k=0,
logprobs=_DECODE_TOPK,
)
def _check_prefill_token(
base_seq,
comp_seq,
prompt_idx: int,
atol: float,
) -> None:
"""Token 0 is produced by the prefill pass; both models see identical input,
so the chosen token *must* be the same and its logprob must match within atol."""
base_token_id = base_seq.token_ids[0]
comp_token_id = comp_seq.token_ids[0]
assert base_token_id == comp_token_id, (
f"Prefill token mismatch at prompt {prompt_idx}: baseline={base_token_id}, compiled={comp_token_id}"
)
base_logprob = base_seq.logprobs[0][base_token_id].logprob
comp_logprob = comp_seq.logprobs[0][comp_token_id].logprob
assert abs(base_logprob - comp_logprob) <= atol, (
f"Prefill logprob mismatch at prompt {prompt_idx}: "
f"baseline={base_logprob:.4f}, compiled={comp_logprob:.4f}, "
f"diff={abs(base_logprob - comp_logprob):.4f} > atol={atol}"
)
def _check_decode_token(
base_seq,
comp_seq,
token_idx: int,
prompt_idx: int,
decode_atol: float,
) -> None:
"""Tokens 1-2 come from decode passes. When the two models pick different
tokens the context has already diverged, so we cannot compare logprobs of
the chosen tokens directly. Instead we do a cross-lookup: find the
baseline's chosen token inside compiled's top-K distribution (and vice
versa) and assert that the assigned log-probability is close. This
confirms that the compiled model's distribution is numerically consistent
with the baseline's even when the argmax differs by a tiny margin.
"""
base_token_id = base_seq.token_ids[token_idx]
comp_token_id = comp_seq.token_ids[token_idx]
base_topk = base_seq.logprobs[token_idx] # dict[token_id, Logprob]
comp_topk = comp_seq.logprobs[token_idx]
if base_token_id == comp_token_id:
# Happy path: same token, direct logprob comparison.
diff = abs(base_topk[base_token_id].logprob - comp_topk[comp_token_id].logprob)
assert diff <= decode_atol, (
f"Decode logprob mismatch at prompt {prompt_idx}, token {token_idx}: "
f"baseline={base_topk[base_token_id].logprob:.4f}, "
f"compiled={comp_topk[comp_token_id].logprob:.4f}, "
f"diff={diff:.4f} > decode_atol={decode_atol}"
)
return
# Tokens differ cross-lookup in each model's top-K distribution.
base_logprob = base_topk[base_token_id].logprob
comp_logprob = comp_topk[comp_token_id].logprob
# Check: what log-probability did compiled assign to baseline's token?
assert base_token_id in comp_topk, (
f"Decode token mismatch at prompt {prompt_idx}, token {token_idx}: "
f"baseline chose token {base_token_id} (logprob={base_logprob:.4f}) but "
f"compiled chose token {comp_token_id} (logprob={comp_logprob:.4f}) and "
f"baseline's token does not appear in compiled's top-{_DECODE_TOPK} distribution"
)
comp_logprob_of_base_token = comp_topk[base_token_id].logprob
diff = abs(base_logprob - comp_logprob_of_base_token)
assert diff <= decode_atol, (
f"Decode distribution mismatch at prompt {prompt_idx}, token {token_idx}: "
f"baseline chose token {base_token_id} with logprob={base_logprob:.4f}; "
f"compiled assigned logprob={comp_logprob_of_base_token:.4f} to that token, "
f"diff={diff:.4f} > decode_atol={decode_atol} "
f"(compiled chose token {comp_token_id} with logprob={comp_logprob:.4f})"
)
def gen_and_valid(runner_kwargs: dict, prompts: list[str], sampling_params: SamplingParams, golden_answers: list[str]):
def compare_logprobs(
runner_kwargs: dict,
prompts: list[str],
atol: float = 0.0689,
decode_atol: float | None = None,
) -> None:
"""Run the model in eager baseline mode and in the configured compilation
mode, generate 3 tokens per prompt, then verify numerical accuracy:
* Token 0 (prefill pass): chosen token must be identical; logprob must
match within *atol*.
* Tokens 1-2 (decode passes): if chosen tokens match, logprob must be
within *decode_atol*; if they differ, the baseline token must appear in
the compiled model's top-K distribution with a logprob within
*decode_atol* of the baseline value.
*decode_atol* defaults to ``2 * atol`` when not supplied.
"""
if decode_atol is None:
decode_atol = 2 * atol
baseline_kwargs = {k: v for k, v in runner_kwargs.items() if k not in _COMPILATION_KEYS}
baseline_kwargs["enforce_eager"] = True
with VllmRunner(**baseline_kwargs) as runner:
baseline_outputs = runner.model.generate(prompts=prompts, sampling_params=_LOGPROB_SAMPLING_PARAMS)
with VllmRunner(**runner_kwargs) as runner:
vllm_aclgraph_outputs = runner.model.generate(prompts=prompts, sampling_params=sampling_params)
outputs_gen = []
for output in vllm_aclgraph_outputs:
outputs_gen.append(([output.outputs[0].index], output.outputs[0].text))
compiled_outputs = runner.model.generate(prompts=prompts, sampling_params=_LOGPROB_SAMPLING_PARAMS)
output_origin = [([0], answer) for answer in golden_answers]
for prompt_idx, (base_out, comp_out) in enumerate(zip(baseline_outputs, compiled_outputs)):
base_seq = base_out.outputs[0]
comp_seq = comp_out.outputs[0]
check_outputs_equal(
outputs_0_lst=output_origin,
outputs_1_lst=outputs_gen,
name_0="output_origin",
name_1="outputs_gen",
)
assert base_seq.logprobs is not None and comp_seq.logprobs is not None, (
f"logprobs not returned for prompt {prompt_idx}"
)
assert len(base_seq.token_ids) == len(comp_seq.token_ids) == 3, (
f"Expected 3 tokens for prompt {prompt_idx}, "
f"got baseline={len(base_seq.token_ids)}, compiled={len(comp_seq.token_ids)}"
)
_check_prefill_token(base_seq, comp_seq, prompt_idx, atol)
for token_idx in range(1, 3):
_check_decode_token(base_seq, comp_seq, token_idx, prompt_idx, decode_atol)