[bugfix]Fix accuracy issue in PCP/DCP with speculative decoding (#6491)
### What this PR does / why we need it? This PR fixes an accuracy issue that occurs when using Prefill/Decode Context Parallelism (PCP/DCP) in conjunction with speculative decoding (MTP). The issue is caused by an irregular attention mask shape when both features are enabled. The fix involves flattening the `block_table` for speculative decoding requests under PCP/DCP to ensure a regular attention mask. This PR also introduces a `use_cp` property for cleaner code and updates dummy runs to handle this scenario correctly. ### Does this PR introduce _any_ user-facing change? No. This is a bug fix that improves accuracy and should not have user-facing API changes. ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -73,9 +73,12 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
||||
query_lens) - input_batch.num_computed_tokens_cpu
|
||||
|
||||
query_lens = torch.tensor(query_lens)
|
||||
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
|
||||
result, _ = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
|
||||
input_batch,
|
||||
num_scheduled_tokens)
|
||||
num_scheduled_tokens,
|
||||
torch.tensor([]),
|
||||
num_reqs_padded=num_reqs,
|
||||
num_reqs=num_reqs)
|
||||
|
||||
if not expect_not_none:
|
||||
assert result is None, f"Expected to return None, but got {type(result)}"
|
||||
|
||||
@@ -571,14 +571,14 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
|
||||
self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
|
||||
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
if self.use_cp:
|
||||
self.pcp_manager.init_batch_info(
|
||||
num_scheduled_tokens,
|
||||
self.input_batch.num_reqs,
|
||||
)
|
||||
|
||||
# for pcp, prefill mtp should use origin scheduleroutput ,
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
if self.speculative_config and self.use_cp:
|
||||
self.pcp_manager.generate_pcp_mtp_input(
|
||||
total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
@@ -732,7 +732,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
spec_decode_metadata = None
|
||||
num_draft_tokens = None
|
||||
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
if self.use_cp:
|
||||
logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens)
|
||||
logits_indices = logits_indices.pin_memory().to(self.device, non_blocking=True)
|
||||
else:
|
||||
@@ -954,7 +954,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
|
||||
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
if self.use_cp:
|
||||
long_seq_metadata = self.long_seq_metadata # type: ignore
|
||||
input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu
|
||||
@@ -1838,11 +1838,17 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
kv_cache_groups = self.kv_cache_config.kv_cache_groups
|
||||
|
||||
def _get_pcp_metadata(num_tokens):
|
||||
def _get_pcp_metadata(block_table_tensor):
|
||||
if not self.use_cp:
|
||||
return None
|
||||
return None, block_table_tensor
|
||||
return self.pcp_manager.generate_pcp_metadata(
|
||||
num_tokens, self.query_lens, self.input_batch, num_scheduled_tokens_np
|
||||
num_tokens,
|
||||
self.query_lens,
|
||||
self.input_batch,
|
||||
num_scheduled_tokens_np,
|
||||
block_table_tensor,
|
||||
num_reqs_padded,
|
||||
num_reqs,
|
||||
)
|
||||
|
||||
def _get_block_table_and_slot_mapping(kv_cache_gid: int):
|
||||
@@ -1883,8 +1889,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
)
|
||||
return blk_table_tensor, slot_mapping
|
||||
|
||||
self.long_seq_metadata = _get_pcp_metadata(num_tokens)
|
||||
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
|
||||
self.long_seq_metadata, block_table_gid_0 = _get_pcp_metadata(block_table_gid_0)
|
||||
|
||||
cm_base = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
|
||||
@@ -2080,11 +2086,14 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# LoRA state when determining the batch descriptor for capture
|
||||
force_has_lora=activate_lora,
|
||||
)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
if self.use_cp:
|
||||
self.pcp_manager.init_batch_info(
|
||||
num_scheduled_tokens,
|
||||
num_reqs,
|
||||
)
|
||||
if self.speculative_config:
|
||||
self.pcp_manager.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(num_scheduled_tokens)
|
||||
self.pcp_manager.query_lens_pcp_full.copy_to_gpu()
|
||||
if cudagraph_runtime_mode is None:
|
||||
cudagraph_runtime_mode = _cudagraph_mode
|
||||
else:
|
||||
|
||||
@@ -24,6 +24,8 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
@@ -514,13 +516,23 @@ class PCPManager:
|
||||
dcp_local_seq_lens = (base + remainder).reshape([-1, pcp_world_size, dcp_world_size])
|
||||
return dcp_local_seq_lens
|
||||
|
||||
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens):
|
||||
def generate_pcp_metadata(
|
||||
self,
|
||||
total_num_scheduled_tokens: int,
|
||||
query_lens: torch.Tensor,
|
||||
input_batch: "NPUInputBatch",
|
||||
num_scheduled_tokens: np.ndarray | None,
|
||||
block_table_tensor: torch.Tensor,
|
||||
num_reqs_padded: int,
|
||||
num_reqs: int,
|
||||
):
|
||||
from vllm_ascend.attention.utils import AscendPrefillContextParallelMetadata
|
||||
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
|
||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||
long_seq_metadata = None
|
||||
if self.pcp_world_size * self.dcp_world_size > 1:
|
||||
assert num_scheduled_tokens is not None
|
||||
decode_context_lens = (
|
||||
input_batch.num_computed_tokens_cpu[: self.num_decode_reqs]
|
||||
+ num_scheduled_tokens[: self.num_decode_reqs]
|
||||
@@ -544,6 +556,7 @@ class PCPManager:
|
||||
self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
)
|
||||
ori_query_lens_cpu = None
|
||||
if self.decode_threshold > 1:
|
||||
num_computed_tokens_of_pcp_dcp_list = []
|
||||
if self.num_decode_reqs:
|
||||
@@ -563,10 +576,37 @@ class PCPManager:
|
||||
]
|
||||
)
|
||||
num_computed_tokens_of_pcp_dcp = torch.cat(num_computed_tokens_of_pcp_dcp_list, dim=0)
|
||||
|
||||
# For pcp + spec decode, we flatten block_table
|
||||
# to avoid irregular attn_mask shape, e.g.,
|
||||
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
|
||||
# ori block_table: # [d0, d1, p0, p1, p2]
|
||||
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
||||
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
|
||||
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
|
||||
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded]
|
||||
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs_padded]
|
||||
num_prefill_reqs = self.num_prefill_reqs
|
||||
num_decode_reqs = self.num_decode_reqs
|
||||
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
|
||||
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
|
||||
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
|
||||
)
|
||||
block_table_tensor[:num_decode_reqs_flatten].copy_(
|
||||
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
|
||||
)
|
||||
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
|
||||
if num_reqs_padded > num_reqs:
|
||||
pad_size = num_reqs_padded - num_reqs
|
||||
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
|
||||
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(),
|
||||
)
|
||||
if ori_query_lens_cpu is not None:
|
||||
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
|
||||
long_seq_metadata.max_query_len_pcp_full = ori_query_lens_cpu.max().item()
|
||||
if self.pcp_world_size > 1:
|
||||
q_head_idx, q_tail_idx = [], []
|
||||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||||
@@ -685,8 +725,9 @@ class PCPManager:
|
||||
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
|
||||
long_seq_metadata.head_attn_nomask_seqlens = head_attn_nomask_seqlens_list
|
||||
long_seq_metadata.tail_attn_nomask_seqlens = tail_attn_nomask_seqlens_list
|
||||
|
||||
self.long_seq_metadata = long_seq_metadata
|
||||
return long_seq_metadata
|
||||
return long_seq_metadata, block_table_tensor
|
||||
|
||||
def _list_to_tensor(self, lst, device, dtype=torch.int32):
|
||||
tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device)
|
||||
|
||||
Reference in New Issue
Block a user