From bc8e87f3db5c6f4642a797364e36977efa786e29 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Fri, 27 Mar 2026 14:24:53 +0800 Subject: [PATCH] [v0.18.0][Bugfix] fix ds3.2 dcp mtp (#7681) ### What this PR does / why we need it? Fixed the issue where the DCP overlaps the MTP scenario in the ds3.2 scenario. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? cherry-pick from: https://github.com/vllm-project/vllm-ascend/pull/7617 Signed-off-by: weiguihua2 --- .../e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml | 4 ++-- tests/ut/attention/test_mla_v1.py | 2 +- tests/ut/attention/test_sfa_v1.py | 1 + tests/ut/compilation/test_acl_graph.py | 1 + vllm_ascend/attention/attention_v1.py | 2 ++ vllm_ascend/attention/context_parallel/attention_cp.py | 1 + vllm_ascend/attention/context_parallel/sfa_cp.py | 4 ++-- vllm_ascend/attention/mla_v1.py | 2 ++ vllm_ascend/attention/sfa_v1.py | 3 +++ vllm_ascend/spec_decode/eagle_proposer.py | 5 +++-- 10 files changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml b/tests/e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml index ad5f0476..77978d4a 100644 --- a/tests/e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml +++ b/tests/e2e/nightly/multi_node/config/DeepSeek-V3_2-W8A8-cp.yaml @@ -36,7 +36,7 @@ deployment: --no-enable-prefix-caching --gpu-memory-utilization 0.85 --trust-remote-code - --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}' + --speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp"}' --compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}' --additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}' --tokenizer-mode deepseek_v32 @@ -62,7 +62,7 @@ deployment: --no-enable-prefix-caching --gpu-memory-utilization 0.85 --trust-remote-code - --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}' + --speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp"}' --compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}' --additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}' --tokenizer-mode deepseek_v32 diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 1cf661bd..7f457898 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -182,7 +182,7 @@ class TestAscendMLAMetadata(TestBase): metadata = AscendMLAMetadata( num_actual_tokens_pcp_padded, num_actual_tokens, slot_mapping, - query_start_loc, seq_lens, block_tables, num_decodes, + query_start_loc, seq_lens, seq_lens, block_tables, num_decodes, num_decode_tokens, num_prefills, num_input_tokens, query_lens, head_dim, attn_mask, attn_state, decode, prefill) diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index a90f7325..48bcbd0a 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -58,6 +58,7 @@ class TestAscendSFAMetadata(TestBase): num_actual_tokens=num_actual_tokens, slot_mapping=slot_mapping, seq_lens=seq_lens, + seq_lens_cpu=seq_lens, cum_query_lens=cum_query_lens, block_table=block_table, sin=sin, diff --git a/tests/ut/compilation/test_acl_graph.py b/tests/ut/compilation/test_acl_graph.py index 7440fc41..828c6e39 100644 --- a/tests/ut/compilation/test_acl_graph.py +++ b/tests/ut/compilation/test_acl_graph.py @@ -803,6 +803,7 @@ class TestPCPDCPGraphParams(TestBase): slot_mapping, query_start_loc, seq_lens, + seq_lens, block_tables, 4, 4, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 1cca7783..62ea6e8c 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -168,6 +168,7 @@ class AscendMetadata: # should simplified these parameters once attention schema in vLLM-Ascend # is unified. seq_lens: torch.Tensor = None + seq_lens_cpu: torch.Tensor = None seq_lens_list: list[int] = None # type: ignore actual_seq_lengths_q: list[int] = None # type: ignore @@ -307,6 +308,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): block_tables=block_table, query_start_loc=query_start_loc, seq_lens=seq_lens, + seq_lens_cpu=seq_lens, seq_lens_list=seq_lens.tolist(), max_query_len=common_attn_metadata.max_query_len, actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index f2d3961d..21168894 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -232,6 +232,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): block_tables=block_table, query_start_loc=query_start_loc, seq_lens=seq_lens, + seq_lens_cpu=seq_lens, seq_lens_list=seq_lens.tolist(), max_query_len=common_attn_metadata.max_query_len, actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), diff --git a/vllm_ascend/attention/context_parallel/sfa_cp.py b/vllm_ascend/attention/context_parallel/sfa_cp.py index dbf6d163..74d88698 100644 --- a/vllm_ascend/attention/context_parallel/sfa_cp.py +++ b/vllm_ascend/attention/context_parallel/sfa_cp.py @@ -257,8 +257,8 @@ class AscendSFACPImpl(AscendSFAImpl): return self._align_to_graph_bucket_tokens(attn_output, attn_metadata) def _align_to_graph_bucket_tokens(self, attn_output: torch.Tensor | None, attn_metadata: M) -> torch.Tensor | None: - if attn_output is None: - return None + if attn_output is None or self.pcp_size == 1: + return attn_output # In graph/piecewise mode, output buffer uses graph bucket token size # (forward_context.num_tokens), while PCP path may compute only valid # tokens. Align to the larger one to avoid later write-back mismatch. diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 1b5a9457..c7c14394 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -174,6 +174,7 @@ class AscendMLAMetadata: slot_mapping: torch.Tensor query_start_loc: torch.Tensor seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor block_tables: torch.Tensor # New for MLA (compared to FlashAttention) @@ -457,6 +458,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): query_start_loc=query_start_loc, block_tables=self.block_table, seq_lens=self.seq_lens, + seq_lens_cpu=self.seq_lens, ) def build_chunked_metadata( diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 0aec273b..692566bf 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -130,6 +130,7 @@ class AscendSFAMetadata: num_actual_tokens: int # Number of tokens excluding padding. slot_mapping: torch.Tensor seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor cum_query_lens: torch.Tensor block_table: torch.Tensor sin: torch.Tensor @@ -233,6 +234,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1] seq_lens = common_attn_metadata.seq_lens[:num_reqs] + seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_reqs] cos, sin = get_cos_and_sin_mla(input_positions, True) @@ -320,6 +322,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): num_actual_tokens=num_actual_tokens, cum_query_lens=cum_query_lens, seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, slot_mapping=slot_mapping, head_dim=self.model_config.get_head_size(), attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config), diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 4da0baff..20a82ebc 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -600,7 +600,7 @@ class SpecDecodeBaseProposer(EagleProposer): - 1 ) num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens - ori_seq_len = attn_metadata_i.seq_lens[:batch_size].clone() + ori_seq_len = attn_metadata_i.seq_lens_cpu[:batch_size].clone() mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad # slot_mapping index base offset: @@ -1247,7 +1247,8 @@ class SpecDecodeBaseProposer(EagleProposer): if self.pcp_size * self.dcp_size > 1: if self.vllm_config.model_config.use_mla: - attn_metadata.decode.cp_seq_len = cp_seq_len + if getattr(attn_metadata, "decode", None): + attn_metadata.decode.cp_seq_len = cp_seq_len else: attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp