[bugfix] pcp + mtp acl graph bugfix (#4221)

Fix pcp + mtp bug while using acl graph.
While using pcp + mtp, we need to flatten block_table to avoid irregular
attn mask shape, this was done in mla attn_metadata builder, but we
found out that this influences block_table address and leads to
incorrect results while enable acl graph.
To fix this, we enlarge block_table buffer size and flatten block_table
in model_runner prepare_inputs, so this will not influence block_table
address.

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-11-19 11:21:46 +08:00
committed by GitHub
parent 9328f377b4
commit df777e9faa
3 changed files with 69 additions and 24 deletions

View File

@@ -596,6 +596,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.is_pooling_model,
self.vllm_config.model_config.logits_processors),
is_pooling_model=self.is_pooling_model,
num_speculative_tokens=(
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config else 0),
kernel_block_sizes=[[self.vllm_config.cache_config.block_size]],
cp_kv_cache_interleave_size=self.parallel_config.
cp_kv_cache_interleave_size
@@ -1922,6 +1925,31 @@ class NPUModelRunner(LoRAModelRunnerMixin):
prefill_context_parallel_metadata=long_seq_metadata,
)
if self.speculative_config and self.pcp_size > 1:
# For pcp + spec decode, we flatten block_table
# to avoid irregular spec_attn_mask shape, e.g.,
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
# ori block_table: # [d0, d1, p0, p1, p2]
# (num_reqs_d + num_reqs_p, max_num_blocks),
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
ori_query_lens = self.query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
self.query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold
blk_table_tensor[
num_decode_reqs_flatten:num_decode_reqs_flatten +
num_prefill_reqs].copy_(
blk_table_tensor[num_decode_reqs:num_decode_reqs +
num_prefill_reqs].clone())
blk_table_tensor[:num_decode_reqs_flatten].copy_(
blk_table_tensor[:num_decode_reqs].repeat_interleave(
self.decode_threshold, dim=0))
common_attn_metadata.block_table_tensor = \
blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs]
if self.speculative_config and \
self.spec_decode_common_attn_metadata is None:
self.spec_decode_common_attn_metadata = common_attn_metadata
@@ -2831,6 +2859,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
sin=self.sin,
prefill_context_parallel_metadata=long_seq_metadata,
)
if self.pcp_size > 1:
common_attn_metadata.block_table_tensor = \
block_table_tensor[:num_reqs * self.decode_threshold]
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and \
self.speculative_config.method == "deepseek_mtp":