support async mtp (#4511)
### What this PR does / why we need it?
this pr aims to support async_scheduling for mtp, which refer to vllm pr
https://github.com/vllm-project/vllm/pull/24799.
and this pr fix some synchronize problem in vllm-ascend.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -556,35 +556,43 @@ class AscendMLAMetadataBuilder:
|
||||
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
chunked_context_metadata = \
|
||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=local_chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
|
||||
device, non_blocking=True),
|
||||
starts=local_chunk_starts.pin_memory().to(
|
||||
device, non_blocking=True),
|
||||
seq_tot=padded_local_chunk_seq_lens.sum(
|
||||
dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
||||
local_context_lens_allranks=local_context_lens_allranks.tolist(),
|
||||
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
|
||||
device, non_blocking=True
|
||||
),
|
||||
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.
|
||||
npu(),
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens
|
||||
.tolist(),
|
||||
local_context_lens_allranks=local_context_lens_allranks
|
||||
.tolist(),
|
||||
padded_local_cu_seq_lens=
|
||||
padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
|
||||
device, non_blocking=True),
|
||||
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
||||
chunk_size=padded_local_max_context_chunk_across_ranks,
|
||||
)
|
||||
else:
|
||||
chunked_context_metadata = \
|
||||
chunked_context_metadata = (
|
||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
|
||||
device, non_blocking=True),
|
||||
starts=chunk_starts.pin_memory().to(
|
||||
device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(
|
||||
dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
))
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos = self.cos_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
@@ -616,7 +624,8 @@ class AscendMLAMetadataBuilder:
|
||||
cos = common_attn_metadata.cos
|
||||
sin = common_attn_metadata.sin
|
||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
|
||||
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
|
||||
1].tolist()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decodes]
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
|
||||
Reference in New Issue
Block a user