support multistep decode (#299)

Add multi step scheduler support for vllm-ascend

Signed-off-by: new-TonyWang <wangtonyyu222@gmail.com>
This commit is contained in:
Tony
2025-03-11 19:20:06 +08:00
committed by GitHub
parent feb6bdb12e
commit 4c9d78a035
5 changed files with 1067 additions and 10 deletions

View File

@@ -16,6 +16,7 @@
#
from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import numpy as np
@@ -38,7 +39,8 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm_ascend.worker.model_runner import ModelInputForNPUBuilder
from vllm_ascend.worker.model_runner import (
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
@@ -197,26 +199,52 @@ class AscendMetadata(AttentionMetadata):
# FIXME: It is for flash attn.
# Maximum sequence length among prefill batch. 0 if there are decoding
# Avoid mypy error
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
block_tables: Optional[torch.Tensor]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["AscendMetadata"] = None
_cached_decode_metadata: Optional["AscendMetadata"] = None
@@ -254,10 +282,18 @@ class AscendMetadata(AttentionMetadata):
or (self.encoder_seq_lens is not None))
# Compute some attn_metadata fields which default to None.
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
@@ -274,7 +310,11 @@ class AscendMetadata(AttentionMetadata):
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_query_len=0,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
@@ -302,6 +342,8 @@ class AscendMetadata(AttentionMetadata):
self.slot_mapping[self.num_prefill_tokens:])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[self.num_prefills:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
@@ -314,8 +356,19 @@ class AscendMetadata(AttentionMetadata):
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
@@ -328,6 +381,98 @@ class AscendMetadata(AttentionMetadata):
enable_kv_scales_calculation=False)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForNPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
if turn_prefills_into_decodes:
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
# TODO optimize these codes using ascendc just like flash attention backend using cuda
# update input_tokens
sampled_token_ids_list = sampled_token_ids[:
num_queries].squeeze( # type: ignore
-1)
model_input.input_tokens[:
num_queries] = sampled_token_ids_list # type: ignore
# get seq_lens and input_positions
seq_lens = self.seq_lens_tensor[:num_queries]
next_seq_lens = seq_lens + 1
next_input_pos = next_seq_lens - 1
# update seq_lens and input_positions
self.seq_lens_tensor[:num_queries] = next_seq_lens
model_input.input_positions[:
num_queries] = next_input_pos # type: ignore
# 计算 block index 和 offset
block_idx = next_input_pos // block_size
block_offset = next_input_pos % block_size
current_block_table = self.block_tables.gather(
1, block_idx.unsqueeze(-1)).squeeze(-1)
slot_num = current_block_table * block_size + block_offset
# update slot_mapping
self.slot_mapping[:num_queries] = slot_num
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
@@ -430,6 +575,11 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
device = self.runner.device
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
@@ -440,6 +590,9 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.input_builder.runner.device)
else:
self.attn_mask = None
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
block_tables = make_tensor_with_pad(
self.block_tables,
@@ -450,9 +603,17 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
assert max_query_len > 0, "query_lens: {}".format(query_lens)
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
@@ -466,15 +627,19 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
return AscendMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=self.num_decode_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
attn_mask=self.attn_mask,
)