[bugfix]Fix accuracy issue in PCP/DCP with speculative decoding (#6491)

### What this PR does / why we need it?

This PR fixes an accuracy issue that occurs when using Prefill/Decode
Context Parallelism (PCP/DCP) in conjunction with speculative decoding
(MTP). The issue is caused by an irregular attention mask shape when
both features are enabled.

The fix involves flattening the `block_table` for speculative decoding
requests under PCP/DCP to ensure a regular attention mask. This PR also
introduces a `use_cp` property for cleaner code and updates dummy runs
to handle this scenario correctly.

### Does this PR introduce _any_ user-facing change?

No. This is a bug fix that improves accuracy and should not have
user-facing API changes.

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2026-02-05 10:06:14 +08:00
committed by GitHub
parent 0ead5e8681
commit 13c4a9c78b
3 changed files with 66 additions and 13 deletions

View File

@@ -73,9 +73,12 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
query_lens) - input_batch.num_computed_tokens_cpu
query_lens = torch.tensor(query_lens)
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
result, _ = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
input_batch,
num_scheduled_tokens)
num_scheduled_tokens,
torch.tensor([]),
num_reqs_padded=num_reqs,
num_reqs=num_reqs)
if not expect_not_none:
assert result is None, f"Expected to return None, but got {type(result)}"

View File

@@ -571,14 +571,14 @@ class NPUModelRunner(GPUModelRunner):
self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
if self.pcp_size * self.dcp_size > 1:
if self.use_cp:
self.pcp_manager.init_batch_info(
num_scheduled_tokens,
self.input_batch.num_reqs,
)
# for pcp, prefill mtp should use origin scheduleroutput ,
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
if self.speculative_config and self.use_cp:
self.pcp_manager.generate_pcp_mtp_input(
total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens,
@@ -732,7 +732,7 @@ class NPUModelRunner(GPUModelRunner):
spec_decode_metadata = None
num_draft_tokens = None
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
if self.pcp_size * self.dcp_size > 1:
if self.use_cp:
logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens)
logits_indices = logits_indices.pin_memory().to(self.device, non_blocking=True)
else:
@@ -954,7 +954,7 @@ class NPUModelRunner(GPUModelRunner):
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size * self.dcp_size > 1:
if self.use_cp:
long_seq_metadata = self.long_seq_metadata # type: ignore
input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu
@@ -1838,11 +1838,17 @@ class NPUModelRunner(GPUModelRunner):
kv_cache_groups = self.kv_cache_config.kv_cache_groups
def _get_pcp_metadata(num_tokens):
def _get_pcp_metadata(block_table_tensor):
if not self.use_cp:
return None
return None, block_table_tensor
return self.pcp_manager.generate_pcp_metadata(
num_tokens, self.query_lens, self.input_batch, num_scheduled_tokens_np
num_tokens,
self.query_lens,
self.input_batch,
num_scheduled_tokens_np,
block_table_tensor,
num_reqs_padded,
num_reqs,
)
def _get_block_table_and_slot_mapping(kv_cache_gid: int):
@@ -1883,8 +1889,8 @@ class NPUModelRunner(GPUModelRunner):
)
return blk_table_tensor, slot_mapping
self.long_seq_metadata = _get_pcp_metadata(num_tokens)
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
self.long_seq_metadata, block_table_gid_0 = _get_pcp_metadata(block_table_gid_0)
cm_base = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
@@ -2080,11 +2086,14 @@ class NPUModelRunner(GPUModelRunner):
# LoRA state when determining the batch descriptor for capture
force_has_lora=activate_lora,
)
if self.pcp_size * self.dcp_size > 1:
if self.use_cp:
self.pcp_manager.init_batch_info(
num_scheduled_tokens,
num_reqs,
)
if self.speculative_config:
self.pcp_manager.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(num_scheduled_tokens)
self.pcp_manager.query_lens_pcp_full.copy_to_gpu()
if cudagraph_runtime_mode is None:
cudagraph_runtime_mode = _cudagraph_mode
else:

View File

@@ -24,6 +24,8 @@ import torch
from vllm.config import VllmConfig
from vllm.v1.utils import CpuGpuBuffer
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -514,13 +516,23 @@ class PCPManager:
dcp_local_seq_lens = (base + remainder).reshape([-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens):
def generate_pcp_metadata(
self,
total_num_scheduled_tokens: int,
query_lens: torch.Tensor,
input_batch: "NPUInputBatch",
num_scheduled_tokens: np.ndarray | None,
block_table_tensor: torch.Tensor,
num_reqs_padded: int,
num_reqs: int,
):
from vllm_ascend.attention.utils import AscendPrefillContextParallelMetadata
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
if self.pcp_world_size * self.dcp_world_size > 1:
assert num_scheduled_tokens is not None
decode_context_lens = (
input_batch.num_computed_tokens_cpu[: self.num_decode_reqs]
+ num_scheduled_tokens[: self.num_decode_reqs]
@@ -544,6 +556,7 @@ class PCPManager:
self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
)
)
ori_query_lens_cpu = None
if self.decode_threshold > 1:
num_computed_tokens_of_pcp_dcp_list = []
if self.num_decode_reqs:
@@ -563,10 +576,37 @@ class PCPManager:
]
)
num_computed_tokens_of_pcp_dcp = torch.cat(num_computed_tokens_of_pcp_dcp_list, dim=0)
# For pcp + spec decode, we flatten block_table
# to avoid irregular attn_mask shape, e.g.,
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
# ori block_table: # [d0, d1, p0, p1, p2]
# (num_reqs_d + num_reqs_p, max_num_blocks),
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded]
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs_padded]
num_prefill_reqs = self.num_prefill_reqs
num_decode_reqs = self.num_decode_reqs
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
)
block_table_tensor[:num_decode_reqs_flatten].copy_(
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(),
)
if ori_query_lens_cpu is not None:
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
long_seq_metadata.max_query_len_pcp_full = ori_query_lens_cpu.max().item()
if self.pcp_world_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
@@ -685,8 +725,9 @@ class PCPManager:
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
long_seq_metadata.head_attn_nomask_seqlens = head_attn_nomask_seqlens_list
long_seq_metadata.tail_attn_nomask_seqlens = tail_attn_nomask_seqlens_list
self.long_seq_metadata = long_seq_metadata
return long_seq_metadata
return long_seq_metadata, block_table_tensor
def _list_to_tensor(self, lst, device, dtype=torch.int32):
tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device)