support async mtp (#4511)
### What this PR does / why we need it?
this pr aims to support async_scheduling for mtp, which refer to vllm pr
https://github.com/vllm-project/vllm/pull/24799.
and this pr fix some synchronize problem in vllm-ascend.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -142,6 +142,9 @@ class MtpProposer(Proposer):
|
||||
self.arange = torch.arange(max_num_slots_for_arange,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.arange_cpu = torch.arange(max_num_slots_for_arange,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
@@ -157,6 +160,7 @@ class MtpProposer(Proposer):
|
||||
)
|
||||
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||
"index_topk")
|
||||
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
||||
|
||||
def load_model(self, model) -> None:
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
@@ -351,6 +355,8 @@ class MtpProposer(Proposer):
|
||||
self.runner.discard_request_indices.gpu,
|
||||
self.runner.num_discarded_requests
|
||||
)
|
||||
self._copy_valid_sampled_token_count(next_token_ids,
|
||||
valid_sampled_tokens_count)
|
||||
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.pcp_size > 1:
|
||||
@@ -430,6 +436,28 @@ class MtpProposer(Proposer):
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def _copy_valid_sampled_token_count(
|
||||
self, next_token_ids: torch.Tensor,
|
||||
valid_sampled_tokens_count: torch.Tensor) -> None:
|
||||
if self.runner.valid_sampled_token_count_event is not None:
|
||||
default_stream = torch.npu.current_stream()
|
||||
# initialize a new stream to overlap the copy operation with
|
||||
# prepare_input of draft model.
|
||||
with torch.npu.stream(
|
||||
self.runner.valid_sampled_token_count_copy_stream):
|
||||
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
|
||||
default_stream) # type: ignore
|
||||
self.runner.valid_sampled_token_count_cpu[:
|
||||
valid_sampled_tokens_count
|
||||
.shape[0]].copy_(
|
||||
valid_sampled_tokens_count,
|
||||
non_blocking=True
|
||||
)
|
||||
self.runner.valid_sampled_token_count_event.record()
|
||||
|
||||
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
|
||||
1)
|
||||
|
||||
def _init_mtp_model(self):
|
||||
architecture = self.vllm_config.model_config.architecture
|
||||
target_device = self.vllm_config.device_config.device
|
||||
@@ -696,6 +724,11 @@ class MtpProposer(Proposer):
|
||||
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.runner.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
|
||||
if self.use_async_scheduling:
|
||||
# there is synchronize between mtp steps when enable aclgraph,
|
||||
# disable aclgraph when use async scheduling to avoid the
|
||||
# synchronize overhead.
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
) and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
@@ -822,7 +855,7 @@ class MtpProposer(Proposer):
|
||||
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
|
||||
if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \
|
||||
aclgraph_runtime_mode != CUDAGraphMode.FULL):
|
||||
decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
||||
decode_metadata.actual_seq_lengths_q = self.arange_cpu[
|
||||
1:batch_size + 1].tolist()
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
decode_metadata.actual_seq_lengths_q = \
|
||||
@@ -847,7 +880,9 @@ class MtpProposer(Proposer):
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions[:batch_size])
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata_i.seq_lens[:batch_size] += 1
|
||||
# This is an out-of-place operation to avoid modifying the original tensor
|
||||
# when enable async_scheduling.
|
||||
attn_metadata_i.seq_lens = attn_metadata_i.seq_lens + 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
||||
|
||||
Reference in New Issue
Block a user