[bugfix]Fix accuracy issue in PCP/DCP with speculative decoding (#6491)
### What this PR does / why we need it? This PR fixes an accuracy issue that occurs when using Prefill/Decode Context Parallelism (PCP/DCP) in conjunction with speculative decoding (MTP). The issue is caused by an irregular attention mask shape when both features are enabled. The fix involves flattening the `block_table` for speculative decoding requests under PCP/DCP to ensure a regular attention mask. This PR also introduces a `use_cp` property for cleaner code and updates dummy runs to handle this scenario correctly. ### Does this PR introduce _any_ user-facing change? No. This is a bug fix that improves accuracy and should not have user-facing API changes. ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -24,6 +24,8 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
@@ -514,13 +516,23 @@ class PCPManager:
|
||||
dcp_local_seq_lens = (base + remainder).reshape([-1, pcp_world_size, dcp_world_size])
|
||||
return dcp_local_seq_lens
|
||||
|
||||
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens):
|
||||
def generate_pcp_metadata(
|
||||
self,
|
||||
total_num_scheduled_tokens: int,
|
||||
query_lens: torch.Tensor,
|
||||
input_batch: "NPUInputBatch",
|
||||
num_scheduled_tokens: np.ndarray | None,
|
||||
block_table_tensor: torch.Tensor,
|
||||
num_reqs_padded: int,
|
||||
num_reqs: int,
|
||||
):
|
||||
from vllm_ascend.attention.utils import AscendPrefillContextParallelMetadata
|
||||
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
|
||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||
long_seq_metadata = None
|
||||
if self.pcp_world_size * self.dcp_world_size > 1:
|
||||
assert num_scheduled_tokens is not None
|
||||
decode_context_lens = (
|
||||
input_batch.num_computed_tokens_cpu[: self.num_decode_reqs]
|
||||
+ num_scheduled_tokens[: self.num_decode_reqs]
|
||||
@@ -544,6 +556,7 @@ class PCPManager:
|
||||
self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
)
|
||||
ori_query_lens_cpu = None
|
||||
if self.decode_threshold > 1:
|
||||
num_computed_tokens_of_pcp_dcp_list = []
|
||||
if self.num_decode_reqs:
|
||||
@@ -563,10 +576,37 @@ class PCPManager:
|
||||
]
|
||||
)
|
||||
num_computed_tokens_of_pcp_dcp = torch.cat(num_computed_tokens_of_pcp_dcp_list, dim=0)
|
||||
|
||||
# For pcp + spec decode, we flatten block_table
|
||||
# to avoid irregular attn_mask shape, e.g.,
|
||||
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
|
||||
# ori block_table: # [d0, d1, p0, p1, p2]
|
||||
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
||||
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
|
||||
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
|
||||
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded]
|
||||
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs_padded]
|
||||
num_prefill_reqs = self.num_prefill_reqs
|
||||
num_decode_reqs = self.num_decode_reqs
|
||||
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
|
||||
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
|
||||
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
|
||||
)
|
||||
block_table_tensor[:num_decode_reqs_flatten].copy_(
|
||||
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
|
||||
)
|
||||
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
|
||||
if num_reqs_padded > num_reqs:
|
||||
pad_size = num_reqs_padded - num_reqs
|
||||
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
|
||||
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(),
|
||||
)
|
||||
if ori_query_lens_cpu is not None:
|
||||
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
|
||||
long_seq_metadata.max_query_len_pcp_full = ori_query_lens_cpu.max().item()
|
||||
if self.pcp_world_size > 1:
|
||||
q_head_idx, q_tail_idx = [], []
|
||||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||||
@@ -685,8 +725,9 @@ class PCPManager:
|
||||
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
|
||||
long_seq_metadata.head_attn_nomask_seqlens = head_attn_nomask_seqlens_list
|
||||
long_seq_metadata.tail_attn_nomask_seqlens = tail_attn_nomask_seqlens_list
|
||||
|
||||
self.long_seq_metadata = long_seq_metadata
|
||||
return long_seq_metadata
|
||||
return long_seq_metadata, block_table_tensor
|
||||
|
||||
def _list_to_tensor(self, lst, device, dtype=torch.int32):
|
||||
tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device)
|
||||
|
||||
Reference in New Issue
Block a user