diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py b/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py index d8f38ba3..3cf269c5 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py @@ -70,12 +70,12 @@ def test_pcp_dcp_mtp3_eager(): max_num_batched_tokens=1024, enable_expert_parallel=True, block_size=128, + async_scheduling=True, speculative_config={ "num_speculative_tokens": 3, "method": "deepseek_mtp", }, enforce_eager=True, - async_scheduling=False, ) as runner: runner.generate_greedy(prompts, 32) diff --git a/tests/ut/worker/test_pcp_manager.py b/tests/ut/worker/test_pcp_manager.py index eaa34a2e..9a6779c1 100644 --- a/tests/ut/worker/test_pcp_manager.py +++ b/tests/ut/worker/test_pcp_manager.py @@ -47,6 +47,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, max_num_reqs=1000, device="cpu", vllm_config=vllm_config, + use_async_scheduling=False, pin_memory=False) input_batch = MagicMock() input_batch.num_reqs = num_reqs @@ -65,13 +66,16 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, num_prompt_tokens.append(query_lens[i]) num_tokens.append(query_lens[i]) - input_batch.num_computed_tokens_cpu = torch.tensor(num_computed_tokens) + input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens) input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens) input_batch.num_tokens = torch.tensor(num_tokens) + num_scheduled_tokens = np.array( + query_lens) - input_batch.num_computed_tokens_cpu query_lens = torch.tensor(query_lens) result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, - input_batch) + input_batch, + num_scheduled_tokens) if not expect_not_none: assert result is None, f"Expected to return None, but got {type(result)}" @@ -128,6 +132,7 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, max_num_reqs=1000, device="cpu", vllm_config=vllm_config, + use_async_scheduling=False, pin_memory=False) input_batch = MagicMock() input_batch.num_reqs = num_reqs @@ -193,6 +198,7 @@ def test_get_cp_local_seq_lens( max_num_reqs=1000, device="cpu", vllm_config=vllm_config, + use_async_scheduling=False, pin_memory=False) ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size, dcp_world_size, @@ -276,6 +282,7 @@ def test_generate_pcp_mtp_input( max_num_reqs=max_num_reqs, device="cpu", vllm_config=vllm_config, + use_async_scheduling=False, pin_memory=False) arange_np = np.arange(max_model_len) input_batch = MagicMock() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 210dea7b..4718551c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -226,6 +226,7 @@ class NPUModelRunner(GPUModelRunner): self.max_num_reqs, self.device, self.vllm_config, + self.use_async_scheduling, self.pin_memory, ) # TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this @@ -540,10 +541,18 @@ class NPUModelRunner(GPUModelRunner): # for pcp, prefill mtp should use origin scheduleroutput , if self.speculative_config and self.pcp_size * self.dcp_size > 1: self.pcp_manager.generate_pcp_mtp_input( - num_reqs, total_num_scheduled_tokens, - scheduler_output.num_scheduled_tokens, with_prefill, - self.input_batch, self.arange_np, req_indices, positions_np, - cu_num_tokens) + num_reqs, + total_num_scheduled_tokens, + scheduler_output.num_scheduled_tokens, + with_prefill, + self.input_batch, + self.arange_np, + req_indices, + positions_np, + cu_num_tokens, + self._draft_token_ids, # type: ignore[has-type] + scheduler_output, + self.num_spec_tokens) if self.pcp_size > 1: if not self.vllm_config.model_config.use_mla: @@ -929,7 +938,7 @@ class NPUModelRunner(GPUModelRunner): if self.pcp_size * self.dcp_size > 1: self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata( total_num_scheduled_tokens, self.query_lens, - self.input_batch) + self.input_batch, num_scheduled_tokens) blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1) if self.pcp_size > 1: slot_mapping_pcp = self.pcp_manager.get_padded_slot_mapping( @@ -1946,7 +1955,8 @@ class NPUModelRunner(GPUModelRunner): slot_mapping = self.input_batch.block_table[ kv_cache_group_id].slot_mapping long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata( - num_tokens, self.query_lens, self.input_batch) + num_tokens, self.query_lens, self.input_batch, + num_scheduled_tokens) if long_seq_metadata is not None: pcp_world_size = get_pcp_group().world_size dcp_world_size = get_dcp_group().world_size diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index f0ace8a4..40294807 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -17,7 +17,7 @@ # Adapted from vllm-project/vllm/vllm/worker/worker.py # -from typing import List +from typing import TYPE_CHECKING, List import numpy as np import torch @@ -25,6 +25,9 @@ from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + class PCPManager: """ @@ -44,6 +47,7 @@ class PCPManager: max_num_reqs: int, device: torch.device, vllm_config: VllmConfig, + use_async_scheduling: bool, pin_memory: bool = False, ) -> None: self.pcp_world_size = pcp_world_size @@ -58,6 +62,7 @@ class PCPManager: self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs self.device = device + self.use_async_scheduling = use_async_scheduling self.pcp_allgather_restore_idx = CpuGpuBuffer( max_buffer_num_tokens, dtype=torch.int64, @@ -354,6 +359,9 @@ class PCPManager: req_indices=None, positions_np=None, cu_num_tokens=None, + draft_token_ids=None, + scheduler_output=None, + num_spec_tokens=None, ): """ While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group, @@ -390,6 +398,12 @@ class PCPManager: torch.from_numpy(token_indices_pcp_full), out=self.input_ids_pcp_full. cpu[:total_num_scheduled_tokens_pcp_full]) + if self.use_async_scheduling: + self._update_input_ids_pcp_full_ids(input_batch, draft_token_ids, + scheduler_output, + total_num_scheduled_tokens, + cu_num_tokens_pcp_full, + num_spec_tokens) self.query_lens_pcp_full.copy_to_gpu() self.query_start_loc_pcp_full.copy_to_gpu() self.input_ids_pcp_full.copy_to_gpu( @@ -428,6 +442,99 @@ class PCPManager: mtp_slot_pad[unpad_mask] = mtp_slot_ori self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True) + def _update_input_ids_pcp_full_ids( + self, + input_batch, + draft_token_ids, + scheduler_output: "SchedulerOutput", + total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray, + num_spec_tokens: int, + ) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + GPU need to be copied into the corresponding slots into input_ids.""" + + if (input_batch.prev_sampled_token_ids is None + or input_batch.prev_req_id_to_index is None): + return + + # Async scheduling case, where some decode requests from the previous + # iteration won't have entries in input_ids_cpu and need to be copied + # on the GPU from prev_sampled_token_ids. + prev_req_id_to_index = input_batch.prev_req_id_to_index + sample_flattened_indices: list[int] = [] + spec_flattened_indices: list[int] = [] + prev_common_req_indices: list[int] = [] + prev_draft_token_indices: list[int] = [] + total_num_spec_tokens = 0 + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + + for req_id, cur_index in input_batch.req_id_to_index.items(): + if (prev_index := prev_req_id_to_index.get(req_id)) is not None: + prev_common_req_indices.append(prev_index) + # We need to compute the flattened input_ids index of the + # last token in each common request. + draft_len = len(scheduled_spec_tokens.get(req_id, ())) + total_num_spec_tokens += draft_len + flattened_index = cu_num_tokens[cur_index].item() - 1 + # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] + # sample_flattened_indices = [0, 2, 5] + # spec_flattened_indices = [1, 3, 4, 6, 7] + sample_flattened_indices.append(flattened_index - draft_len) + spec_flattened_indices.extend( + range(flattened_index - draft_len + 1, + flattened_index + 1)) + start = prev_index * num_spec_tokens + # prev_draft_token_indices is used to find which draft_tokens_id + # should be copied to input_ids + # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] + # flatten draft_tokens_id [1,2,3,4,5,6] + # draft_len of each request [1, 2, 1] + # then prev_draft_token_indices is [0, 2, 3, 4] + prev_draft_token_indices.extend(range(start, + start + draft_len)) + num_commmon_tokens = len(sample_flattened_indices) + + if num_commmon_tokens == 0: + # No requests in common with the previous iteration + # So input_ids.cpu will have all the input ids. + return + # Upload the index tensors asynchronously so the scatter can be non-blocking. + sampled_tokens_index_tensor = torch.tensor(sample_flattened_indices, + dtype=torch.int64) + prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices, + dtype=torch.int64) + self.input_ids_pcp_full.cpu.scatter_( + dim=0, + index=sampled_tokens_index_tensor, + src=input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor, 0].cpu(), + ) + + # Scatter the draft tokens after the sampled tokens are scattered. + if draft_token_ids is None or not spec_flattened_indices: + return + + assert isinstance(draft_token_ids, torch.Tensor) + draft_tokens_index_tensor = torch.tensor(spec_flattened_indices, + dtype=torch.int64) + prev_draft_token_indices_tensor = torch.tensor( + prev_draft_token_indices, dtype=torch.int64) + + # because input_ids dtype is torch.int32, + # so convert draft_token_ids to torch.int32 here. + draft_token_ids = draft_token_ids.to(dtype=torch.int32) + + self.input_ids_pcp_full.cpu.scatter_( + dim=0, + index=draft_tokens_index_tensor, + src=draft_token_ids.flatten() + [prev_draft_token_indices_tensor].cpu(), + ) + def _get_cp_local_seq_lens( self, seq_lens: torch.Tensor, @@ -498,7 +605,7 @@ class PCPManager: torch.float32).argsort().to(torch.int32) def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, - input_batch): + input_batch, num_scheduled_tokens): from vllm_ascend.attention.utils import \ AscendPrefillContextParallelMetadata num_reqs = input_batch.num_reqs or query_lens.size(0) @@ -510,7 +617,9 @@ class PCPManager: self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded long_seq_metadata = None if self.pcp_world_size * self.dcp_world_size > 1: - decode_context_lens = input_batch.num_tokens[:num_decodes] + decode_context_lens = input_batch.num_computed_tokens_cpu[: + num_decodes] + num_scheduled_tokens[: + num_decodes] prefill_context_lens = input_batch.num_computed_tokens_cpu[ num_decodes:num_reqs] context_lens = np.concatenate(