support qwen3-next full_decode_only mode. (#3949)

### What this PR does / why we need it?
support qwen3-next full_decode_only mode. 
bs=1, max_token=1024
| branch| tps| e2e time|
| --- | --- | --- |
|piecewise  |3.06  | 8.15 |
|fulldecodeonly | 7.2 | 3.47 |

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-11-05 08:46:05 +08:00
committed by GitHub
parent 5f08e07208
commit 738bf2b720
4 changed files with 66 additions and 9 deletions

View File

@@ -36,3 +36,21 @@ def test_models_distributed_Qwen3_NEXT_TP4():
distributed_executor_backend="mp",
enforce_eager=True) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
example_prompts = [
"Hello, my name is",
] * 4
max_tokens = 5
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
max_model_len=4096,
gpu_memory_utilization=0.8,
distributed_executor_backend="mp",
enforce_eager=False,
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [1, 8, 24, 48, 60]
}) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -192,8 +192,10 @@ class ACLGraphWrapper:
def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
# For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with
# self_attn followed by linear_attn. Therefore, using zip directly
# filters out the update operations for linear_attn.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
@@ -289,9 +291,9 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
graph_params = get_graph_params()
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,

View File

@@ -31,6 +31,7 @@ from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
import numpy as np
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform
@@ -178,6 +179,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
num_reqs: int,
num_tokens: int,
max_query_len: int,
num_scheduled_tokens: np.ndarray,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
) -> Optional[dict[str, Any]]:
@@ -186,7 +188,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
if with_prefill or self.enable_shared_expert_dp:
attn_metadata = super()._build_dummy_attn_metadata(
with_prefill, num_reqs, num_tokens, max_query_len,
aclgraph_runtime_mode, force_attention)
num_scheduled_tokens, aclgraph_runtime_mode, force_attention)
else:
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs,

View File

@@ -76,7 +76,8 @@ from vllm.utils import cdiv, length_from_prompt_token_ids_or_embeds
from vllm.utils.jsontree import json_map_leaves
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
AttentionCGSupport, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
@@ -107,7 +108,8 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import (MoECommType,
set_ascend_forward_context)
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.attention_v1 import (AscendAttentionMetadataBuilder,
AscendAttentionState)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata)
# yapf: disable
@@ -2644,6 +2646,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_reqs: int,
num_tokens: int,
max_query_len: int,
num_scheduled_tokens: np.ndarray,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
) -> Optional[dict[str, Any]]:
@@ -2659,6 +2662,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens_np[:num_reqs] = seq_lens
self.seq_lens_np[num_reqs:] = 0
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
self.device).to(torch.int32)
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
self.query_start_loc_cpu[1:num_reqs +
1] = torch.Tensor(cu_num_tokens)
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -2715,12 +2726,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.speculative_config.method == "deepseek_mtp":
attn_state = AscendAttentionState.SpecDecoding
common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
1],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
block_table_tensor=block_table_tensor[:num_reqs],
slot_mapping=slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
max_seq_len=seq_lens)
for attn_group in self.attn_groups[kv_cache_group_id]:
builder = attn_group.get_metadata_builder()
attn_metadata_i = builder.build_for_graph_capture(
common_attn_metadata, attn_state, self.get_model())
if isinstance(builder, AscendAttentionMetadataBuilder):
attn_metadata_full_attention = builder.build_for_graph_capture(
common_attn_metadata, attn_state, self.get_model())
elif isinstance(builder, GDNAttentionMetadataBuilder):
attn_metadata_gdn_attention = builder.build_for_cudagraph_capture(
common_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
if "linear_attn" in layer_name:
attn_metadata[
layer_name] = attn_metadata_gdn_attention
else:
attn_metadata[
layer_name] = attn_metadata_full_attention
return attn_metadata
@@ -2895,6 +2929,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
max_query_len=max_query_len,
aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention,
num_scheduled_tokens=num_scheduled_tokens,
)
need_dummy_logits = (not self.in_profile_run