[Misc] Refactor aclgraph accuracy test to use logprob-based comparison (#7455)
### What this PR does / why we need it?
Replace text-match assertions with a two-tier logprob accuracy check:
- Prefill (token 0): assert token ID is identical between eager baseline
and compiled mode, then verify logprob matches within `atol`.
- Decode (tokens 1-2): if chosen tokens match, compare logprobs
directly; if they differ, cross-lookup the baseline token in the
compiled model's top-20 distribution and assert the assigned logprob is
within `decode_atol` (defaults to 2x atol). This tolerates minor argmax
drift caused by floating-point differences while still catching
distribution divergence.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.17.0
- vLLM main:
8a680463fa
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
68
tests/e2e/singlecard/test_eager_mode_acc.py
Normal file
68
tests/e2e/singlecard/test_eager_mode_acc.py
Normal file
@@ -0,0 +1,68 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
This file test accuracy via LMEval.
|
||||
It uses local-completions, which interacts with vLLM
|
||||
through the OAI API with N concurrent connections.
|
||||
This simulates real work usage of the API and makes
|
||||
sure that the zmq frontend mp RPC message passing and
|
||||
AsyncLLMEngine are working correctly.
|
||||
"""
|
||||
|
||||
import lm_eval
|
||||
import pytest
|
||||
|
||||
MODEL_NAMES = ["Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8"]
|
||||
NUM_CONCURRENT = 500
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUES = {"Qwen/Qwen3-0.6B": 0.414, "vllm-ascend/DeepSeek-V2-Lite-W8A8": 0.34}
|
||||
|
||||
|
||||
def run_test(model_name, more_args=None):
|
||||
"""Run the end to end accuracy test."""
|
||||
|
||||
# NOTE: Do not add any spaces to the string below, as this will cause parameter parsing errors.
|
||||
model_args = f"pretrained={model_name},max_model_len=4096,enforce_eager=True"
|
||||
|
||||
if more_args is not None:
|
||||
model_args = "{},{}".format(model_args, more_args)
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=model_args,
|
||||
tasks="gsm8k",
|
||||
batch_size="auto",
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert model_name in EXPECTED_VALUES, f"Cannot find the expected value for the model {model_name=}"
|
||||
expected_value = EXPECTED_VALUES[model_name]
|
||||
assert measured_value - RTOL < expected_value and measured_value + RTOL > expected_value, (
|
||||
f"Expected: {expected_value} | Measured: {measured_value}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODEL_NAMES)
|
||||
def test_lm_eval_accuracy(model):
|
||||
"""Run with the V1 Engine."""
|
||||
more_args = None
|
||||
run_test(model, more_args)
|
||||
Reference in New Issue
Block a user