[bugfix]Fix accuracy issue in PCP/DCP with speculative decoding (#6491)

### What this PR does / why we need it?

This PR fixes an accuracy issue that occurs when using Prefill/Decode
Context Parallelism (PCP/DCP) in conjunction with speculative decoding
(MTP). The issue is caused by an irregular attention mask shape when
both features are enabled.

The fix involves flattening the `block_table` for speculative decoding
requests under PCP/DCP to ensure a regular attention mask. This PR also
introduces a `use_cp` property for cleaner code and updates dummy runs
to handle this scenario correctly.

### Does this PR introduce _any_ user-facing change?

No. This is a bug fix that improves accuracy and should not have
user-facing API changes.

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2026-02-05 10:06:14 +08:00
committed by GitHub
parent 0ead5e8681
commit 13c4a9c78b
3 changed files with 66 additions and 13 deletions

View File

@@ -24,6 +24,8 @@ import torch
from vllm.config import VllmConfig
from vllm.v1.utils import CpuGpuBuffer
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -514,13 +516,23 @@ class PCPManager:
dcp_local_seq_lens = (base + remainder).reshape([-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens):
def generate_pcp_metadata(
self,
total_num_scheduled_tokens: int,
query_lens: torch.Tensor,
input_batch: "NPUInputBatch",
num_scheduled_tokens: np.ndarray | None,
block_table_tensor: torch.Tensor,
num_reqs_padded: int,
num_reqs: int,
):
from vllm_ascend.attention.utils import AscendPrefillContextParallelMetadata
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
if self.pcp_world_size * self.dcp_world_size > 1:
assert num_scheduled_tokens is not None
decode_context_lens = (
input_batch.num_computed_tokens_cpu[: self.num_decode_reqs]
+ num_scheduled_tokens[: self.num_decode_reqs]
@@ -544,6 +556,7 @@ class PCPManager:
self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
)
)
ori_query_lens_cpu = None
if self.decode_threshold > 1:
num_computed_tokens_of_pcp_dcp_list = []
if self.num_decode_reqs:
@@ -563,10 +576,37 @@ class PCPManager:
]
)
num_computed_tokens_of_pcp_dcp = torch.cat(num_computed_tokens_of_pcp_dcp_list, dim=0)
# For pcp + spec decode, we flatten block_table
# to avoid irregular attn_mask shape, e.g.,
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
# ori block_table: # [d0, d1, p0, p1, p2]
# (num_reqs_d + num_reqs_p, max_num_blocks),
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded]
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs_padded]
num_prefill_reqs = self.num_prefill_reqs
num_decode_reqs = self.num_decode_reqs
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
)
block_table_tensor[:num_decode_reqs_flatten].copy_(
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(),
)
if ori_query_lens_cpu is not None:
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
long_seq_metadata.max_query_len_pcp_full = ori_query_lens_cpu.max().item()
if self.pcp_world_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
@@ -685,8 +725,9 @@ class PCPManager:
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
long_seq_metadata.head_attn_nomask_seqlens = head_attn_nomask_seqlens_list
long_seq_metadata.tail_attn_nomask_seqlens = tail_attn_nomask_seqlens_list
self.long_seq_metadata = long_seq_metadata
return long_seq_metadata
return long_seq_metadata, block_table_tensor
def _list_to_tensor(self, lst, device, dtype=torch.int32):
tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device)