[Feat][main] Supported to use full-graph with Qwen3-Next-MTP (#5477)

### What this PR does / why we need it?

Supported to use full-graph with Qwen3-Next-MTP.

In detail, we adatpted `AscendAttentionState.ChunkedPrefill` in main
model, and also adapted `AscendAttentionState.ChunkedPrefill` in mtp
model.

### Does this PR introduce _any_ user-facing change?

N/A

### How was this patch tested?

We changed the test of Qwen3-Next-MTP in
`tests/e2e/multicard/test_qwen3_next.py` to make it a test of
`FULL_DECODE_ONLY`. Then run `pytest -s
tests/e2e/multicard/test_qwen3_next.py::test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4`.

And this test passed.

```text
.

================================================================================================================================= warnings summary =================================================================================================================================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================================== 1 passed, 2 warnings in 271.89s (0:04:31) =====================================================================================================================
```
- vLLM version: v0.13.0
- vLLM main:
5326c89803

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2026-01-04 12:03:21 +08:00
committed by GitHub
parent fd4b4fd06f
commit 363ac1b80f
4 changed files with 42 additions and 32 deletions

View File

@@ -30,6 +30,8 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
update_attn_dcp_pcp_params,
update_attn_params,
update_mla_attn_dcp_pcp_params,
update_mla_attn_params)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
@@ -66,6 +68,24 @@ class MtpProposer(EagleProposer):
# TODO: Find out why ModelRunner does not this explicit typing?
model: Union[nn.Module, ACLGraphWrapper]
# update full-graph params for one spec token
def _update_full_graph_params(self, forward_context, num_tokens):
if self.vllm_config.model_config.use_mla:
if self.pcp_size * self.dcp_size > 1:
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context, num_tokens)
else:
update_mla_attn_params(self.update_stream, forward_context,
num_tokens,
self.vllm_config.speculative_config)
else:
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream, forward_context,
num_tokens)
else:
update_attn_params(self.update_stream, forward_context,
num_tokens, self.vllm_config)
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)
@@ -141,7 +161,7 @@ class MtpProposer(EagleProposer):
num_tokens_across_dp,
with_prefill,
) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill)
if self.use_async_scheduling:
if not self.use_cuda_graph:
# there is synchronization between mtp steps when enabling aclgraph,
# disable aclgraph when use async scheduling to avoid the
# synchronization overhead.
@@ -185,8 +205,10 @@ class MtpProposer(EagleProposer):
:num_reqs * self.decode_threshold]
builder = self.runner.attn_groups[0][0].get_metadata_builder()
# `AscendAttentionState.SpecDecoding` is only designed for mla, `AscendAttentionState.ChunkedPrefill` is used in self-attention.
attn_state = AscendAttentionState.SpecDecoding if self.vllm_config.model_config.use_mla else AscendAttentionState.ChunkedPrefill
attn_metadata_mtp = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.SpecDecoding)
common_attn_metadata, attn_state)
attn_metadata = {}
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_mtp
@@ -222,17 +244,9 @@ class MtpProposer(EagleProposer):
hidden_states=previous_hidden_states)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing:
if self.vllm_config.model_config.use_mla and not self.use_sparse:
if self.pcp_size * self.dcp_size > 1:
update_mla_attn_dcp_pcp_params(
self.update_stream, forward_context,
num_tokens)
else:
update_mla_attn_params(
self.update_stream, forward_context,
num_tokens,
self.vllm_config.speculative_config)
not forward_context.capturing and not self.use_sparse:
self._update_full_graph_params(forward_context, num_tokens)
if self.enable_shared_expert_dp:
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
positions, True)
@@ -654,7 +668,7 @@ class MtpProposer(EagleProposer):
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
aclgraph_runtime_mode, batch_descriptor = \
self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
if self.use_async_scheduling:
if not self.use_cuda_graph:
# there is synchronization between mtp steps when enabling aclgraph,
# disable aclgraph when use async scheduling to avoid the
# synchronization overhead.
@@ -721,17 +735,9 @@ class MtpProposer(EagleProposer):
positions=positions,
hidden_states=hidden_states)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
if self.vllm_config.model_config.use_mla and not self.use_sparse:
if self.pcp_size * self.dcp_size > 1:
update_mla_attn_dcp_pcp_params(
self.update_stream, forward_context,
num_input_tokens)
else:
update_mla_attn_params(
self.update_stream, forward_context,
num_input_tokens,
self.vllm_config.speculative_config)
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not self.use_sparse:
self._update_full_graph_params(forward_context,
num_input_tokens)
if self.enable_shared_expert_dp:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(