support async mtp (#4511)

### What this PR does / why we need it?
this pr aims to support async_scheduling for mtp, which refer to vllm pr
https://github.com/vllm-project/vllm/pull/24799.
and this pr fix some synchronize problem in vllm-ascend.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Ronald
2025-12-06 17:15:57 +08:00
committed by GitHub
parent f067623afd
commit 3480094d7c
8 changed files with 477 additions and 83 deletions

View File

@@ -142,6 +142,9 @@ class MtpProposer(Proposer):
self.arange = torch.arange(max_num_slots_for_arange,
device=device,
dtype=torch.int32)
self.arange_cpu = torch.arange(max_num_slots_for_arange,
device="cpu",
dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
@@ -157,6 +160,7 @@ class MtpProposer(Proposer):
)
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)
@@ -351,6 +355,8 @@ class MtpProposer(Proposer):
self.runner.discard_request_indices.gpu,
self.runner.num_discarded_requests
)
self._copy_valid_sampled_token_count(next_token_ids,
valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size > 1:
@@ -430,6 +436,28 @@ class MtpProposer(Proposer):
return draft_token_ids
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor,
valid_sampled_tokens_count: torch.Tensor) -> None:
if self.runner.valid_sampled_token_count_event is not None:
default_stream = torch.npu.current_stream()
# initialize a new stream to overlap the copy operation with
# prepare_input of draft model.
with torch.npu.stream(
self.runner.valid_sampled_token_count_copy_stream):
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
default_stream) # type: ignore
self.runner.valid_sampled_token_count_cpu[:
valid_sampled_tokens_count
.shape[0]].copy_(
valid_sampled_tokens_count,
non_blocking=True
)
self.runner.valid_sampled_token_count_event.record()
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
1)
def _init_mtp_model(self):
architecture = self.vllm_config.model_config.architecture
target_device = self.vllm_config.device_config.device
@@ -696,6 +724,11 @@ class MtpProposer(Proposer):
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
aclgraph_runtime_mode, batch_descriptor = \
self.runner.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
if self.use_async_scheduling:
# there is synchronize between mtp steps when enable aclgraph,
# disable aclgraph when use async scheduling to avoid the
# synchronize overhead.
aclgraph_runtime_mode = CUDAGraphMode.NONE
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
) and aclgraph_runtime_mode == CUDAGraphMode.FULL:
@@ -822,7 +855,7 @@ class MtpProposer(Proposer):
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \
aclgraph_runtime_mode != CUDAGraphMode.FULL):
decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
decode_metadata.actual_seq_lengths_q = self.arange_cpu[
1:batch_size + 1].tolist()
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
decode_metadata.actual_seq_lengths_q = \
@@ -847,7 +880,9 @@ class MtpProposer(Proposer):
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions[:batch_size])
# Increment the sequence lengths.
attn_metadata_i.seq_lens[:batch_size] += 1
# This is an out-of-place operation to avoid modifying the original tensor
# when enable async_scheduling.
attn_metadata_i.seq_lens = attn_metadata_i.seq_lens + 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
exceeds_max_model_len_cpu = exceeds_max_model_len.to(