[bugfix] pcp + mtp acl graph bugfix (#4221)

Fix pcp + mtp bug while using acl graph.
While using pcp + mtp, we need to flatten block_table to avoid irregular
attn mask shape, this was done in mla attn_metadata builder, but we
found out that this influences block_table address and leads to
incorrect results while enable acl graph.
To fix this, we enlarge block_table buffer size and flatten block_table
in model_runner prepare_inputs, so this will not influence block_table
address.

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-11-19 11:21:46 +08:00
committed by GitHub
parent 9328f377b4
commit df777e9faa
3 changed files with 69 additions and 24 deletions

View File

@@ -369,6 +369,12 @@ class AscendMLAMetadataBuilder:
device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
if self.pcp_size > 1:
num_decodes_flatten = num_decodes * self.decode_threshold
block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten
+
num_prefills]
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens
@@ -546,6 +552,9 @@ class AscendMLAMetadataBuilder:
cos=cos,
pcp_metadata=pcp_metadata,
)
if self.pcp_size > 1:
prefill_metadata.block_table = block_table[
num_decodes_flatten:, ...]
decode_metadata = None
if num_decodes > 0:
@@ -556,12 +565,12 @@ class AscendMLAMetadataBuilder:
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decodes, ...]
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
if self.pcp_size > 1 and self.decode_threshold > 1:
block_table = block_table.repeat_interleave(
self.decode_threshold, dim=0)
if self.pcp_size > 1:
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
block_table = block_table[:num_decodes_flatten, ...]
else:
block_table = block_table[:num_decodes, ...]
seq_lens_list = seq_lens.tolist()
if num_computed_tokens_of_pcp_dcp is not None: