[Bugfix] fix bug of pcp+mtp+async scheduler (#5994)
### What this PR does / why we need it?
Fixed the issue where the PCP and MTP services could not be started due
to asynchronous scheduling.
After the pcp, mtp, and asynchronous scheduling functions are enabled,
the service is suspended because of a shape mismatch after a curl
request is sent. This PR resolves this issue.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -70,12 +70,12 @@ def test_pcp_dcp_mtp3_eager():
|
||||
max_num_batched_tokens=1024,
|
||||
enable_expert_parallel=True,
|
||||
block_size=128,
|
||||
async_scheduling=True,
|
||||
speculative_config={
|
||||
"num_speculative_tokens": 3,
|
||||
"method": "deepseek_mtp",
|
||||
},
|
||||
enforce_eager=True,
|
||||
async_scheduling=False,
|
||||
) as runner:
|
||||
runner.generate_greedy(prompts, 32)
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
||||
max_num_reqs=1000,
|
||||
device="cpu",
|
||||
vllm_config=vllm_config,
|
||||
use_async_scheduling=False,
|
||||
pin_memory=False)
|
||||
input_batch = MagicMock()
|
||||
input_batch.num_reqs = num_reqs
|
||||
@@ -65,13 +66,16 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
||||
num_prompt_tokens.append(query_lens[i])
|
||||
num_tokens.append(query_lens[i])
|
||||
|
||||
input_batch.num_computed_tokens_cpu = torch.tensor(num_computed_tokens)
|
||||
input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens)
|
||||
input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens)
|
||||
input_batch.num_tokens = torch.tensor(num_tokens)
|
||||
num_scheduled_tokens = np.array(
|
||||
query_lens) - input_batch.num_computed_tokens_cpu
|
||||
|
||||
query_lens = torch.tensor(query_lens)
|
||||
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
|
||||
input_batch)
|
||||
input_batch,
|
||||
num_scheduled_tokens)
|
||||
|
||||
if not expect_not_none:
|
||||
assert result is None, f"Expected to return None, but got {type(result)}"
|
||||
@@ -128,6 +132,7 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens,
|
||||
max_num_reqs=1000,
|
||||
device="cpu",
|
||||
vllm_config=vllm_config,
|
||||
use_async_scheduling=False,
|
||||
pin_memory=False)
|
||||
input_batch = MagicMock()
|
||||
input_batch.num_reqs = num_reqs
|
||||
@@ -193,6 +198,7 @@ def test_get_cp_local_seq_lens(
|
||||
max_num_reqs=1000,
|
||||
device="cpu",
|
||||
vllm_config=vllm_config,
|
||||
use_async_scheduling=False,
|
||||
pin_memory=False)
|
||||
ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size,
|
||||
dcp_world_size,
|
||||
@@ -276,6 +282,7 @@ def test_generate_pcp_mtp_input(
|
||||
max_num_reqs=max_num_reqs,
|
||||
device="cpu",
|
||||
vllm_config=vllm_config,
|
||||
use_async_scheduling=False,
|
||||
pin_memory=False)
|
||||
arange_np = np.arange(max_model_len)
|
||||
input_batch = MagicMock()
|
||||
|
||||
@@ -226,6 +226,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.max_num_reqs,
|
||||
self.device,
|
||||
self.vllm_config,
|
||||
self.use_async_scheduling,
|
||||
self.pin_memory,
|
||||
)
|
||||
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
|
||||
@@ -540,10 +541,18 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# for pcp, prefill mtp should use origin scheduleroutput ,
|
||||
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
|
||||
self.pcp_manager.generate_pcp_mtp_input(
|
||||
num_reqs, total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens, with_prefill,
|
||||
self.input_batch, self.arange_np, req_indices, positions_np,
|
||||
cu_num_tokens)
|
||||
num_reqs,
|
||||
total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
with_prefill,
|
||||
self.input_batch,
|
||||
self.arange_np,
|
||||
req_indices,
|
||||
positions_np,
|
||||
cu_num_tokens,
|
||||
self._draft_token_ids, # type: ignore[has-type]
|
||||
scheduler_output,
|
||||
self.num_spec_tokens)
|
||||
|
||||
if self.pcp_size > 1:
|
||||
if not self.vllm_config.model_config.use_mla:
|
||||
@@ -929,7 +938,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata(
|
||||
total_num_scheduled_tokens, self.query_lens,
|
||||
self.input_batch)
|
||||
self.input_batch, num_scheduled_tokens)
|
||||
blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1)
|
||||
if self.pcp_size > 1:
|
||||
slot_mapping_pcp = self.pcp_manager.get_padded_slot_mapping(
|
||||
@@ -1946,7 +1955,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
slot_mapping = self.input_batch.block_table[
|
||||
kv_cache_group_id].slot_mapping
|
||||
long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata(
|
||||
num_tokens, self.query_lens, self.input_batch)
|
||||
num_tokens, self.query_lens, self.input_batch,
|
||||
num_scheduled_tokens)
|
||||
if long_seq_metadata is not None:
|
||||
pcp_world_size = get_pcp_group().world_size
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||
#
|
||||
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -25,6 +25,9 @@ from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class PCPManager:
|
||||
"""
|
||||
@@ -44,6 +47,7 @@ class PCPManager:
|
||||
max_num_reqs: int,
|
||||
device: torch.device,
|
||||
vllm_config: VllmConfig,
|
||||
use_async_scheduling: bool,
|
||||
pin_memory: bool = False,
|
||||
) -> None:
|
||||
self.pcp_world_size = pcp_world_size
|
||||
@@ -58,6 +62,7 @@ class PCPManager:
|
||||
self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
self.device = device
|
||||
self.use_async_scheduling = use_async_scheduling
|
||||
self.pcp_allgather_restore_idx = CpuGpuBuffer(
|
||||
max_buffer_num_tokens,
|
||||
dtype=torch.int64,
|
||||
@@ -354,6 +359,9 @@ class PCPManager:
|
||||
req_indices=None,
|
||||
positions_np=None,
|
||||
cu_num_tokens=None,
|
||||
draft_token_ids=None,
|
||||
scheduler_output=None,
|
||||
num_spec_tokens=None,
|
||||
):
|
||||
"""
|
||||
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
|
||||
@@ -390,6 +398,12 @@ class PCPManager:
|
||||
torch.from_numpy(token_indices_pcp_full),
|
||||
out=self.input_ids_pcp_full.
|
||||
cpu[:total_num_scheduled_tokens_pcp_full])
|
||||
if self.use_async_scheduling:
|
||||
self._update_input_ids_pcp_full_ids(input_batch, draft_token_ids,
|
||||
scheduler_output,
|
||||
total_num_scheduled_tokens,
|
||||
cu_num_tokens_pcp_full,
|
||||
num_spec_tokens)
|
||||
self.query_lens_pcp_full.copy_to_gpu()
|
||||
self.query_start_loc_pcp_full.copy_to_gpu()
|
||||
self.input_ids_pcp_full.copy_to_gpu(
|
||||
@@ -428,6 +442,99 @@ class PCPManager:
|
||||
mtp_slot_pad[unpad_mask] = mtp_slot_ori
|
||||
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
|
||||
|
||||
def _update_input_ids_pcp_full_ids(
|
||||
self,
|
||||
input_batch,
|
||||
draft_token_ids,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
total_num_scheduled_tokens: int,
|
||||
cu_num_tokens: np.ndarray,
|
||||
num_spec_tokens: int,
|
||||
) -> None:
|
||||
"""Prepare the input IDs for the current batch.
|
||||
|
||||
Carefully handles the `prev_sampled_token_ids` which can be cached
|
||||
from the previous engine iteration, in which case those tokens on the
|
||||
GPU need to be copied into the corresponding slots into input_ids."""
|
||||
|
||||
if (input_batch.prev_sampled_token_ids is None
|
||||
or input_batch.prev_req_id_to_index is None):
|
||||
return
|
||||
|
||||
# Async scheduling case, where some decode requests from the previous
|
||||
# iteration won't have entries in input_ids_cpu and need to be copied
|
||||
# on the GPU from prev_sampled_token_ids.
|
||||
prev_req_id_to_index = input_batch.prev_req_id_to_index
|
||||
sample_flattened_indices: list[int] = []
|
||||
spec_flattened_indices: list[int] = []
|
||||
prev_common_req_indices: list[int] = []
|
||||
prev_draft_token_indices: list[int] = []
|
||||
total_num_spec_tokens = 0
|
||||
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
|
||||
for req_id, cur_index in input_batch.req_id_to_index.items():
|
||||
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
|
||||
prev_common_req_indices.append(prev_index)
|
||||
# We need to compute the flattened input_ids index of the
|
||||
# last token in each common request.
|
||||
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
|
||||
total_num_spec_tokens += draft_len
|
||||
flattened_index = cu_num_tokens[cur_index].item() - 1
|
||||
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
|
||||
# sample_flattened_indices = [0, 2, 5]
|
||||
# spec_flattened_indices = [1, 3, 4, 6, 7]
|
||||
sample_flattened_indices.append(flattened_index - draft_len)
|
||||
spec_flattened_indices.extend(
|
||||
range(flattened_index - draft_len + 1,
|
||||
flattened_index + 1))
|
||||
start = prev_index * num_spec_tokens
|
||||
# prev_draft_token_indices is used to find which draft_tokens_id
|
||||
# should be copied to input_ids
|
||||
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
|
||||
# flatten draft_tokens_id [1,2,3,4,5,6]
|
||||
# draft_len of each request [1, 2, 1]
|
||||
# then prev_draft_token_indices is [0, 2, 3, 4]
|
||||
prev_draft_token_indices.extend(range(start,
|
||||
start + draft_len))
|
||||
num_commmon_tokens = len(sample_flattened_indices)
|
||||
|
||||
if num_commmon_tokens == 0:
|
||||
# No requests in common with the previous iteration
|
||||
# So input_ids.cpu will have all the input ids.
|
||||
return
|
||||
# Upload the index tensors asynchronously so the scatter can be non-blocking.
|
||||
sampled_tokens_index_tensor = torch.tensor(sample_flattened_indices,
|
||||
dtype=torch.int64)
|
||||
prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices,
|
||||
dtype=torch.int64)
|
||||
self.input_ids_pcp_full.cpu.scatter_(
|
||||
dim=0,
|
||||
index=sampled_tokens_index_tensor,
|
||||
src=input_batch.prev_sampled_token_ids[
|
||||
prev_common_req_indices_tensor, 0].cpu(),
|
||||
)
|
||||
|
||||
# Scatter the draft tokens after the sampled tokens are scattered.
|
||||
if draft_token_ids is None or not spec_flattened_indices:
|
||||
return
|
||||
|
||||
assert isinstance(draft_token_ids, torch.Tensor)
|
||||
draft_tokens_index_tensor = torch.tensor(spec_flattened_indices,
|
||||
dtype=torch.int64)
|
||||
prev_draft_token_indices_tensor = torch.tensor(
|
||||
prev_draft_token_indices, dtype=torch.int64)
|
||||
|
||||
# because input_ids dtype is torch.int32,
|
||||
# so convert draft_token_ids to torch.int32 here.
|
||||
draft_token_ids = draft_token_ids.to(dtype=torch.int32)
|
||||
|
||||
self.input_ids_pcp_full.cpu.scatter_(
|
||||
dim=0,
|
||||
index=draft_tokens_index_tensor,
|
||||
src=draft_token_ids.flatten()
|
||||
[prev_draft_token_indices_tensor].cpu(),
|
||||
)
|
||||
|
||||
def _get_cp_local_seq_lens(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
@@ -498,7 +605,7 @@ class PCPManager:
|
||||
torch.float32).argsort().to(torch.int32)
|
||||
|
||||
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
|
||||
input_batch):
|
||||
input_batch, num_scheduled_tokens):
|
||||
from vllm_ascend.attention.utils import \
|
||||
AscendPrefillContextParallelMetadata
|
||||
num_reqs = input_batch.num_reqs or query_lens.size(0)
|
||||
@@ -510,7 +617,9 @@ class PCPManager:
|
||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||
long_seq_metadata = None
|
||||
if self.pcp_world_size * self.dcp_world_size > 1:
|
||||
decode_context_lens = input_batch.num_tokens[:num_decodes]
|
||||
decode_context_lens = input_batch.num_computed_tokens_cpu[:
|
||||
num_decodes] + num_scheduled_tokens[:
|
||||
num_decodes]
|
||||
prefill_context_lens = input_batch.num_computed_tokens_cpu[
|
||||
num_decodes:num_reqs]
|
||||
context_lens = np.concatenate(
|
||||
|
||||
Reference in New Issue
Block a user