[v0.18.0][Bugfix] fix ds3.2 dcp mtp (#7681)

### What this PR does / why we need it?
Fixed the issue where the DCP overlaps the MTP scenario in the ds3.2
scenario.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

cherry-pick from: https://github.com/vllm-project/vllm-ascend/pull/7617

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2026-03-27 14:24:53 +08:00
committed by GitHub
parent 048c8d1afe
commit bc8e87f3db
10 changed files with 18 additions and 7 deletions

View File

@@ -168,6 +168,7 @@ class AscendMetadata:
# should simplified these parameters once attention schema in vLLM-Ascend
# is unified.
seq_lens: torch.Tensor = None
seq_lens_cpu: torch.Tensor = None
seq_lens_list: list[int] = None # type: ignore
actual_seq_lengths_q: list[int] = None # type: ignore
@@ -307,6 +308,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
block_tables=block_table,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_query_len=common_attn_metadata.max_query_len,
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),

View File

@@ -232,6 +232,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
block_tables=block_table,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_query_len=common_attn_metadata.max_query_len,
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),

View File

@@ -257,8 +257,8 @@ class AscendSFACPImpl(AscendSFAImpl):
return self._align_to_graph_bucket_tokens(attn_output, attn_metadata)
def _align_to_graph_bucket_tokens(self, attn_output: torch.Tensor | None, attn_metadata: M) -> torch.Tensor | None:
if attn_output is None:
return None
if attn_output is None or self.pcp_size == 1:
return attn_output
# In graph/piecewise mode, output buffer uses graph bucket token size
# (forward_context.num_tokens), while PCP path may compute only valid
# tokens. Align to the larger one to avoid later write-back mismatch.

View File

@@ -174,6 +174,7 @@ class AscendMLAMetadata:
slot_mapping: torch.Tensor
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
block_tables: torch.Tensor
# New for MLA (compared to FlashAttention)
@@ -457,6 +458,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
query_start_loc=query_start_loc,
block_tables=self.block_table,
seq_lens=self.seq_lens,
seq_lens_cpu=self.seq_lens,
)
def build_chunked_metadata(

View File

@@ -130,6 +130,7 @@ class AscendSFAMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
cum_query_lens: torch.Tensor
block_table: torch.Tensor
sin: torch.Tensor
@@ -233,6 +234,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_reqs]
cos, sin = get_cos_and_sin_mla(input_positions, True)
@@ -320,6 +322,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
num_actual_tokens=num_actual_tokens,
cum_query_lens=cum_query_lens,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config),

View File

@@ -600,7 +600,7 @@ class SpecDecodeBaseProposer(EagleProposer):
- 1
)
num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens
ori_seq_len = attn_metadata_i.seq_lens[:batch_size].clone()
ori_seq_len = attn_metadata_i.seq_lens_cpu[:batch_size].clone()
mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad
# slot_mapping index base offset:
@@ -1247,7 +1247,8 @@ class SpecDecodeBaseProposer(EagleProposer):
if self.pcp_size * self.dcp_size > 1:
if self.vllm_config.model_config.use_mla:
attn_metadata.decode.cp_seq_len = cp_seq_len
if getattr(attn_metadata, "decode", None):
attn_metadata.decode.cp_seq_len = cp_seq_len
else:
attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp