diff --git a/tests/ut/kv_connector/utils.py b/tests/ut/kv_connector/utils.py index 1bc535c5..389b5044 100644 --- a/tests/ut/kv_connector/utils.py +++ b/tests/ut/kv_connector/utils.py @@ -164,7 +164,7 @@ def create_request( remote_host="my-host", remote_port=1234, remote_tp_size=1, - remote_cp_size=1, + remote_pcp_size=1, remote_dcp_size=1) max_tokens = 1 if do_remote_decode else max_tokens diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 314d5a55..2fa60ca8 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -281,14 +281,17 @@ class AscendMLAMetadataBuilder: decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', 0) max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) - self.batch_seq_mask_buf = torch.empty(max_num_seqs, + self.batch_seq_mask_buf = torch.empty(max_num_seqs * + self.decode_threshold, dtype=torch.uint8, device=device) - self.seq_mask_pcp_buf = torch.empty(max_num_seqs, + self.seq_mask_pcp_buf = torch.empty(max_num_seqs * + self.decode_threshold, self.pcp_size, dtype=torch.uint8, device=device) - self.seq_mask_dcp_buf = torch.empty(max_num_seqs, + self.seq_mask_dcp_buf = torch.empty(max_num_seqs * + self.decode_threshold, self.dcp_size, dtype=torch.uint8, device=device) @@ -504,12 +507,18 @@ class AscendMLAMetadataBuilder: seq_lens = seq_lens[:num_decodes] input_positions = input_positions[:num_decode_tokens] block_table = block_table[:num_decodes, ...] + # For pcp + spec decode, we flatten seq_lens and block_table + # to avoid irregular spec_attn_mask shape + if self.pcp_size > 1: + block_table = block_table.repeat_interleave( + self.decode_threshold, dim=0) seq_lens_list = seq_lens.tolist() if num_computed_tokens_of_pcp_dcp is not None: + # [bs, pcp_size, dcp_size] num_computed_tokens_of_cp_dcp_array = np.array( - num_computed_tokens_of_pcp_dcp - )[:num_decodes] # [bs, pcp_size, dcp_size] + num_computed_tokens_of_pcp_dcp)[:num_decodes * + self.decode_threshold] cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 5c65936e..8e72ebf0 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -371,7 +371,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): def update_mla_attn_dcp_pcp_params(update_stream, forward_context, - runtime_shape, speculative_config): + runtime_shape): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. @@ -388,16 +388,14 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context, decode_meta = forward_context.attn_metadata[key].decode seq_len = decode_meta.cp_seq_len - if speculative_config and speculative_config.method == "deepseek_mtp": - spec_multiple = speculative_config.num_speculative_tokens + 1 - seq_len = seq_len + [0] * (runtime_shape // spec_multiple - - len(seq_len)) - else: - pad_length = runtime_shape - len(seq_len) - pad_tensor = torch.zeros(pad_length, - dtype=seq_len.dtype, - device=seq_len.device) - seq_len = torch.cat([seq_len, pad_tensor], dim=0) + # For pcp + spec decode, we flatten seq_lens + # to avoid irregular spec_attn_mask shape, + # so there's no need to divide runtime_shape by spec_multiple + pad_length = runtime_shape - len(seq_len) + pad_tensor = torch.zeros(pad_length, + dtype=seq_len.dtype, + device=seq_len.device) + seq_len = torch.cat([seq_len, pad_tensor], dim=0) torch.npu.graph_task_update_begin(update_stream, handle) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index d92b724f..3aa49131 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -79,7 +79,7 @@ class ReqMeta: remote_port: str engine_id: str remote_tp_size: str - remote_cp_size: str + remote_pcp_size: str remote_dcp_size: str @@ -97,7 +97,7 @@ class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata): remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], remote_tp_size=kv_transfer_params["remote_tp_size"], - remote_cp_size=kv_transfer_params["remote_cp_size"], + remote_pcp_size=kv_transfer_params["remote_pcp_size"], remote_dcp_size=kv_transfer_params["remote_dcp_size"], ) @@ -318,7 +318,7 @@ class LLMDataDistCMgrConnectorScheduler(): remote_port=self.port, remote_tp_size=str( self.vllm_config.parallel_config.tensor_parallel_size), - remote_cp_size=str(self.pcp_size), + remote_pcp_size=str(self.pcp_size), remote_dcp_size=str(self.dcp_size), ) @@ -677,7 +677,7 @@ class LLMDataDistCMgrConnectorWorker(): remote_engine_id=meta.engine_id, request_id=req_id, remote_tp_size=meta.remote_tp_size, - remote_cp_size=meta.remote_cp_size, + remote_pcp_size=meta.remote_pcp_size, remote_dcp_size=meta.remote_dcp_size, ) futures.append(future) @@ -876,39 +876,40 @@ class LLMDataDistCMgrConnectorWorker(): remote_block_ids: list[int], remote_port: int, remote_tp_size: int, - remote_cp_size: int, + remote_pcp_size: int, remote_dcp_size: int, ) -> tuple[int, list[int], list[int]]: """ In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node. Use this function to calculate remote port and remote block number of each remote P node that we need to pull. """ - if self.pcp_size == remote_cp_size and self.dcp_size == remote_dcp_size: + if self.pcp_size == remote_pcp_size and self.dcp_size == remote_dcp_size: # remote & local cp/dcp are equal, do kv transfer point-to-point remote_kv_num = 1 remote_ports = [remote_port + self.pcp_rank * self.tp_size + tp_offset \ for tp_offset in range(self.tp_rank, int(remote_tp_size), self.tp_size)] remote_block_nums = [len(remote_block_ids)] elif (self.use_mla and self.pcp_size == 1 and self.dcp_size == 1) \ - or (not self.use_mla and self.pcp_size == 1 and remote_tp_size == self.tp_size and remote_dcp_size == self.dcp_size): + or (not self.use_mla and self.pcp_size == 1 and self.dcp_size == 1 and remote_tp_size == self.tp_size): # remote & local cp/dcp are not equal, each D node needs to pull from cp(*dcp) P nodes # 1. for mla, support D cp_size = dcp_size = 1 # 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size remote_dcp_size = remote_dcp_size // self.dcp_size - remote_kv_num = remote_cp_size * remote_dcp_size + remote_kv_num = remote_pcp_size * remote_dcp_size cp_dcp_offsets = [] - for cp_idx in range(remote_cp_size): + for cp_idx in range(remote_pcp_size): cp_offset = cp_idx * remote_tp_size cp_dcp_offsets += list( range(cp_offset, cp_offset + remote_dcp_size)) - remote_ports = [remote_port + cp_dcp_offset + (self.tp_rank if not self.use_mla else 0) \ + tp_offset = 0 if self.use_mla else self.tp_rank // remote_dcp_size * remote_dcp_size + remote_ports = [remote_port + cp_dcp_offset + tp_offset \ for cp_dcp_offset in cp_dcp_offsets] # recompute cp/dcp block assign here, maybe we can also pass it from P node meta local_block_num = len(local_block_ids) remote_block_nums = [ - local_block_num // (remote_cp_size * remote_dcp_size) - ] * remote_cp_size * remote_dcp_size - num_remain_blocks = local_block_num % (remote_cp_size * + local_block_num // (remote_pcp_size * remote_dcp_size) + ] * remote_pcp_size * remote_dcp_size + num_remain_blocks = local_block_num % (remote_pcp_size * remote_dcp_size) for i in range(num_remain_blocks): remote_block_nums[i] += 1 @@ -921,7 +922,7 @@ class LLMDataDistCMgrConnectorWorker(): # Other cases are not supported now, maybe need to reshard kv_cache. raise NotImplementedError( f'Current case is not supported now: use_mla={self.use_mla}, ' - f'P tp={remote_tp_size}, pcp={remote_cp_size}, dcp={remote_dcp_size}, ' + f'P tp={remote_tp_size}, pcp={remote_pcp_size}, dcp={remote_dcp_size}, ' f'D tp={self.tp_size}, pcp={self.pcp_size}, dcp={self.dcp_size}' ) return remote_kv_num, remote_ports, remote_block_nums @@ -935,7 +936,7 @@ class LLMDataDistCMgrConnectorWorker(): remote_engine_id: str, request_id: str, remote_tp_size: str, - remote_cp_size: str, + remote_pcp_size: str, remote_dcp_size: str, ): remote_kv_num, remote_ports, remote_block_nums = self._get_kv_split_metadata( @@ -943,7 +944,7 @@ class LLMDataDistCMgrConnectorWorker(): remote_block_ids=remote_block_ids, remote_port=remote_port, remote_tp_size=int(remote_tp_size), - remote_cp_size=int(remote_cp_size), + remote_pcp_size=int(remote_pcp_size), remote_dcp_size=int(remote_dcp_size), ) logger.debug( diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 2d4e239e..362c6148 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -25,12 +25,15 @@ from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - AscendPrefillContextParallelMetadata) +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, + prefill_context_parallel_enable, vllm_version_is) +if prefill_context_parallel_enable(): + from vllm.distributed import get_pcp_group + if vllm_version_is("0.11.0"): from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.utils import is_pin_memory_available @@ -84,6 +87,7 @@ class MtpProposer(Proposer): self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size self.num_speculative_tokens = self.speculative_config.num_speculative_tokens + self.decode_threshold = 1 + self.num_speculative_tokens self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because @@ -276,16 +280,28 @@ class MtpProposer(Proposer): self.runner.num_discarded_requests ) - is_prefill = len(scheduler_output.scheduled_new_reqs) > 0 req_scheduled_tokens = scheduler_output.num_scheduled_tokens - long_seq_metadata: AscendPrefillContextParallelMetadata = \ - self.runner.long_seq_metadata if self.pcp_size > 1 else None + if self.pcp_size > 1: + long_seq_metadata = self.runner.long_seq_metadata + input_ids_pcp_full = self.runner.input_ids_pcp_full + query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full + query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu + num_reqs = self.runner.input_batch.num_reqs + ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ + query_start_loc_pcp_full_cpu[:num_reqs] + num_prefill_reqs = (ori_query_lens + > self.decode_threshold).sum().item() + num_decode_reqs = num_reqs - num_prefill_reqs + else: + long_seq_metadata = None + num_prefill_reqs = 0 + num_decode_reqs = 0 if spec_decode_metadata is None: # update pcp related params - if self.pcp_size > 1 and is_prefill: - token_indices_to_sample = None - target_token_ids = self.runner.input_ids_pcp_full[: - num_scheduled_tokens] + if self.pcp_size > 1: + token_indices_to_sample = \ + query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1 + target_token_ids = input_ids_pcp_full[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] target_hidden_states = hidden_states else: @@ -295,6 +311,11 @@ class MtpProposer(Proposer): target_positions = positions[:num_scheduled_tokens] target_hidden_states = hidden_states[:num_scheduled_tokens] else: + if self.pcp_size > 1: + common_attn_metadata.query_start_loc_cpu = \ + query_start_loc_pcp_full_cpu[:num_reqs + 1] + common_attn_metadata.query_start_loc = \ + query_start_loc_pcp_full[:num_reqs + 1] if self.speculative_config.disable_padded_drafter_batch: token_indices_to_sample = None common_attn_metadata, token_indices =\ @@ -309,9 +330,14 @@ class MtpProposer(Proposer): common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count) - target_token_ids = self.runner.input_ids[token_indices] - target_positions = positions[token_indices] - target_hidden_states = hidden_states[token_indices] + if self.pcp_size > 1: + target_token_ids = input_ids_pcp_full[token_indices] + target_positions = positions + target_hidden_states = hidden_states + else: + target_token_ids = self.runner.input_ids[token_indices] + target_positions = positions[token_indices] + target_hidden_states = hidden_states[token_indices] draft_token_ids = self._propose( target_token_ids=target_token_ids, @@ -321,9 +347,10 @@ class MtpProposer(Proposer): last_token_indices=token_indices_to_sample, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, - is_prefill=is_prefill, req_scheduled_tokens=req_scheduled_tokens, long_seq_metadata=long_seq_metadata, + num_prefill_reqs=num_prefill_reqs, + num_decode_reqs=num_decode_reqs, ) return draft_token_ids @@ -464,9 +491,10 @@ class MtpProposer(Proposer): sampling_metadata: SamplingMetadata, mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, - is_prefill=False, req_scheduled_tokens=None, long_seq_metadata=None, + num_prefill_reqs=0, + num_decode_reqs=0, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -488,20 +516,65 @@ class MtpProposer(Proposer): self.input_ids[last_token_indices] = next_token_ids # update pcp related params - if self.pcp_size > 1 and is_prefill: - num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens = \ - self._split_pcp_input(req_scheduled_tokens, num_tokens, target_hidden_states) - # graph mode padding not considered now - num_input_tokens = num_tokens - self.input_ids[:num_input_tokens].copy_(input_ids) + if self.pcp_size > 1: + assert long_seq_metadata is not None common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata + # 1. preprocess decode/prefill input_ids & target_hidden_states + # decode input_ids: keep unchanged + # decode target_hidden_states: remove padding + # prefill input_ids: add padding and pcp split + # prefill target_hidden_states: pcp split + num_tokens_d = num_decode_reqs * self.decode_threshold + num_tokens_d_padded = num_tokens_d * self.pcp_size + input_ids_d = self.input_ids[:num_tokens_d] + input_ids_p = self.input_ids[num_tokens_d:num_tokens] + target_hidden_states_d_padded = \ + target_hidden_states[:num_tokens_d_padded] + if num_tokens_d: + # remove padding (from pcp all-gather) in decode part + target_hidden_states_d = target_hidden_states_d_padded.reshape( + [ + num_decode_reqs, self.decode_threshold * self.pcp_size, + -1 + ])[:, :self.decode_threshold, :].reshape( + [num_tokens_d, -1]) + else: + target_hidden_states_d = target_hidden_states_d_padded + target_hidden_states_p = target_hidden_states[num_tokens_d_padded:] + req_scheduled_tokens_p = {} + for i, req_id in enumerate(self.runner.input_batch.req_ids): + if i >= num_decode_reqs: + req_scheduled_tokens_p[req_id] = \ + req_scheduled_tokens[req_id] + (num_tokens_p, input_ids_p, target_hidden_states_p, + max_query_len_p, seq_lens_p, cu_num_tokens_p) = \ + self._split_pcp_input( + req_scheduled_tokens_p, input_ids_p, target_hidden_states_p) + num_tokens = num_tokens_d + num_tokens_p + target_positions = target_positions[:num_tokens] + self.input_ids[:num_tokens].copy_( + torch.cat([input_ids_d, input_ids_p], dim=0)) + target_hidden_states = torch.cat( + [target_hidden_states_d, target_hidden_states_p], dim=0) + # 2. update attn_metadata params that may be influenced by pcp common_attn_metadata.num_actual_tokens = num_tokens - common_attn_metadata.max_query_len = max_query_len - common_attn_metadata.seq_lens_cpu = seq_lens.cpu() - common_attn_metadata.query_start_loc = \ - cu_num_tokens[:batch_size + 1] - common_attn_metadata.query_start_loc_cpu = \ - cu_num_tokens[:batch_size + 1].cpu() + common_attn_metadata.max_query_len = max(self.decode_threshold, + max_query_len_p) + common_attn_metadata.seq_lens[num_decode_reqs:] = seq_lens_p + common_attn_metadata.seq_lens_cpu[num_decode_reqs:] = seq_lens_p + query_start_loc_p = cu_num_tokens_p[1:] + \ + common_attn_metadata.query_start_loc[num_decode_reqs].item() + common_attn_metadata.query_start_loc[num_decode_reqs + 1:] = \ + query_start_loc_p + common_attn_metadata.query_start_loc_cpu[num_decode_reqs + 1:] = \ + query_start_loc_p + # 3. update sample_indices according to main model + if num_decode_reqs: + last_token_indices[:num_decode_reqs] = \ + self.runner.logits_indices[last_token_indices[:num_decode_reqs]] + if num_prefill_reqs: + last_token_indices[-num_prefill_reqs:] = \ + self.runner.logits_indices[-num_prefill_reqs:] assert self.runner is not None @@ -575,6 +648,12 @@ class MtpProposer(Proposer): last_token_indices, (0, max_num_reqs_across_dp - num_indices)) + if self.pcp_size > 1: + hidden_states = get_pcp_group().all_gather(hidden_states, 0) + hidden_states = torch.index_select( + hidden_states, 0, self.runner. + pcp_allgather_restore_idx[:hidden_states.shape[0]]) + sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) if lmhead_tp_enable() and num_indices < logits.shape[0]: @@ -854,16 +933,26 @@ class MtpProposer(Proposer): return spec_common_attn_metadata, token_indices, token_indices_to_sample - def _split_pcp_input(self, req_scheduled_tokens, num_tokens, + def _split_pcp_input(self, req_scheduled_tokens, input_ids, target_hidden_states): """ - Split input_ids and target_hidden_states in pcp group. + Split prefill input_ids and target_hidden_states in pcp group. 1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad] 2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5] - 3. split target_hidden_states (already include cp padding): + 3. split target_hidden_states (already include pcp padding): [h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5] 4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split. """ + if len(req_scheduled_tokens) == 0: + # no prefill inputs to split, return empty result + return ( + 0, + torch.zeros([0], device='npu'), + torch.zeros([0, target_hidden_states.size(1)], device='npu'), + 0, + torch.zeros([0]), + torch.tensor([0], dtype=torch.int32), + ) def _pcp_pad_and_split(num_tokens): num_pcp_padded_scheduled_tokens = cdiv( @@ -885,7 +974,6 @@ class MtpProposer(Proposer): return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad num_pcp_scheduled_tokens = [] - input_ids_list = self.input_ids[:num_tokens] ori_start_index = 0 pad_start_index = 0 pcp_split_input_ids_list = [] @@ -896,8 +984,8 @@ class MtpProposer(Proposer): actual_num_tokens = len(req_position_pcp) num_pcp_scheduled_tokens.append(actual_num_tokens) pad_input_ids = F.pad( - input_ids_list[ori_start_index:ori_start_index + - ori_num_tokens], (0, num_pcp_pad)) + input_ids[ori_start_index:ori_start_index + ori_num_tokens], + (0, num_pcp_pad)) ori_start_index += ori_num_tokens pcp_chunk_indices = [ pad_start_index + pos for pos in req_position_pcp diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7f8fe1e1..746a1c90 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -487,19 +487,29 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.speculative_config and self.pcp_size > 1: self.input_ids_pcp_full = torch.zeros(self.max_num_tokens, dtype=torch.int32, - device="cpu", - pin_memory=True) + device=self.device) + self.input_ids_pcp_full_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=True) self.query_start_loc_pcp_full = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, - device="cpu", - pin_memory=True) - self.query_start_loc_pcp_full_np = self.query_start_loc_pcp_full.numpy( - ) + device=self.device) + self.query_start_loc_pcp_full_cpu = \ + torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=True) + self.query_start_loc_pcp_full_np = \ + self.query_start_loc_pcp_full_cpu.numpy() self.positions_pcp_full = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=True) - self.positions_np_pcp_full = self.positions_pcp_full.numpy() + self.positions_pcp_full_np = self.positions_pcp_full.numpy() + self.decode_threshold = 1 + ( + self.speculative_config.num_speculative_tokens + if self.speculative_config else 0) self.use_aclgraph = self._use_aclgraph() self.aclgraph_batch_sizes = list( @@ -1854,8 +1864,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): logits_indices = torch.from_numpy(cu_num_tokens - 1).to( self.device, non_blocking=True) else: - # pcp not supported now - assert self.pcp_size == 1 # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all # requests have draft tokens. @@ -1866,11 +1874,13 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_draft_tokens[req_idx] = len(draft_token_ids) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) + num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs]) logits_indices = spec_decode_metadata.logits_indices self.num_draft_tokens.np[:num_reqs] = num_draft_tokens self.num_draft_tokens.np[num_reqs:].fill(0) self.num_draft_tokens.copy_to_gpu() + # save logits_indices for pcp spec decode usage + self.logits_indices = logits_indices # Used in the below loop. # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] @@ -1883,8 +1893,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() - is_prefill = len(scheduler_output.scheduled_new_reqs) > 0 - if self.speculative_config and self.pcp_size > 1 and is_prefill: + if self.speculative_config and self.pcp_size > 1: self._generate_pcp_mtp_input( num_reqs, scheduler_output.total_num_scheduled_tokens, scheduler_output.num_scheduled_tokens) @@ -2040,8 +2049,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, - maybe_padded_num_tokens, - self.speculative_config) + maybe_padded_num_tokens) else: # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_params(self.update_stream, forward_context, @@ -2110,6 +2118,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self, num_draft_tokens: np.ndarray, cu_num_scheduled_tokens: np.ndarray, + num_pcp_pads: np.ndarray, ) -> SpecDecodeMetadata: # Inputs: # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] @@ -2138,6 +2147,17 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] logits_indices += arange + # while pcp > 1, decode results may contain padding (from pcp all-gather), + # update logits_indices after getting draft_token_ids from ori logits_indices + if self.pcp_size > 1: + cu_num_scheduled_tokens = cu_num_scheduled_tokens * self.pcp_size - num_pcp_pads + logits_indices_pcp = np.repeat( + cu_num_scheduled_tokens - num_sampled_tokens, + num_sampled_tokens) + logits_indices_pcp += arange + logits_indices_pcp = torch.from_numpy(logits_indices_pcp).to( + self.device, non_blocking=True) + # Compute the bonus logits indices. bonus_logits_indices = cu_num_sampled_tokens - 1 @@ -2173,6 +2193,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): # draft_token_indices: [ 1, 2, 3, 105, 106, 208] draft_token_ids = self.input_ids[logits_indices] draft_token_ids = draft_token_ids[target_logits_indices + 1] + if self.pcp_size > 1: + logits_indices = logits_indices_pcp if vllm_version_is("0.11.0"): metadata = SpecDecodeMetadata( draft_token_ids=draft_token_ids, @@ -2920,8 +2942,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, - positions.shape[0], - self.speculative_config) + positions.shape[0]) else: # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_params(self.update_stream, forward_context, @@ -4328,18 +4349,25 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_decode_reqs = sum( self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs]) + num_decode_tokens = sum(tokens[:num_decode_reqs]) num_padded_scheduled_tokens = np.ceil( tokens / (2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size) - num_padded_scheduled_tokens[:num_decode_reqs] = self.pcp_size + num_padded_scheduled_tokens[:num_decode_reqs] = ( + tokens[:num_decode_reqs] * self.pcp_size) self.num_pcp_pads = num_padded_scheduled_tokens - tokens cu_padded_tokens, pcp_padded_arange = \ self._get_cumsum_and_arange(num_padded_scheduled_tokens) unpad_mask = torch.from_numpy( pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens)) + unpad_mask_decode = unpad_mask[:num_decode_tokens * self.pcp_size] + unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_size]) + unpad_mask_decode[:, 0] = True + unpad_mask_decode[:, 1:] = False pcp_tokens = num_padded_scheduled_tokens // self.pcp_size pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) + pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens) _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, @@ -4356,14 +4384,16 @@ class NPUModelRunner(LoRAModelRunnerMixin): np.repeat(head_start_loc, pcp_chunk_sizes) # Decode reqs do not have tail chunks. positions[~pcp_head_chunk_mask] = \ - pcp_chunk_arange[num_decode_reqs:] + \ - np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_reqs:] + pcp_chunk_arange[num_decode_tokens:] + \ + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:] return positions positions = get_current_rank_positions( np.zeros(num_reqs, dtype=np.int32), self.pcp_rank) # Decode tokens are duplicate and their positions always be 0. - positions[:num_decode_reqs] = 0 + if num_decode_reqs > 0: + positions[:num_decode_tokens] = self._get_cumsum_and_arange( + tokens[:num_decode_reqs])[1] all_positions = [ get_current_rank_positions(cu_padded_tokens, rank_i) @@ -4372,7 +4402,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): all_positions_tensor = torch.from_numpy(np.concatenate(all_positions)) self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_( all_positions_tensor.float().argsort().long(), non_blocking=True) - pcp_tokens[:num_decode_reqs] = 1 return pcp_tokens, positions, unpad_mask def _get_pcp_local_seq_lens( @@ -4524,7 +4553,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens, seq_lens_origin): - num_reqs = self.input_batch.num_reqs + # In dummy run num_reqs == 0, update it from seq_lens + num_reqs = self.input_batch.num_reqs or seq_lens.size(0) num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs]) num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size @@ -4535,14 +4565,28 @@ class NPUModelRunner(LoRAModelRunnerMixin): local_chunked_kv_lens) long_seq_metadata = None if self.pcp_size * self.dcp_size > 1: + num_computed_tokens_of_pcp_dcp = torch.zeros( + [ + num_reqs * self.decode_threshold, self.pcp_size, + self.dcp_size + ], + dtype=torch.int32, + ) + # For pcp + spec decode, we flatten seq_lens + # to avoid irregular spec_attn_mask shape + for decode_idx in range(self.decode_threshold): + num_computed_tokens_of_pcp_dcp[ + self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ + self._get_pcp_local_seq_lens( + seq_lens_origin - decode_idx, + self.pcp_size, + self.dcp_size, + self.parallel_config.cp_kv_cache_interleave_size, + ) long_seq_metadata = AscendPrefillContextParallelMetadata( num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, - num_computed_tokens_of_pcp_dcp=self._get_pcp_local_seq_lens( - seq_lens_origin, - self.pcp_size, - self.dcp_size, - self.parallel_config.cp_kv_cache_interleave_size, - ).numpy(), + num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. + numpy(), local_chunked_kv_lens=local_chunked_kv_lens, mask_for_non_zero_chunk=mask_for_non_zero_chunk, max_chunk_num=max_chunk_num) @@ -4706,16 +4750,25 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens_pcp_full) arange_pcp_full = self.arange_np[: total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full - positions_np_pcp_full = self.positions_np_pcp_full[: + positions_pcp_full_np = self.positions_pcp_full_np[: total_num_scheduled_tokens_pcp_full] np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full], arange_pcp_full, - out=positions_np_pcp_full) + out=positions_pcp_full_np) token_indices_pcp_full = ( - positions_np_pcp_full + + positions_pcp_full_np + req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1]) torch.index_select( self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices_pcp_full), - out=self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full]) + out=self. + input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full]) + self.query_start_loc_pcp_full[:num_reqs + 1].copy_( + self.query_start_loc_pcp_full_cpu[:num_reqs + 1], + non_blocking=True, + ) + self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full].copy_( + self.input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full], + non_blocking=True, + )