add new accuracy test case for aclgraph (#3390)

### What this PR does / why we need it?
Add new accuracy test case Deepseek-V2-Lite-W8A8 for aclgraph

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2025-10-20 20:04:04 +08:00
committed by GitHub
parent b9e2896eb1
commit 70bef33f13
3 changed files with 83 additions and 28 deletions

View File

@@ -177,6 +177,7 @@ jobs:
pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/test_expert_parallel.py pytest -sv tests/e2e/multicard/test_expert_parallel.py
pytest -sv tests/e2e/multicard/test_external_launcher.py pytest -sv tests/e2e/multicard/test_external_launcher.py
pytest -sv tests/e2e/multicard/test_single_request_aclgraph.py
pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py

View File

@@ -21,6 +21,8 @@ Run `pytest tests/compile/test_aclgraph.py`.
""" """
import os import os
import random
import string
import pytest import pytest
from vllm import SamplingParams from vllm import SamplingParams
@@ -30,6 +32,7 @@ from tests.e2e.model_utils import check_outputs_equal
MODELS = [ MODELS = [
"Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B",
"vllm-ascend/DeepSeek-V2-Lite-W8A8",
] ]
@@ -45,20 +48,40 @@ def test_models_with_aclgraph(
] ]
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0) sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8":
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=False,
quantization="ascend",
) as runner:
vllm_aclgraph_outputs = runner.model.generate(
prompts, sampling_params)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=True,
quantization="ascend",
) as runner:
vllm_eager_outputs = runner.model.generate(prompts,
sampling_params)
else:
with VllmRunner( with VllmRunner(
model, model,
max_model_len=1024, max_model_len=1024,
enforce_eager=False, enforce_eager=False,
) as runner: ) as runner:
vllm_aclgraph_outputs = runner.model.generate(prompts, sampling_params) vllm_aclgraph_outputs = runner.model.generate(
prompts, sampling_params)
with VllmRunner( with VllmRunner(
model, model,
max_model_len=1024, max_model_len=1024,
enforce_eager=True, enforce_eager=True,
) as runner: ) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params) vllm_eager_outputs = runner.model.generate(prompts,
sampling_params)
vllm_aclgraph_outputs_list = [] vllm_aclgraph_outputs_list = []
for output in vllm_aclgraph_outputs: for output in vllm_aclgraph_outputs:
vllm_aclgraph_outputs_list.append( vllm_aclgraph_outputs_list.append(
@@ -85,6 +108,9 @@ def test_models_with_aclgraph_full_decode_only(
) -> None: ) -> None:
if 'HCCL_OP_EXPANSION_MODE' in os.environ: if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE'] del os.environ['HCCL_OP_EXPANSION_MODE']
# NOTE: Randomly fill the prompt with the requested amount for
# the specified capture shape to prevent accuracy issues caused by padding
random_number = random.choice(list(range(6, 47, 8)))
prompts = [ prompts = [
('Solve the following math problem step by step.' ('Solve the following math problem step by step.'
'The last line of your response should be of the form Answer: ' 'The last line of your response should be of the form Answer: '
@@ -110,6 +136,9 @@ def test_models_with_aclgraph_full_decode_only(
'and $x^2 + bx + c = 0$ have a common real root, and the equations $x^2 + x + a = 0$' 'and $x^2 + bx + c = 0$ have a common real root, and the equations $x^2 + x + a = 0$'
'and $x^2 + cx + b = 0$ also have a common real root.' 'and $x^2 + cx + b = 0$ also have a common real root.'
'Compute the sum $a + b + c$.') 'Compute the sum $a + b + c$.')
] + [
''.join(random.choices(string.ascii_lowercase, k=random.randint(
1, 25))) for _ in range(random_number)
] ]
sampling_params = SamplingParams(max_tokens=5, sampling_params = SamplingParams(max_tokens=5,
@@ -117,20 +146,42 @@ def test_models_with_aclgraph_full_decode_only(
temperature=0.0, temperature=0.0,
top_p=1.0, top_p=1.0,
top_k=1) top_k=1)
if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8":
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=False,
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY"},
quantization="ascend",
) as runner:
vllm_aclgraph_outputs = runner.model.generate(
prompts, sampling_params)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=True,
quantization="ascend",
) as runner:
vllm_eager_outputs = runner.model.generate(prompts,
sampling_params)
else:
with VllmRunner( with VllmRunner(
model, model,
max_model_len=1024, max_model_len=1024,
enforce_eager=False, enforce_eager=False,
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY"}, compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY"},
) as runner: ) as runner:
vllm_aclgraph_outputs = runner.model.generate(prompts, sampling_params) vllm_aclgraph_outputs = runner.model.generate(
prompts, sampling_params)
with VllmRunner( with VllmRunner(
model, model,
max_model_len=1024, max_model_len=1024,
enforce_eager=True, enforce_eager=True,
) as runner: ) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params) vllm_eager_outputs = runner.model.generate(prompts,
sampling_params)
vllm_aclgraph_outputs_list = [] vllm_aclgraph_outputs_list = []
for output in vllm_aclgraph_outputs: for output in vllm_aclgraph_outputs:

View File

@@ -976,17 +976,20 @@ class AscendMLAImpl(MLAAttentionImpl):
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill # Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
input_layout = "TND" input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim] # [bs * q_seq_len, num_heads_per_rank, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1) # TODO: If the driver is upgraded later, the contiguous function can be deleted.
q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, -1) q_pe = q_pe.view(num_tokens, self.num_heads, -1)
sparse_mode = 3 sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q actual_seq_lengths = decode_meta.actual_seq_lengths_q
else: else:
if self.enable_kv_nz: if self.enable_kv_nz:
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1) q_nope = q_nope.view(num_tokens, 1, self.num_heads,
-1).contiguous()
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else: else:
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1) q_nope = q_nope.view(num_tokens, self.num_heads, 1,
-1).contiguous()
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
sparse_mode = 0 sparse_mode = 0
spec_attn_mask = None spec_attn_mask = None