[Feat][main] Supported to use full-graph with Qwen3-Next-MTP (#5477)
### What this PR does / why we need it?
Supported to use full-graph with Qwen3-Next-MTP.
In detail, we adatpted `AscendAttentionState.ChunkedPrefill` in main
model, and also adapted `AscendAttentionState.ChunkedPrefill` in mtp
model.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
We changed the test of Qwen3-Next-MTP in
`tests/e2e/multicard/test_qwen3_next.py` to make it a test of
`FULL_DECODE_ONLY`. Then run `pytest -s
tests/e2e/multicard/test_qwen3_next.py::test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4`.
And this test passed.
```text
.
================================================================================================================================= warnings summary =================================================================================================================================
<frozen importlib._bootstrap>:241
<frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute
<frozen importlib._bootstrap>:241
<frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================================================================================================================== 1 passed, 2 warnings in 271.89s (0:04:31) =====================================================================================================================
```
- vLLM version: v0.13.0
- vLLM main:
5326c89803
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -34,7 +34,6 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
MODELS = ["Qwen/Qwen3-Next-80B-A3B-Instruct"]
|
||||
|
||||
|
||||
# TODO: add full decode only (when ready)
|
||||
@pytest.mark.parametrize("model_name", MODELS)
|
||||
def test_qwen3_next_mtp_acceptance_tp4(model_name):
|
||||
golden = [0.85, 0.46, 0.19]
|
||||
@@ -55,6 +54,7 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name):
|
||||
distributed_executor_backend="mp",
|
||||
disable_log_stats=False,
|
||||
speculative_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"method": "qwen3_next_mtp",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
@@ -88,6 +88,8 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name):
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
# FIXME: When applying `FULL_DECODE_ONLY` in this e2e, ci will fail.
|
||||
# The failure can not be reproduced locally.
|
||||
@pytest.mark.parametrize("model_name", MODELS)
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1])
|
||||
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
|
||||
|
||||
@@ -293,7 +293,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently we only support building dummy metadata for DecodeOnly state"
|
||||
"Currently we only support building dummy metadata for DecodeOnly and ChunkedPrefill state"
|
||||
)
|
||||
|
||||
attn_metadata.attn_state = attn_state
|
||||
|
||||
@@ -30,6 +30,8 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_attn_dcp_pcp_params,
|
||||
update_attn_params,
|
||||
update_mla_attn_dcp_pcp_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
@@ -66,6 +68,24 @@ class MtpProposer(EagleProposer):
|
||||
# TODO: Find out why ModelRunner does not this explicit typing?
|
||||
model: Union[nn.Module, ACLGraphWrapper]
|
||||
|
||||
# update full-graph params for one spec token
|
||||
def _update_full_graph_params(self, forward_context, num_tokens):
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||
forward_context, num_tokens)
|
||||
else:
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
num_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
else:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_attn_dcp_pcp_params(self.update_stream, forward_context,
|
||||
num_tokens)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
num_tokens, self.vllm_config)
|
||||
|
||||
def load_model(self, model) -> None:
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
|
||||
@@ -141,7 +161,7 @@ class MtpProposer(EagleProposer):
|
||||
num_tokens_across_dp,
|
||||
with_prefill,
|
||||
) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill)
|
||||
if self.use_async_scheduling:
|
||||
if not self.use_cuda_graph:
|
||||
# there is synchronization between mtp steps when enabling aclgraph,
|
||||
# disable aclgraph when use async scheduling to avoid the
|
||||
# synchronization overhead.
|
||||
@@ -185,8 +205,10 @@ class MtpProposer(EagleProposer):
|
||||
:num_reqs * self.decode_threshold]
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
# `AscendAttentionState.SpecDecoding` is only designed for mla, `AscendAttentionState.ChunkedPrefill` is used in self-attention.
|
||||
attn_state = AscendAttentionState.SpecDecoding if self.vllm_config.model_config.use_mla else AscendAttentionState.ChunkedPrefill
|
||||
attn_metadata_mtp = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.SpecDecoding)
|
||||
common_attn_metadata, attn_state)
|
||||
attn_metadata = {}
|
||||
for layer_name in self.attn_layer_name:
|
||||
attn_metadata[layer_name] = attn_metadata_mtp
|
||||
@@ -222,17 +244,9 @@ class MtpProposer(EagleProposer):
|
||||
hidden_states=previous_hidden_states)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||
not forward_context.capturing:
|
||||
if self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_mla_attn_dcp_pcp_params(
|
||||
self.update_stream, forward_context,
|
||||
num_tokens)
|
||||
else:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
num_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
not forward_context.capturing and not self.use_sparse:
|
||||
self._update_full_graph_params(forward_context, num_tokens)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
positions, True)
|
||||
@@ -654,7 +668,7 @@ class MtpProposer(EagleProposer):
|
||||
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
|
||||
if self.use_async_scheduling:
|
||||
if not self.use_cuda_graph:
|
||||
# there is synchronization between mtp steps when enabling aclgraph,
|
||||
# disable aclgraph when use async scheduling to avoid the
|
||||
# synchronization overhead.
|
||||
@@ -721,17 +735,9 @@ class MtpProposer(EagleProposer):
|
||||
positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
if self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_mla_attn_dcp_pcp_params(
|
||||
self.update_stream, forward_context,
|
||||
num_input_tokens)
|
||||
else:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
num_input_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not self.use_sparse:
|
||||
self._update_full_graph_params(forward_context,
|
||||
num_input_tokens)
|
||||
|
||||
if self.enable_shared_expert_dp:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
|
||||
@@ -894,9 +894,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.logits_indices = logits_indices
|
||||
|
||||
# Used in the below loop.
|
||||
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
self.spec_decode_common_attn_metadata = None
|
||||
if use_spec_decode and self.need_accepted_tokens:
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
@@ -991,7 +988,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# TODO: change this to the right block table for linear attn
|
||||
block_table_tensor=blk_table_tensor[:num_reqs],
|
||||
slot_mapping=slot_mapping,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_computed_tokens_cpu=self.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||
positions=self.positions.gpu,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
@@ -1822,7 +1820,11 @@ class NPUModelRunner(GPUModelRunner):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
if self.speculative_config and \
|
||||
self.speculative_config.method == "mtp":
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# `AscendAttentionState.SpecDecoding` is only designed for mla
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
else:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
common_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc.gpu[:num_reqs + 1],
|
||||
|
||||
Reference in New Issue
Block a user