diff --git a/tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py b/tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py index 7ea0ca94..816c25f0 100644 --- a/tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py +++ b/tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py @@ -34,7 +34,6 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" MODELS = ["Qwen/Qwen3-Next-80B-A3B-Instruct"] -# TODO: add full decode only (when ready) @pytest.mark.parametrize("model_name", MODELS) def test_qwen3_next_mtp_acceptance_tp4(model_name): golden = [0.85, 0.46, 0.19] @@ -55,6 +54,7 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name): distributed_executor_backend="mp", disable_log_stats=False, speculative_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", "method": "qwen3_next_mtp", "num_speculative_tokens": 3, }, @@ -88,6 +88,8 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name): cleanup_dist_env_and_memory() +# FIXME: When applying `FULL_DECODE_ONLY` in this e2e, ci will fail. +# The failure can not be reproduced locally. @pytest.mark.parametrize("model_name", MODELS) @pytest.mark.parametrize("num_speculative_tokens", [1]) @pytest.mark.parametrize("disable_padded_drafter_batch", [True, False]) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 1405ed9f..54999de2 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -293,7 +293,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): ) else: raise NotImplementedError( - "Currently we only support building dummy metadata for DecodeOnly state" + "Currently we only support building dummy metadata for DecodeOnly and ChunkedPrefill state" ) attn_metadata.attn_state = attn_state diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 51bc6325..08990671 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -30,6 +30,8 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, + update_attn_dcp_pcp_params, + update_attn_params, update_mla_attn_dcp_pcp_params, update_mla_attn_params) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla @@ -66,6 +68,24 @@ class MtpProposer(EagleProposer): # TODO: Find out why ModelRunner does not this explicit typing? model: Union[nn.Module, ACLGraphWrapper] + # update full-graph params for one spec token + def _update_full_graph_params(self, forward_context, num_tokens): + if self.vllm_config.model_config.use_mla: + if self.pcp_size * self.dcp_size > 1: + update_mla_attn_dcp_pcp_params(self.update_stream, + forward_context, num_tokens) + else: + update_mla_attn_params(self.update_stream, forward_context, + num_tokens, + self.vllm_config.speculative_config) + else: + if self.pcp_size * self.dcp_size > 1: + update_attn_dcp_pcp_params(self.update_stream, forward_context, + num_tokens) + else: + update_attn_params(self.update_stream, forward_context, + num_tokens, self.vllm_config) + def load_model(self, model) -> None: loader = get_model_loader(self.vllm_config.load_config) @@ -141,7 +161,7 @@ class MtpProposer(EagleProposer): num_tokens_across_dp, with_prefill, ) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill) - if self.use_async_scheduling: + if not self.use_cuda_graph: # there is synchronization between mtp steps when enabling aclgraph, # disable aclgraph when use async scheduling to avoid the # synchronization overhead. @@ -185,8 +205,10 @@ class MtpProposer(EagleProposer): :num_reqs * self.decode_threshold] builder = self.runner.attn_groups[0][0].get_metadata_builder() + # `AscendAttentionState.SpecDecoding` is only designed for mla, `AscendAttentionState.ChunkedPrefill` is used in self-attention. + attn_state = AscendAttentionState.SpecDecoding if self.vllm_config.model_config.use_mla else AscendAttentionState.ChunkedPrefill attn_metadata_mtp = builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.SpecDecoding) + common_attn_metadata, attn_state) attn_metadata = {} for layer_name in self.attn_layer_name: attn_metadata[layer_name] = attn_metadata_mtp @@ -222,17 +244,9 @@ class MtpProposer(EagleProposer): hidden_states=previous_hidden_states) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ - not forward_context.capturing: - if self.vllm_config.model_config.use_mla and not self.use_sparse: - if self.pcp_size * self.dcp_size > 1: - update_mla_attn_dcp_pcp_params( - self.update_stream, forward_context, - num_tokens) - else: - update_mla_attn_params( - self.update_stream, forward_context, - num_tokens, - self.vllm_config.speculative_config) + not forward_context.capturing and not self.use_sparse: + self._update_full_graph_params(forward_context, num_tokens) + if self.enable_shared_expert_dp: positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( positions, True) @@ -654,7 +668,7 @@ class MtpProposer(EagleProposer): has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0 aclgraph_runtime_mode, batch_descriptor = \ self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora) - if self.use_async_scheduling: + if not self.use_cuda_graph: # there is synchronization between mtp steps when enabling aclgraph, # disable aclgraph when use async scheduling to avoid the # synchronization overhead. @@ -721,17 +735,9 @@ class MtpProposer(EagleProposer): positions=positions, hidden_states=hidden_states) forward_context = get_forward_context() - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - if self.vllm_config.model_config.use_mla and not self.use_sparse: - if self.pcp_size * self.dcp_size > 1: - update_mla_attn_dcp_pcp_params( - self.update_stream, forward_context, - num_input_tokens) - else: - update_mla_attn_params( - self.update_stream, forward_context, - num_input_tokens, - self.vllm_config.speculative_config) + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not self.use_sparse: + self._update_full_graph_params(forward_context, + num_input_tokens) if self.enable_shared_expert_dp: hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b4b7e3a0..2b21955e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -894,9 +894,6 @@ class NPUModelRunner(GPUModelRunner): self.logits_indices = logits_indices # Used in the below loop. - # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] - num_computed_tokens_cpu = ( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) self.spec_decode_common_attn_metadata = None if use_spec_decode and self.need_accepted_tokens: self.num_accepted_tokens.np[:num_reqs] = ( @@ -991,7 +988,8 @@ class NPUModelRunner(GPUModelRunner): # TODO: change this to the right block table for linear attn block_table_tensor=blk_table_tensor[:num_reqs], slot_mapping=slot_mapping, - num_computed_tokens_cpu=num_computed_tokens_cpu, + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], positions=self.positions.gpu, attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, @@ -1822,7 +1820,11 @@ class NPUModelRunner(GPUModelRunner): attn_state = AscendAttentionState.DecodeOnly if self.speculative_config and \ self.speculative_config.method == "mtp": - attn_state = AscendAttentionState.SpecDecoding + # `AscendAttentionState.SpecDecoding` is only designed for mla + if self.vllm_config.model_config.use_mla: + attn_state = AscendAttentionState.SpecDecoding + else: + attn_state = AscendAttentionState.ChunkedPrefill common_metadata = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[:num_reqs + 1],