support async mtp (#4511)

### What this PR does / why we need it?
this pr aims to support async_scheduling for mtp, which refer to vllm pr
https://github.com/vllm-project/vllm/pull/24799.
and this pr fix some synchronize problem in vllm-ascend.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Ronald
2025-12-06 17:15:57 +08:00
committed by GitHub
parent f067623afd
commit 3480094d7c
8 changed files with 477 additions and 83 deletions

View File

@@ -97,6 +97,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
make_empty_encoder_model_runner_output)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
@@ -213,6 +214,7 @@ class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
sampled_token_ids: torch.Tensor,
invalid_req_indices: list[int],
async_output_copy_stream: torch.npu.Stream,
vocab_size: int,
):
self._model_runner_output = model_runner_output
self._invalid_req_indices = invalid_req_indices
@@ -223,7 +225,7 @@ class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
# Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host.
self._sampled_token_ids = sampled_token_ids
self.vocab_size = vocab_size
# Initiate the copy on a separate stream, but do not synchronize it.
default_stream = torch.npu.current_stream()
with torch.npu.stream(async_output_copy_stream):
@@ -242,10 +244,17 @@ class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
# Release the device tensor once the copy has completed
del self._sampled_token_ids
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
max_gen_len = self._sampled_token_ids_cpu.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
else:
valid_sampled_token_ids, _ = RejectionSampler.parse_output(
self._sampled_token_ids_cpu,
self.vocab_size,
self._invalid_req_indices,
return_cu_num_tokens=False)
output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
return output
@@ -567,6 +576,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.async_output_copy_stream = torch.npu.Stream() if \
self.use_async_scheduling else None
self.num_spec_tokens = 0
if self.speculative_config:
self.num_spec_tokens = self.speculative_config.num_speculative_tokens # noqa
self.valid_sampled_token_count_event: torch.npu.Event | None = None
self.valid_sampled_token_count_copy_stream: torch.npu.Stream | None = None
if self.use_async_scheduling and self.num_spec_tokens:
self.valid_sampled_token_count_event = torch.npu.Event()
self.valid_sampled_token_count_copy_stream = torch.npu.Stream()
self.valid_sampled_token_count_cpu = torch.empty(
self.max_num_reqs,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory,
)
# Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside
# `initialize_kv_cache` based on the kv cache config. However, as in
@@ -791,13 +814,40 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
# wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count = self._get_valid_sampled_token_count()
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
resumed_from_preemption = req_id in req_data.resumed_req_ids
num_output_tokens = req_data.num_output_tokens[i]
req_index = self.input_batch.req_id_to_index.get(req_id)
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
# of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0.
# second step: num_computed_tokens = 100(prompt length),
# spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2.
# num_computed_tokens in first step and second step doesn't contain
# the spec tokens length, but in third step it contains the
# spec tokens length. we only need to update num_computed_tokens
# when prev_num_draft_len > 0.
if req_state.prev_num_draft_len:
if req_index is None:
req_state.prev_num_draft_len = 0
else:
assert self.input_batch.prev_req_id_to_index is not None
prev_req_index = self.input_batch.prev_req_id_to_index[
req_id]
num_accepted = valid_sampled_token_count[prev_req_index] - 1
num_rejected = req_state.prev_num_draft_len - num_accepted
num_computed_tokens -= num_rejected
req_state.output_token_ids.extend([-1] * num_accepted)
req_state.num_computed_tokens = num_computed_tokens
if not is_last_rank:
@@ -828,12 +878,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
# The request was either preempted and resumed later, or was
# not scheduled in the previous step and needs to be added
# again.
if self.use_async_scheduling and num_output_tokens > 0:
# We must recover the output token ids for resumed requests
# in the async scheduling case, so that correct input_ids
# are obtained.
resumed_token_ids = req_data.all_token_ids[req_id]
req_state.output_token_ids = resumed_token_ids[
-num_output_tokens:]
req_ids_to_add.append(req_id)
continue
@@ -860,8 +918,10 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
if spec_token_ids:
num_spec_tokens = len(spec_token_ids)
num_spec_tokens = len(spec_token_ids)
if self.use_async_scheduling:
req_state.prev_num_draft_len = num_spec_tokens
if num_spec_tokens:
start_index = self.input_batch.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.input_batch.token_ids_cpu[
@@ -882,6 +942,17 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
def _get_valid_sampled_token_count(self) -> list[int]:
# Wait until valid_sampled_tokens_count is copied to cpu,
prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
if (self.valid_sampled_token_count_event is None
or prev_sampled_token_ids is None):
return []
counts_cpu = self.valid_sampled_token_count_cpu
self.valid_sampled_token_count_event.synchronize()
return counts_cpu[:prev_sampled_token_ids.shape[0]].tolist()
def _init_mrope_positions(self, req_state: CachedRequestState):
assert supports_mrope(self.model), "MROPE is not supported"
req_state.mrope_positions, req_state.mrope_position_delta = \
@@ -901,26 +972,25 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# immediately once the other two flags are no longer needed.
if self.dp_size == 1:
return num_tokens, None, with_prefill
# Sync num_tokens, with_prefill across dp ranks
num_tokens_tensor = torch.tensor([
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
],
dtype=torch.int32,
device="npu")
device="cpu")
flags_tensor = torch.tensor([int(with_prefill)],
dtype=torch.int32,
device="npu")
device="cpu")
packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])
dist.all_reduce(packed_tensor, group=get_dp_group().device_group)
# use cpu_group to avoid cpu synchronization issue.
# it can be overlapped with main moell execution on npu.
dist.all_reduce(packed_tensor, group=get_dp_group().cpu_group)
# Unpack the results
num_tokens_across_dp = packed_tensor[:-1]
synced_flags = packed_tensor[-1:]
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
global_with_prefill = bool(synced_flags[0])
@@ -1195,7 +1265,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
return cu_num_tokens, arange
def _prepare_input_ids(self, total_num_scheduled_tokens: int,
def _prepare_input_ids(self, scheduler_output: "SchedulerOutput",
total_num_scheduled_tokens: int,
cu_num_tokens: np.ndarray) -> None:
"""Prepare the input IDs for the current batch.
@@ -1218,21 +1289,44 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# on the NPU from prev_sampled_token_ids.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
assert prev_req_id_to_index is not None
flattened_indices = []
prev_common_req_indices = []
sample_flattened_indices: list[int] = []
spec_flattened_indices: list[int] = []
prev_common_req_indices: list[int] = []
prev_draft_token_indices: list[int] = []
indices_match = True
max_flattened_index = -1
total_num_spec_tokens = 0
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id, cur_index in self.input_batch.req_id_to_index.items():
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
prev_common_req_indices.append(prev_index)
# We need to compute the flattened input_ids index of the
# last token in each common request.
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
total_num_spec_tokens += draft_len
flattened_index = cu_num_tokens[cur_index].item() - 1
flattened_indices.append(flattened_index)
indices_match &= (prev_index == flattened_index)
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
# sample_flattened_indices = [0, 2, 5]
# spec_flattened_indices = [1, 3, 4, 6, 7]
sample_flattened_indices.append(flattened_index - draft_len)
spec_flattened_indices.extend(
range(flattened_index - draft_len + 1,
flattened_index + 1))
start = prev_index * self.num_spec_tokens
# prev_draft_token_indices is used to find which draft_tokens_id
# should be copied to input_ids
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
# flatten draft_tokens_id [1,2,3,4,5,6]
# draft_len of each request [1, 2, 1]
# then prev_draft_token_indices is [0, 2, 3, 4]
prev_draft_token_indices.extend(range(start,
start + draft_len))
indices_match &= prev_index == flattened_index
max_flattened_index = max(max_flattened_index, flattened_index)
num_commmon_tokens = len(flattened_indices)
if num_commmon_tokens < total_num_scheduled_tokens:
num_commmon_tokens = len(sample_flattened_indices)
total_without_spec = (total_num_scheduled_tokens -
total_num_spec_tokens)
if num_commmon_tokens < total_without_spec:
# If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the NPU first.
self.input_ids[:total_num_scheduled_tokens].copy_(
@@ -1256,21 +1350,45 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
non_blocking=True)
self.is_token_ids.gpu[:num_commmon_tokens] = True
return
# Upload the index tensors asynchronously
# so the scatter can be non-blocking.
input_ids_index_tensor = torch.tensor(flattened_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(
self.device,
non_blocking=True)
# Upload the index tensors asynchronously so the scatter can be non-blocking.
sampled_tokens_index_tensor = torch.tensor(
sample_flattened_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
self.input_ids.scatter_(dim=0,
index=input_ids_index_tensor,
src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0])
self.input_ids.scatter_(
dim=0,
index=sampled_tokens_index_tensor,
src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0],
)
# scatter the draft tokens after the sampled tokens are scattered.
if self._draft_token_ids is None or not spec_flattened_indices:
return
assert isinstance(self._draft_token_ids, torch.Tensor)
draft_tokens_index_tensor = torch.tensor(
spec_flattened_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
prev_draft_token_indices_tensor = torch.tensor(
prev_draft_token_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
# because input_ids dtype is torch.int32,
# so convert draft_token_ids to torch.int32 here.
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
self._draft_token_ids = None
self.input_ids.scatter_(
dim=0,
index=draft_tokens_index_tensor,
src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
)
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
@@ -1544,7 +1662,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Copy the tensors to the NPU.
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens,
cu_num_tokens)
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
@@ -1993,8 +2112,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
cu_num_scheduled_tokens - num_sampled_tokens,
num_sampled_tokens)
logits_indices_pcp += arange
logits_indices_pcp = torch.from_numpy(logits_indices_pcp).to(
self.device, non_blocking=True)
logits_indices_pcp = torch.from_numpy(
logits_indices_pcp).pin_memory().to(self.device,
non_blocking=True)
# Compute the bonus logits indices.
bonus_logits_indices = cu_num_sampled_tokens - 1
@@ -2015,16 +2135,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
target_logits_indices += arange
# TODO: Optimize the CPU -> NPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True)
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True)
cu_num_draft_tokens = (
torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
self.device, non_blocking=True))
cu_num_sampled_tokens = (
torch.from_numpy(cu_num_sampled_tokens).pin_memory().to(
self.device, non_blocking=True))
logits_indices = (torch.from_numpy(logits_indices).pin_memory().to(
self.device, non_blocking=True))
target_logits_indices = (
torch.from_numpy(target_logits_indices).pin_memory().to(
self.device, non_blocking=True))
bonus_logits_indices = torch.from_numpy(
bonus_logits_indices).pin_memory().to(self.device,
non_blocking=True)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
@@ -2466,7 +2590,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
sampler_output.sampled_token_ids = output_token_ids
if self.need_accepted_tokens:
self._update_states_after_model_execute(output_token_ids)
discard_sampled_tokens_req_indices = \
self.discard_request_indices.np[:self.num_discarded_requests]
for i in discard_sampled_tokens_req_indices:
@@ -2494,6 +2617,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
@@ -2514,13 +2638,14 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
)
invalid_req_indices_set = set(invalid_req_indices)
assert sampled_token_ids.shape[-1] == 1
if self.num_spec_tokens <= 0:
assert sampled_token_ids.shape[-1] == 1
# Cache the sampled tokens on the NPU and avoid CPU sync.
# These will be copied into input_ids in the next step
# when preparing inputs.
self.input_batch.prev_sampled_token_ids = sampled_token_ids
# Cache the sampled tokens on the NPU and avoid CPU sync.
# These will be copied into input_ids in the next step
# when preparing inputs.
self.input_batch.prev_sampled_token_ids = \
sampled_token_ids
self.input_batch.prev_sampled_token_ids_invalid_indices = \
invalid_req_indices_set
self.input_batch.prev_req_id_to_index = {
@@ -2629,6 +2754,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
sampled_token_ids=sampled_token_ids,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
vocab_size=self.input_batch.vocab_size,
)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:

View File

@@ -68,6 +68,8 @@ class CachedRequestState:
lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
prev_num_draft_len: int = 0 # previous number of draft tokens
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)