support multistep decode (#299)
Add multi step scheduler support for vllm-ascend Signed-off-by: new-TonyWang <wangtonyyu222@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from itertools import accumulate
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -38,7 +39,8 @@ from vllm.attention.backends.utils import (CommonAttentionState,
|
|||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm_ascend.worker.model_runner import ModelInputForNPUBuilder
|
from vllm_ascend.worker.model_runner import (
|
||||||
|
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||||
|
|
||||||
|
|
||||||
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
||||||
@@ -197,26 +199,52 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
# FIXME: It is for flash attn.
|
# FIXME: It is for flash attn.
|
||||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
|
# Avoid mypy error
|
||||||
|
# Total number of prefill requests.
|
||||||
|
num_prefills: int
|
||||||
|
# Number of prefill tokens.
|
||||||
|
num_prefill_tokens: int
|
||||||
|
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||||
|
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||||
|
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||||
|
# in block 0, and 1st slot in block 1, respectively.
|
||||||
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
# requests only.
|
# requests only.
|
||||||
max_prefill_seq_len: int
|
max_prefill_seq_len: int
|
||||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||||
# requests only.
|
# requests only.
|
||||||
max_decode_seq_len: int
|
max_decode_seq_len: int
|
||||||
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||||
|
# so far).
|
||||||
|
context_lens_tensor: Optional[torch.Tensor]
|
||||||
|
|
||||||
# (batch_size, max_blocks_per_seq).
|
# (batch_size, max_blocks_per_seq).
|
||||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||||
block_tables: Optional[torch.Tensor]
|
block_tables: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# seq_lens stored as a tensor.
|
||||||
|
seq_lens_tensor: Optional[torch.Tensor]
|
||||||
|
|
||||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||||
# the computed tokens + new tokens None if it is a decoding.
|
# the computed tokens + new tokens None if it is a decoding.
|
||||||
seq_lens: Optional[List[int]] = None
|
seq_lens: Optional[List[int]] = None
|
||||||
|
|
||||||
# seq_lens stored as a tensor.
|
|
||||||
seq_lens_tensor: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
# Maximum query length in the batch. None for decoding.
|
# Maximum query length in the batch. None for decoding.
|
||||||
max_query_len: Optional[int] = None
|
max_query_len: Optional[int] = None
|
||||||
|
|
||||||
|
# Max number of query tokens among request in the batch.
|
||||||
|
max_decode_query_len: Optional[int] = None
|
||||||
|
|
||||||
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||||
|
# the batch, used to index into subquery. E.g., if the subquery length
|
||||||
|
# is [4, 6], it is [0, 4, 10].
|
||||||
|
query_start_loc: Optional[torch.Tensor] = None
|
||||||
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||||
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||||
|
# [4, 6], it is [0, 4, 10].
|
||||||
|
seq_start_loc: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Self-attention prefill/decode metadata cache
|
# Self-attention prefill/decode metadata cache
|
||||||
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
||||||
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
||||||
@@ -254,10 +282,18 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
or (self.encoder_seq_lens is not None))
|
or (self.encoder_seq_lens is not None))
|
||||||
|
|
||||||
# Compute some attn_metadata fields which default to None.
|
# Compute some attn_metadata fields which default to None.
|
||||||
|
query_start_loc = (None if self.query_start_loc is None else
|
||||||
|
self.query_start_loc[:self.num_prefills + 1])
|
||||||
slot_mapping = (None if self.slot_mapping is None else
|
slot_mapping = (None if self.slot_mapping is None else
|
||||||
self.slot_mapping[:self.num_prefill_tokens])
|
self.slot_mapping[:self.num_prefill_tokens])
|
||||||
seq_lens = (None if self.seq_lens is None else
|
seq_lens = (None if self.seq_lens is None else
|
||||||
self.seq_lens[:self.num_prefills])
|
self.seq_lens[:self.num_prefills])
|
||||||
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||||
|
self.seq_lens_tensor[:self.num_prefills])
|
||||||
|
seq_start_loc = (None if self.seq_start_loc is None else
|
||||||
|
self.seq_start_loc[:self.num_prefills + 1])
|
||||||
|
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||||
|
self.context_lens_tensor[:self.num_prefills])
|
||||||
block_tables = (None if self.block_tables is None else
|
block_tables = (None if self.block_tables is None else
|
||||||
self.block_tables[:self.num_prefills])
|
self.block_tables[:self.num_prefills])
|
||||||
|
|
||||||
@@ -274,7 +310,11 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||||
|
max_decode_query_len=0,
|
||||||
max_decode_seq_len=0,
|
max_decode_seq_len=0,
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
|
seq_start_loc=seq_start_loc,
|
||||||
|
context_lens_tensor=context_lens_tensor,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
# Begin encoder & cross attn fields below...
|
# Begin encoder & cross attn fields below...
|
||||||
encoder_seq_lens=self.encoder_seq_lens,
|
encoder_seq_lens=self.encoder_seq_lens,
|
||||||
@@ -302,6 +342,8 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
self.slot_mapping[self.num_prefill_tokens:])
|
self.slot_mapping[self.num_prefill_tokens:])
|
||||||
seq_lens = (None if self.seq_lens is None else
|
seq_lens = (None if self.seq_lens is None else
|
||||||
self.seq_lens[self.num_prefills:])
|
self.seq_lens[self.num_prefills:])
|
||||||
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||||
|
self.seq_lens_tensor[self.num_prefills:])
|
||||||
block_tables = (None if self.block_tables is None else
|
block_tables = (None if self.block_tables is None else
|
||||||
self.block_tables[self.num_prefills:])
|
self.block_tables[self.num_prefills:])
|
||||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||||
@@ -314,8 +356,19 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
|
max_decode_query_len=self.max_decode_query_len,
|
||||||
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=self.max_decode_seq_len,
|
max_decode_seq_len=self.max_decode_seq_len,
|
||||||
|
# Batch may be composed of prefill|decodes, adjust query start
|
||||||
|
# indices to refer to the start of decodes. E.g.
|
||||||
|
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||||
|
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||||
|
self.query_start_loc[self.num_prefills])
|
||||||
|
if self.query_start_loc is not None else None,
|
||||||
|
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||||
|
if self.seq_start_loc is not None else None,
|
||||||
|
context_lens_tensor=None,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
# Begin encoder & cross attn fields below...
|
# Begin encoder & cross attn fields below...
|
||||||
encoder_seq_lens=self.encoder_seq_lens,
|
encoder_seq_lens=self.encoder_seq_lens,
|
||||||
@@ -328,6 +381,98 @@ class AscendMetadata(AttentionMetadata):
|
|||||||
enable_kv_scales_calculation=False)
|
enable_kv_scales_calculation=False)
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
def advance_step(self,
|
||||||
|
model_input: "ModelInputForNPUWithSamplingMetadata",
|
||||||
|
sampled_token_ids: Optional[torch.Tensor],
|
||||||
|
block_size: int,
|
||||||
|
num_seqs: int,
|
||||||
|
num_queries: int,
|
||||||
|
turn_prefills_into_decodes: bool = False):
|
||||||
|
"""
|
||||||
|
Update metadata in-place to advance one decode step.
|
||||||
|
"""
|
||||||
|
# When using cudagraph, the num_seqs is padded to the next captured
|
||||||
|
# batch sized, but num_queries tracks the actual number of requests in
|
||||||
|
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||||
|
if num_seqs != num_queries:
|
||||||
|
assert num_seqs > num_queries
|
||||||
|
|
||||||
|
if turn_prefills_into_decodes:
|
||||||
|
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
||||||
|
# decodes are scheduled together. In the first step, all the
|
||||||
|
# prefills turn into decodes. This update reflects that
|
||||||
|
# conversion.
|
||||||
|
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||||
|
self.num_decode_tokens += self.num_prefills
|
||||||
|
self.num_prefills = 0
|
||||||
|
self.num_prefill_tokens = 0
|
||||||
|
self.max_prefill_seq_len = 0
|
||||||
|
self.max_query_len = 1
|
||||||
|
|
||||||
|
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||||
|
else:
|
||||||
|
assert self.seq_lens is not None
|
||||||
|
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||||
|
|
||||||
|
assert self.num_prefills == 0
|
||||||
|
assert self.num_prefill_tokens == 0
|
||||||
|
assert self.num_decode_tokens == num_seqs
|
||||||
|
assert self.slot_mapping.shape == (num_seqs, )
|
||||||
|
|
||||||
|
assert self.seq_lens is not None
|
||||||
|
assert len(self.seq_lens) == num_seqs
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||||
|
assert self.max_query_len == 1
|
||||||
|
assert self.max_prefill_seq_len == 0
|
||||||
|
|
||||||
|
assert self.query_start_loc is not None
|
||||||
|
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||||
|
assert self.seq_start_loc is not None
|
||||||
|
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||||
|
|
||||||
|
assert self.context_lens_tensor is not None
|
||||||
|
assert self.context_lens_tensor.shape == (num_queries, )
|
||||||
|
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.block_tables.shape[0] == num_seqs
|
||||||
|
|
||||||
|
# Update query lengths. Note that we update only queries and not seqs,
|
||||||
|
# since tensors may be padded due to captured cuda graph batch size
|
||||||
|
for i in range(num_queries):
|
||||||
|
self.seq_lens[i] += 1
|
||||||
|
self.max_decode_seq_len = max(self.seq_lens)
|
||||||
|
|
||||||
|
# TODO optimize these codes using ascendc just like flash attention backend using cuda
|
||||||
|
|
||||||
|
# update input_tokens
|
||||||
|
sampled_token_ids_list = sampled_token_ids[:
|
||||||
|
num_queries].squeeze( # type: ignore
|
||||||
|
-1)
|
||||||
|
model_input.input_tokens[:
|
||||||
|
num_queries] = sampled_token_ids_list # type: ignore
|
||||||
|
|
||||||
|
# get seq_lens and input_positions
|
||||||
|
seq_lens = self.seq_lens_tensor[:num_queries]
|
||||||
|
next_seq_lens = seq_lens + 1
|
||||||
|
next_input_pos = next_seq_lens - 1
|
||||||
|
|
||||||
|
# update seq_lens and input_positions
|
||||||
|
self.seq_lens_tensor[:num_queries] = next_seq_lens
|
||||||
|
model_input.input_positions[:
|
||||||
|
num_queries] = next_input_pos # type: ignore
|
||||||
|
|
||||||
|
# 计算 block index 和 offset
|
||||||
|
block_idx = next_input_pos // block_size
|
||||||
|
block_offset = next_input_pos % block_size
|
||||||
|
|
||||||
|
current_block_table = self.block_tables.gather(
|
||||||
|
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
||||||
|
slot_num = current_block_table * block_size + block_offset
|
||||||
|
|
||||||
|
# update slot_mapping
|
||||||
|
self.slot_mapping[:num_queries] = slot_num
|
||||||
|
|
||||||
|
|
||||||
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||||
|
|
||||||
@@ -430,6 +575,11 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
|
|
||||||
max_query_len = max(query_lens)
|
max_query_len = max(query_lens)
|
||||||
|
decode_query_lens = query_lens[self.num_prefills:]
|
||||||
|
if len(decode_query_lens) > 0:
|
||||||
|
max_decode_query_len = max(decode_query_lens)
|
||||||
|
else:
|
||||||
|
max_decode_query_len = 1
|
||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||||
|
|
||||||
@@ -440,6 +590,9 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
self.input_builder.runner.device)
|
self.input_builder.runner.device)
|
||||||
else:
|
else:
|
||||||
self.attn_mask = None
|
self.attn_mask = None
|
||||||
|
num_decode_tokens = self.num_decode_tokens
|
||||||
|
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||||
|
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||||
|
|
||||||
block_tables = make_tensor_with_pad(
|
block_tables = make_tensor_with_pad(
|
||||||
self.block_tables,
|
self.block_tables,
|
||||||
@@ -450,9 +603,17 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||||
|
|
||||||
assert device is not None
|
assert device is not None
|
||||||
|
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||||
|
device, self.runner.pin_memory)
|
||||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
||||||
device, self.runner.pin_memory)
|
device, self.runner.pin_memory)
|
||||||
|
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||||
|
self.runner.pin_memory)
|
||||||
|
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||||
|
device,
|
||||||
|
self.runner.pin_memory)
|
||||||
|
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||||
|
device, self.runner.pin_memory)
|
||||||
placeholder_index_maps = {
|
placeholder_index_maps = {
|
||||||
modality: placeholder_map.index_map()
|
modality: placeholder_map.index_map()
|
||||||
for modality, placeholder_map in
|
for modality, placeholder_map in
|
||||||
@@ -466,15 +627,19 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
return AscendMetadata(
|
return AscendMetadata(
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
slot_mapping=slot_mapping_tensor,
|
slot_mapping=slot_mapping_tensor,
|
||||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
|
||||||
enable_kv_scales_calculation=False,
|
|
||||||
num_prefill_tokens=self.num_prefill_tokens,
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
num_decode_tokens=self.num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||||
|
enable_kv_scales_calculation=True,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
max_decode_query_len=max_decode_query_len,
|
||||||
max_prefill_seq_len=max_prefill_seq_len,
|
max_prefill_seq_len=max_prefill_seq_len,
|
||||||
max_decode_seq_len=max_decode_seq_len,
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
|
query_start_loc=query_start_loc_tensor,
|
||||||
|
seq_start_loc=seq_start_loc_tensor,
|
||||||
|
context_lens_tensor=context_lens_tensor,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
attn_mask=self.attn_mask,
|
attn_mask=self.attn_mask,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -105,7 +105,11 @@ class NPUPlatform(Platform):
|
|||||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
|
if vllm_config.scheduler_config.is_multi_step:
|
||||||
|
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
|
||||||
|
else:
|
||||||
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
||||||
|
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 128
|
cache_config.block_size = 128
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import torch
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -33,3 +33,23 @@ def try_register_lib(lib_name: str, lib_info: str = ""):
|
|||||||
logger.info(lib_info)
|
logger.info(lib_info)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_current_stream = None
|
||||||
|
|
||||||
|
|
||||||
|
def current_stream() -> torch.npu.Stream:
|
||||||
|
"""
|
||||||
|
replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`.
|
||||||
|
it turns out that `torch.npu.current_stream()` is quite expensive,
|
||||||
|
as it will construct a new stream object at each call.
|
||||||
|
here we patch `torch.npu.set_stream` to keep track of the current stream
|
||||||
|
directly, so that we can avoid calling `torch.npu.current_stream()`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
global _current_stream
|
||||||
|
if _current_stream is None:
|
||||||
|
# when this function is called before any stream is set,
|
||||||
|
# we return the default stream.
|
||||||
|
_current_stream = torch.npu.current_stream()
|
||||||
|
return _current_stream
|
||||||
|
|||||||
674
vllm_ascend/worker/multi_step_runner.py
Normal file
674
vllm_ascend/worker/multi_step_runner.py
Normal file
@@ -0,0 +1,674 @@
|
|||||||
|
import dataclasses
|
||||||
|
import functools
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from vllm.distributed import get_pp_group
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
||||||
|
SamplerOutput,
|
||||||
|
SamplingMetadata, get_logprobs,
|
||||||
|
get_pythonized_sample_results)
|
||||||
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
|
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||||
|
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||||
|
from vllm.worker.multi_step_model_runner import (ModelOutput,
|
||||||
|
PythonizationCache,
|
||||||
|
StatefulModelInput)
|
||||||
|
|
||||||
|
from vllm_ascend.utils import current_stream
|
||||||
|
from vllm_ascend.worker.model_runner import (
|
||||||
|
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=False)
|
||||||
|
class NPUStatefulModelInput(StatefulModelInput):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def record_step_event(self, current_stream: torch.npu.Stream):
|
||||||
|
# record the event for the current step so that the next step can sync
|
||||||
|
# on it. We modulo by 2 to keep the events in a circular buffer and
|
||||||
|
# support any attn backends that may be supported in the future. ie
|
||||||
|
# Flashinfer would want two DecodeWrappers to overlap the CPU and NPU.
|
||||||
|
self.step_cuda_events[self.current_step & 1] = \
|
||||||
|
torch.npu.Event(blocking=True)
|
||||||
|
self.step_cuda_events[self.current_step & 1].record(current_stream)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=False)
|
||||||
|
class NPUModelOutput(ModelOutput):
|
||||||
|
|
||||||
|
logprobs: Optional["torch.Tensor"] = None
|
||||||
|
|
||||||
|
def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput",
|
||||||
|
copy_stream: torch.npu.Stream,
|
||||||
|
pinned_sampled_token_buffer: torch.Tensor,
|
||||||
|
blocking: bool) -> bool:
|
||||||
|
"""
|
||||||
|
If blocking is set, will block until the forward pass for the output is
|
||||||
|
ready and pythonize the output. Upon completing Pythonization, erases
|
||||||
|
self.logprobs (note that a non-blocking call that is performed when
|
||||||
|
the sampler output is not yet ready, will not erase self.logprobs.)
|
||||||
|
"""
|
||||||
|
assert self.sampled_token_ids is not None
|
||||||
|
if not blocking and not self.sampler_output_ready_event.query():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if blocking:
|
||||||
|
self.sampler_output_ready_event.synchronize()
|
||||||
|
with torch.npu.stream(copy_stream):
|
||||||
|
_pythonize_sampler_output(input_metadata, self.sampler_output,
|
||||||
|
pinned_sampled_token_buffer,
|
||||||
|
self.sampled_token_ids, self.logprobs,
|
||||||
|
self.pythonization_cache)
|
||||||
|
|
||||||
|
# Erase the logprobs GPU-side tensor.
|
||||||
|
# Note that although _pythonize_sampler_output() runs in its
|
||||||
|
# own CUDA stream, nonetheless _pythonize_sampler_output()
|
||||||
|
# cannot return until Pythonization is complete; therefore
|
||||||
|
# we know that by the time the CPU reaches this point,
|
||||||
|
# `self.logprobs` is no longer needed.
|
||||||
|
self.logprobs = None
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
||||||
|
# mypy: enable-error-code=type-var
|
||||||
|
|
||||||
|
def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# uses the base model runner to execute the model and wraps it with
|
||||||
|
# multi-step logic
|
||||||
|
self._base_model_runner: NPUModelRunnerBase = base_model_runner
|
||||||
|
|
||||||
|
self.is_multi_step = self.scheduler_config.is_multi_step
|
||||||
|
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Using the PythonizationCache in Pipeline-Parallel clobbers the
|
||||||
|
# SequenceOutput and CompletionSequenceGroupOutput object.
|
||||||
|
# When cache-reset happens at the last step of a multi-step
|
||||||
|
# execution, there may be other on-going single-step/multi-step
|
||||||
|
# executions. The current caching implementation does not check
|
||||||
|
# for this.
|
||||||
|
self.pythonization_cache = PythonizationCache() \
|
||||||
|
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def _copy_stream(self):
|
||||||
|
# used to copy tensors from NPU to CPU asynchronously
|
||||||
|
return torch.npu.Stream()
|
||||||
|
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
|
||||||
|
model_input = (NPUStatefulModelInput.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
))
|
||||||
|
return model_input
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
|
) -> StatefulModelInput:
|
||||||
|
frozen_model_input: ModelInputForNPUWithSamplingMetadata = \
|
||||||
|
self._base_model_runner.prepare_model_input(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
virtual_engine,
|
||||||
|
finished_requests_ids)
|
||||||
|
|
||||||
|
assert frozen_model_input.query_lens is not None
|
||||||
|
assert frozen_model_input.seq_lens is not None
|
||||||
|
assert frozen_model_input.attn_metadata is not None
|
||||||
|
num_queries = len(frozen_model_input.query_lens)
|
||||||
|
num_seqs = len(frozen_model_input.seq_lens)
|
||||||
|
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
|
||||||
|
|
||||||
|
model_input = NPUStatefulModelInput(
|
||||||
|
frozen_model_input=frozen_model_input,
|
||||||
|
num_seqs=num_seqs,
|
||||||
|
num_queries=num_queries,
|
||||||
|
num_single_step_prefills=num_single_step_prefills,
|
||||||
|
step_cuda_events=[torch.npu.Event(blocking=True)] * 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_input
|
||||||
|
|
||||||
|
def _async_process_outputs(self, model_input: StatefulModelInput,
|
||||||
|
output_proc_callback: Callable):
|
||||||
|
# Proceed with pythonization and output_proc in order.
|
||||||
|
# Stop on the first one that fails to pythonize
|
||||||
|
output_proc_callback()
|
||||||
|
|
||||||
|
cont = True
|
||||||
|
for step_num, model_output in enumerate(model_input.cached_outputs):
|
||||||
|
if not model_output.pythonized:
|
||||||
|
model_output.maybe_pythonize(model_input, self._copy_stream,
|
||||||
|
self.pinned_sampled_token_ids)
|
||||||
|
if model_output.pythonized:
|
||||||
|
ctx = output_proc_callback.keywords["ctx"] # type: ignore
|
||||||
|
ctx.append_output(
|
||||||
|
outputs=[model_output.sampler_output],
|
||||||
|
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||||
|
scheduler_outputs=ctx.scheduler_outputs,
|
||||||
|
is_async=False,
|
||||||
|
is_last_step=False,
|
||||||
|
is_first_step_output=step_num == 0)
|
||||||
|
|
||||||
|
output_proc_callback()
|
||||||
|
else:
|
||||||
|
cont = False
|
||||||
|
|
||||||
|
if not cont:
|
||||||
|
break
|
||||||
|
|
||||||
|
def _final_process_outputs(
|
||||||
|
self, model_input: StatefulModelInput,
|
||||||
|
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
|
||||||
|
assert model_input.frozen_model_input is not None
|
||||||
|
|
||||||
|
has_async_callback = output_proc_callback is not None
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for step_num, output in enumerate(model_input.cached_outputs):
|
||||||
|
is_last_step = step_num == len(model_input.cached_outputs) - 1
|
||||||
|
|
||||||
|
# For non-async case:
|
||||||
|
# -- We simply add the outputs
|
||||||
|
# For async case:
|
||||||
|
# -- Invoke callback, pythonize, add to callback queue and repeat
|
||||||
|
# -- For last output, just add to callback queue
|
||||||
|
if has_async_callback:
|
||||||
|
assert output_proc_callback is not None
|
||||||
|
|
||||||
|
# Invoke callback before pythonize (to overlap with NPU)
|
||||||
|
output_proc_callback()
|
||||||
|
|
||||||
|
# Pythonize
|
||||||
|
if not output.pythonized:
|
||||||
|
output.pythonize(model_input, self._copy_stream,
|
||||||
|
self.pinned_sampled_token_ids)
|
||||||
|
|
||||||
|
# For non last step, add to callback queue to chain
|
||||||
|
# callbacks=>pythonize pairs (for NPU overlap)
|
||||||
|
if not is_last_step:
|
||||||
|
ctx = output_proc_callback.keywords[ # type: ignore
|
||||||
|
"ctx"] # type: ignore
|
||||||
|
ctx.append_output(
|
||||||
|
outputs=[output.sampler_output],
|
||||||
|
seq_group_metadata_list=ctx.
|
||||||
|
seq_group_metadata_list,
|
||||||
|
scheduler_outputs=ctx.scheduler_outputs,
|
||||||
|
is_async=False,
|
||||||
|
is_last_step=False,
|
||||||
|
is_first_step_output=step_num == 0)
|
||||||
|
else:
|
||||||
|
outputs.append(output.sampler_output)
|
||||||
|
else:
|
||||||
|
output.pythonize(model_input, self._copy_stream,
|
||||||
|
self.pinned_sampled_token_ids)
|
||||||
|
outputs.append(output.sampler_output)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: StatefulModelInput,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
num_steps: int = 1,
|
||||||
|
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||||
|
"""
|
||||||
|
Execute the model for a single step and update multi-step
|
||||||
|
metadata
|
||||||
|
"""
|
||||||
|
assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
|
||||||
|
frozen_model_input = model_input.frozen_model_input
|
||||||
|
assert frozen_model_input is not None
|
||||||
|
|
||||||
|
# path for warm up runs
|
||||||
|
if not model_input.is_multi_step:
|
||||||
|
return self._base_model_runner.execute_model(
|
||||||
|
frozen_model_input, kv_caches, intermediate_tensors, num_steps)
|
||||||
|
|
||||||
|
# make sure we skip the sampler on the lask rank and only pythonize
|
||||||
|
# if CPU is ahead.
|
||||||
|
if self.is_driver_worker and get_pp_group().is_last_rank:
|
||||||
|
if self.pinned_sampled_token_ids is None:
|
||||||
|
self.pinned_sampled_token_ids = torch.zeros(
|
||||||
|
(self.scheduler_config.max_num_seqs, 1),
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
|
self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
|
||||||
|
True)
|
||||||
|
if frozen_model_input.sampling_metadata:
|
||||||
|
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||||
|
True)
|
||||||
|
|
||||||
|
# some pre-execute model logic for multi-step:
|
||||||
|
# - if it's the first step, we need to reset the sampling tensors
|
||||||
|
# - if it's not the first step, we need to advance the step using the
|
||||||
|
# appended sampler output from last iteration
|
||||||
|
# - also maybe pythonize if CPU is ahead of NPU
|
||||||
|
|
||||||
|
stream = current_stream()
|
||||||
|
if not model_input.is_first_multi_step:
|
||||||
|
# Explicitly block on the previous step's forward to make sure we
|
||||||
|
# don't clobber any NPU tensors still in use.
|
||||||
|
# This is not needed for flashattn backend, but for other attn
|
||||||
|
# backends such as flashinfer that performs extra CPU operations on
|
||||||
|
# input metadata we may need to synchronize any CPU operations that
|
||||||
|
# might clobber enqueued forwards. (prevents CPU from running too
|
||||||
|
# far ahead if needed)
|
||||||
|
model_input.wait_previous_step()
|
||||||
|
model_input = self._advance_step(
|
||||||
|
model_input, model_input.cached_outputs[-1].sampler_output)
|
||||||
|
|
||||||
|
# frozen_model_input may have been updated
|
||||||
|
frozen_model_input = model_input.frozen_model_input
|
||||||
|
assert frozen_model_input is not None
|
||||||
|
|
||||||
|
if model_input.base_output_proc_callback is None:
|
||||||
|
assert frozen_model_input is not None
|
||||||
|
model_input.base_output_proc_callback = \
|
||||||
|
frozen_model_input.async_callback
|
||||||
|
|
||||||
|
if frozen_model_input.async_callback is not None:
|
||||||
|
assert model_input.base_output_proc_callback is not None
|
||||||
|
async_callback = functools.partial(
|
||||||
|
self._async_process_outputs,
|
||||||
|
model_input=model_input,
|
||||||
|
output_proc_callback=model_input.base_output_proc_callback)
|
||||||
|
|
||||||
|
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||||
|
model_input.frozen_model_input,
|
||||||
|
async_callback=async_callback)
|
||||||
|
# Update the local instance
|
||||||
|
frozen_model_input = model_input.frozen_model_input
|
||||||
|
assert frozen_model_input is not None
|
||||||
|
|
||||||
|
# Execute the model
|
||||||
|
output = self._base_model_runner.execute_model(frozen_model_input,
|
||||||
|
kv_caches,
|
||||||
|
intermediate_tensors,
|
||||||
|
num_steps=1)
|
||||||
|
|
||||||
|
# record the event for the current step so that the next step can sync
|
||||||
|
model_input.record_step_event(stream)
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank and self.is_driver_worker:
|
||||||
|
assert isinstance(output, list)
|
||||||
|
assert len(
|
||||||
|
output
|
||||||
|
) == 1, "MultiStepModelRunner requires single-step base_models"
|
||||||
|
|
||||||
|
# event for the pythonization so that we only pythonize if the
|
||||||
|
# tensors are ready. May be able to be combined with the step event
|
||||||
|
output_ready_event = torch.npu.Event()
|
||||||
|
output_ready_event.record(stream)
|
||||||
|
if self.parallel_config.pipeline_parallel_size > 1:
|
||||||
|
output[0].sampled_token_ids_cpu = output[
|
||||||
|
0].sampled_token_ids.cpu()
|
||||||
|
model_input.cached_outputs.append(
|
||||||
|
NPUModelOutput(output[0], output_ready_event,
|
||||||
|
output[0].sampled_token_ids, False,
|
||||||
|
output[0].logprobs, self.pythonization_cache))
|
||||||
|
|
||||||
|
# These NPU tensors are not required by multi-step;
|
||||||
|
# erase them to ensure they are not pythonized or
|
||||||
|
# transferred to CPU
|
||||||
|
output[0].sampled_token_ids = None
|
||||||
|
output[0].sampled_token_probs = None
|
||||||
|
output[0].logprobs = None
|
||||||
|
|
||||||
|
# Pythonize the output if CPU is ahead and the previous step is
|
||||||
|
# ready.
|
||||||
|
if frozen_model_input.async_callback is None:
|
||||||
|
for model_output in model_input.cached_outputs:
|
||||||
|
model_output.maybe_pythonize(model_input,
|
||||||
|
self._copy_stream,
|
||||||
|
self.pinned_sampled_token_ids)
|
||||||
|
|
||||||
|
model_input.current_step += 1
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
# Should be IntermediateTensors
|
||||||
|
assert isinstance(output, IntermediateTensors)
|
||||||
|
return output
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Pythonize the output and block if needed since it is the last step
|
||||||
|
if model_input.is_last_step:
|
||||||
|
outputs = self._final_process_outputs(
|
||||||
|
model_input, model_input.base_output_proc_callback)
|
||||||
|
if self.pythonization_cache:
|
||||||
|
self.pythonization_cache.reset()
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
# should be [SamplerOutput]
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
|
||||||
|
num_seqs: Optional[int], num_queries: int):
|
||||||
|
|
||||||
|
assert sampling_metadata.num_prompts == 0
|
||||||
|
assert len(sampling_metadata.seq_groups) == num_queries
|
||||||
|
assert sampling_metadata.selected_token_indices.shape == (
|
||||||
|
num_queries, )
|
||||||
|
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
|
||||||
|
|
||||||
|
# Verify that all sequences are decodes
|
||||||
|
for i in range(num_queries):
|
||||||
|
seq_group = sampling_metadata.seq_groups[i]
|
||||||
|
|
||||||
|
assert seq_group.is_prompt is False # No prompt
|
||||||
|
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||||
|
assert seq_group.sample_indices == [i] # Simple
|
||||||
|
assert seq_group.seq_len is None # Decode
|
||||||
|
assert seq_group.query_len is None # Decode
|
||||||
|
|
||||||
|
def _advance_step(self, model_input: StatefulModelInput,
|
||||||
|
out: SamplerOutput) -> StatefulModelInput:
|
||||||
|
|
||||||
|
model_input.maybe_advance_frozen_model_input(self.device,
|
||||||
|
self.pin_memory)
|
||||||
|
frozen_model_input = model_input.frozen_model_input
|
||||||
|
assert frozen_model_input is not None
|
||||||
|
assert frozen_model_input.input_tokens is not None
|
||||||
|
assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs
|
||||||
|
assert frozen_model_input.attn_metadata is not None
|
||||||
|
|
||||||
|
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
|
||||||
|
num_seqs = model_input.num_seqs
|
||||||
|
num_queries = model_input.num_queries
|
||||||
|
frozen_model_input = model_input.frozen_model_input
|
||||||
|
assert frozen_model_input is not None
|
||||||
|
attn_metadata = frozen_model_input.attn_metadata
|
||||||
|
assert attn_metadata is not None
|
||||||
|
|
||||||
|
turn_prefills_into_decodes: bool = model_input.current_step == 1 and \
|
||||||
|
model_input.num_single_step_prefills != 0
|
||||||
|
attn_metadata.advance_step(
|
||||||
|
frozen_model_input,
|
||||||
|
sampled_token_ids,
|
||||||
|
self.block_size,
|
||||||
|
num_seqs,
|
||||||
|
num_queries,
|
||||||
|
turn_prefills_into_decodes=turn_prefills_into_decodes)
|
||||||
|
|
||||||
|
return model_input
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
self._base_model_runner.load_model()
|
||||||
|
self.model_memory_usage = self._base_model_runner.model_memory_usage
|
||||||
|
|
||||||
|
def save_sharded_state(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
return self._base_model_runner.save_sharded_state(
|
||||||
|
path, pattern, max_size)
|
||||||
|
|
||||||
|
def save_tensorized_model(self,
|
||||||
|
tensorizer_config: TensorizerConfig) -> None:
|
||||||
|
return self._base_model_runner.save_tensorized_model(tensorizer_config)
|
||||||
|
|
||||||
|
def profile_run(self) -> None:
|
||||||
|
return self._base_model_runner.profile_run()
|
||||||
|
|
||||||
|
def remove_all_loras(self):
|
||||||
|
return self._base_model_runner.remove_all_loras()
|
||||||
|
|
||||||
|
def capture_model(self, kv_caches: List[List]) -> None:
|
||||||
|
return self._base_model_runner.capture_model(kv_caches)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self._base_model_runner.vocab_size
|
||||||
|
|
||||||
|
|
||||||
|
DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
|
||||||
|
Optional[List[SampleLogprobs]]]
|
||||||
|
|
||||||
|
|
||||||
|
def deferred_pythonize_logprobs(
|
||||||
|
output: SamplerOutput,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
logprobs_tensor: Optional[torch.Tensor],
|
||||||
|
) -> DeferredLogprobsReturnType:
|
||||||
|
"""Perform deferred logprob Pythonization.
|
||||||
|
|
||||||
|
1. Pythonize NPU-side sampler result tensors into CPU-side sampler result.
|
||||||
|
2. Pythonize NPU-side logprobs tensor into CPU-side logprobs lists,
|
||||||
|
utilizing the Pythonized sampler result computed in step 1.
|
||||||
|
|
||||||
|
These deferred computations are not required for single-step scheduling
|
||||||
|
or the `profile_run()` phase of multi-step scheduling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: sampler output (under deferred Pythonization)
|
||||||
|
sampling_metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
prompt_logprobs (CPU), sample_logprobs (CPU)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# - Deferred pythonization of sample result
|
||||||
|
sampler_result = get_pythonized_sample_results(
|
||||||
|
output.deferred_sample_results_args)
|
||||||
|
|
||||||
|
# - Erase the NPU-side deferred sample_result
|
||||||
|
# computation args to ensure it is never
|
||||||
|
# pythonized or transferred to CPU
|
||||||
|
output.deferred_sample_results_args = None
|
||||||
|
|
||||||
|
# - Deferred pythonization of logprobs
|
||||||
|
(
|
||||||
|
prompt_logprobs,
|
||||||
|
sample_logprobs,
|
||||||
|
) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
|
||||||
|
assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
|
||||||
|
assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
|
||||||
|
|
||||||
|
return prompt_logprobs, sample_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def _pythonize_sampler_output(
|
||||||
|
model_input: StatefulModelInput,
|
||||||
|
output: SamplerOutput,
|
||||||
|
pinned_sampled_token_buffer: torch.Tensor,
|
||||||
|
sampled_token_ids: torch.Tensor,
|
||||||
|
logprobs_tensor: Optional[torch.Tensor],
|
||||||
|
cache: Optional[PythonizationCache],
|
||||||
|
) -> None:
|
||||||
|
""" This function is only called when the output tensors are ready.
|
||||||
|
See :class:`ModelOutput`.
|
||||||
|
|
||||||
|
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
|
||||||
|
adding a Pythonized output data structure
|
||||||
|
(:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_input
|
||||||
|
output: sampler output
|
||||||
|
pinned_sampled_token_token_buffer: CPU-side pinned memory
|
||||||
|
(receives copy of
|
||||||
|
NPU-side token buffer.)
|
||||||
|
sampled_token_ids: NPU-side token buffer
|
||||||
|
logprobs_tensor: NPU-side tensor containing
|
||||||
|
logprobs computed during sampling
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert model_input.frozen_model_input is not None
|
||||||
|
|
||||||
|
frozen_model_input = model_input.frozen_model_input
|
||||||
|
assert frozen_model_input.sampling_metadata is not None
|
||||||
|
sampling_metadata = frozen_model_input.sampling_metadata
|
||||||
|
# samples generation should have been skipped
|
||||||
|
assert not output.outputs
|
||||||
|
|
||||||
|
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
|
||||||
|
|
||||||
|
# We guarantee output tensors are ready, so it is safe to
|
||||||
|
# pythonize the sampler output & obtain CPU-side logprobs.
|
||||||
|
#
|
||||||
|
# However we should check whether logprobs pythonization may
|
||||||
|
# be skipped entirely, i.e. because no logprobs were requested
|
||||||
|
# or pythonization was not deferred. To that end,
|
||||||
|
#
|
||||||
|
# * `prompt_logprobs_are_requested_for_prefill` signals that
|
||||||
|
# there are *any* prefill-phase requests which specify that
|
||||||
|
# prompt logprobs should be returned.
|
||||||
|
#
|
||||||
|
# * `any_logprobs_are_requested` signals that there are any
|
||||||
|
# requests which (1) specify that sample logprobs should be
|
||||||
|
# returned, or (2) are in the prefill phase AND specify that
|
||||||
|
# prompt logprobs should be returned.
|
||||||
|
#
|
||||||
|
# Later on, these flags cause adjustments to the pythonization
|
||||||
|
# process to accommodate logprobs.
|
||||||
|
|
||||||
|
seq_groups = sampling_metadata.seq_groups
|
||||||
|
prompt_logprobs_are_requested_for_prefill = any([
|
||||||
|
sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
|
||||||
|
for sg in seq_groups
|
||||||
|
])
|
||||||
|
any_logprobs_are_requested = (
|
||||||
|
prompt_logprobs_are_requested_for_prefill
|
||||||
|
or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))
|
||||||
|
|
||||||
|
if prompt_logprobs_are_requested_for_prefill:
|
||||||
|
# CPU NPU sync, after gathering *only* sampled tokens (since
|
||||||
|
# requesting prompt logprobs leads `sampled_token_ids` to
|
||||||
|
# include prompt token ids in addition to sampled token ids.)
|
||||||
|
sample_idx_tensor = torch.tensor(
|
||||||
|
[sdx for sg in seq_groups for sdx in sg.sample_indices])
|
||||||
|
pinned_buffer = pinned_buffer.copy_(
|
||||||
|
sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
|
||||||
|
else:
|
||||||
|
# CPU NPU sync
|
||||||
|
pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
|
||||||
|
non_blocking=False)
|
||||||
|
|
||||||
|
# this will not block as the tensors are already on CPU
|
||||||
|
samples_list = pinned_buffer.tolist()
|
||||||
|
|
||||||
|
skip_sampler_cpu_output = (
|
||||||
|
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
|
||||||
|
|
||||||
|
# *Don't* skip logprobs pythonization *if*:
|
||||||
|
# * Any requests require logprobs to be returned in this
|
||||||
|
# iteration AND
|
||||||
|
# * These requests are being scheduled in a fashion which
|
||||||
|
# defers pythonization (i.e. multi-step scheduling.)
|
||||||
|
do_pythonize_logprobs = (skip_sampler_cpu_output
|
||||||
|
and any_logprobs_are_requested)
|
||||||
|
(
|
||||||
|
prompt_logprobs,
|
||||||
|
sample_logprobs,
|
||||||
|
) = (deferred_pythonize_logprobs(output, sampling_metadata,
|
||||||
|
logprobs_tensor)
|
||||||
|
if do_pythonize_logprobs else (None, None))
|
||||||
|
|
||||||
|
for sgdx, (seq_group,
|
||||||
|
sample_result) in enumerate(zip(seq_groups, samples_list)):
|
||||||
|
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||||
|
# If the feature combo become valid
|
||||||
|
# (Check for Guided Decoding)
|
||||||
|
if seq_group.sampling_params.logits_processors:
|
||||||
|
assert len(seq_group.sampling_params.logits_processors) == 0, (
|
||||||
|
"Logits Processors are not supported in multi-step decoding")
|
||||||
|
|
||||||
|
if do_pythonize_logprobs:
|
||||||
|
assert prompt_logprobs is not None
|
||||||
|
assert sample_logprobs is not None
|
||||||
|
|
||||||
|
(
|
||||||
|
group_prompt_logprobs,
|
||||||
|
group_sample_logprobs,
|
||||||
|
) = ( # Utilize deferred pythonization results
|
||||||
|
prompt_logprobs[sgdx],
|
||||||
|
sample_logprobs[sgdx],
|
||||||
|
)
|
||||||
|
elif any_logprobs_are_requested:
|
||||||
|
(
|
||||||
|
group_prompt_logprobs,
|
||||||
|
group_sample_logprobs,
|
||||||
|
) = (
|
||||||
|
# profile_run: use already-computed logprobs
|
||||||
|
output.outputs[sgdx].prompt_logprobs,
|
||||||
|
[sample.logprobs for sample in output.outputs[sgdx].samples])
|
||||||
|
|
||||||
|
seq_ids = seq_group.seq_ids
|
||||||
|
next_token_ids = sample_result
|
||||||
|
parent_ids = [0]
|
||||||
|
seq_outputs: List[SequenceOutput]
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
completion_seq_group_output: CompletionSequenceGroupOutput = \
|
||||||
|
cache.cached_completion_seq_group_output.get_object()
|
||||||
|
completion_seq_group_output.samples.clear()
|
||||||
|
seq_outputs = completion_seq_group_output.samples
|
||||||
|
else:
|
||||||
|
seq_outputs = []
|
||||||
|
|
||||||
|
for tdx, (parent_id,
|
||||||
|
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
|
||||||
|
if cache is not None:
|
||||||
|
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
|
||||||
|
)
|
||||||
|
seq_output.parent_seq_id = seq_ids[parent_id]
|
||||||
|
seq_output.output_token = next_token_id
|
||||||
|
|
||||||
|
if any_logprobs_are_requested:
|
||||||
|
seq_output.logprobs = group_sample_logprobs[tdx]
|
||||||
|
else:
|
||||||
|
logprobs = next(iter(seq_output.logprobs.values()))
|
||||||
|
seq_output.logprobs.clear()
|
||||||
|
|
||||||
|
logprobs.logprob = float('inf')
|
||||||
|
logprobs.rank = None
|
||||||
|
logprobs.decoded_token = None
|
||||||
|
|
||||||
|
seq_output.logprobs[next_token_id] = logprobs
|
||||||
|
|
||||||
|
seq_outputs.append(seq_output)
|
||||||
|
|
||||||
|
else:
|
||||||
|
seq_outputs.append(
|
||||||
|
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||||
|
(group_sample_logprobs[tdx]
|
||||||
|
if any_logprobs_are_requested else {
|
||||||
|
next_token_id:
|
||||||
|
Logprob(logprob=float('inf'),
|
||||||
|
rank=None,
|
||||||
|
decoded_token=None)
|
||||||
|
})))
|
||||||
|
if cache is not None:
|
||||||
|
completion_seq_group_output.prompt_logprobs = \
|
||||||
|
group_prompt_logprobs if any_logprobs_are_requested else None
|
||||||
|
output.outputs.append(completion_seq_group_output)
|
||||||
|
else:
|
||||||
|
output.outputs.append(
|
||||||
|
CompletionSequenceGroupOutput(
|
||||||
|
seq_outputs, (group_prompt_logprobs
|
||||||
|
if any_logprobs_are_requested else None)))
|
||||||
|
|
||||||
|
assert len(output.outputs) > 0
|
||||||
194
vllm_ascend/worker/multi_step_worker.py
Normal file
194
vllm_ascend/worker/multi_step_worker.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
import dataclasses
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.sequence import ExecuteModelRequest
|
||||||
|
from vllm.worker.model_runner_base import BroadcastableModelInput
|
||||||
|
from vllm.worker.multi_step_model_runner import StatefulModelInput
|
||||||
|
|
||||||
|
from vllm_ascend.worker.multi_step_runner import MultiStepModelNPURunner
|
||||||
|
from vllm_ascend.worker.worker import NPUWorker, WorkerInput
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiStepState:
|
||||||
|
worker_input: WorkerInput
|
||||||
|
model_input: StatefulModelInput
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStepWorker(NPUWorker):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
base_model_runner = self.model_runner
|
||||||
|
# for multi-step model, wrap the model runner with MultiStepModelRunner
|
||||||
|
self.model_runner = MultiStepModelNPURunner(
|
||||||
|
base_model_runner,
|
||||||
|
vllm_config=base_model_runner.vllm_config,
|
||||||
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||||
|
is_driver_worker=base_model_runner.is_driver_worker,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||||
|
self.multi_step_states: List[
|
||||||
|
Optional[MultiStepState]] = [None] * pipeline_parallel_size
|
||||||
|
self.temp_output = None
|
||||||
|
|
||||||
|
def _get_driver_input_and_broadcast(
|
||||||
|
self, execute_model_req: ExecuteModelRequest
|
||||||
|
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Get the driver input and broadcast it to other workers.
|
||||||
|
"""
|
||||||
|
assert self.is_driver_worker
|
||||||
|
virtual_engine = execute_model_req.virtual_engine
|
||||||
|
is_first_multi_step = execute_model_req.is_first_multi_step
|
||||||
|
if is_first_multi_step:
|
||||||
|
# on first step we prepare the worker input and model input normally
|
||||||
|
worker_input: WorkerInput = self.prepare_worker_input(
|
||||||
|
execute_model_req=execute_model_req)
|
||||||
|
model_input: StatefulModelInput = (
|
||||||
|
self.model_runner.prepare_model_input(
|
||||||
|
execute_model_req.seq_group_metadata_list,
|
||||||
|
execute_model_req.virtual_engine,
|
||||||
|
execute_model_req.finished_requests_ids))
|
||||||
|
|
||||||
|
if execute_model_req.async_callback:
|
||||||
|
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||||
|
model_input.frozen_model_input,
|
||||||
|
async_callback=execute_model_req.async_callback)
|
||||||
|
else:
|
||||||
|
# on subsequent steps we reuse the worker input and model input
|
||||||
|
multi_step_state = self.multi_step_states[virtual_engine]
|
||||||
|
worker_input = multi_step_state.worker_input
|
||||||
|
model_input = multi_step_state.model_input
|
||||||
|
frozen_model_input = model_input.frozen_model_input
|
||||||
|
assert frozen_model_input is not None
|
||||||
|
assert frozen_model_input.attn_metadata is not None
|
||||||
|
# clear the cached metadata so that it can be recomputed on
|
||||||
|
# the workers.
|
||||||
|
frozen_model_input.attn_metadata._cached_prefill_metadata = None
|
||||||
|
frozen_model_input.attn_metadata._cached_decode_metadata = None
|
||||||
|
|
||||||
|
model_input.is_first_multi_step = is_first_multi_step
|
||||||
|
model_input.is_last_step = execute_model_req.is_last_step
|
||||||
|
|
||||||
|
if not is_first_multi_step:
|
||||||
|
# we broadcast the last sampled token ids to all TP workers so they
|
||||||
|
# can update their model input metadata in-place.
|
||||||
|
self._prepare_last_sampled_token_ids_for_tp_workers(
|
||||||
|
execute_model_req=execute_model_req, model_input=model_input)
|
||||||
|
|
||||||
|
if self.do_metadata_broadcast:
|
||||||
|
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||||
|
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||||
|
broadcast_tensor_dict(broadcast_data, src=0)
|
||||||
|
|
||||||
|
# Retuning empty dict here to keep this compatible with
|
||||||
|
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||||
|
return model_input, worker_input, {}
|
||||||
|
|
||||||
|
def _prepare_last_sampled_token_ids_for_tp_workers(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
model_input: StatefulModelInput,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Prepare the last sampled token ids for TP workers. If it's the last
|
||||||
|
PP rank, then the last sampled token ids are already in the model_input.
|
||||||
|
If it is NOT the last PP rank, then we need to get the last sampled
|
||||||
|
token that is cached in the execute_model_req.
|
||||||
|
"""
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
assert model_input.cached_outputs[
|
||||||
|
-1].sampler_output.sampled_token_ids is None
|
||||||
|
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
||||||
|
model_input.last_sampled_token_ids = model_input.cached_outputs[
|
||||||
|
-1].sampled_token_ids
|
||||||
|
# free sampled token ids from the previous step if it has been
|
||||||
|
# pythonized. Cannot free the last sampled token ids because
|
||||||
|
# we need it for GPU advance_step.
|
||||||
|
for output in model_input.cached_outputs[:-1]:
|
||||||
|
if output.pythonized:
|
||||||
|
output.sampled_token_ids = None
|
||||||
|
else:
|
||||||
|
# otherwise we need to get the cached sampled token ids from the
|
||||||
|
# execute_model_req
|
||||||
|
assert execute_model_req.last_sampled_token_ids is not None
|
||||||
|
model_input.last_sampled_token_ids = (
|
||||||
|
execute_model_req.last_sampled_token_ids.cuda())
|
||||||
|
model_input.add_sampler_output(
|
||||||
|
SamplerOutput(outputs=[], sampled_token_ids=None),
|
||||||
|
model_input.last_sampled_token_ids)
|
||||||
|
|
||||||
|
# free sampled token ids from the previous step.
|
||||||
|
# TODO(will) we could reuse the sampled token ids tensor from
|
||||||
|
# the previous step instead.
|
||||||
|
for output in model_input.cached_outputs[:-1]:
|
||||||
|
output.sampled_token_ids = None
|
||||||
|
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
||||||
|
|
||||||
|
def prepare_input(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||||
|
) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
|
||||||
|
torch.Tensor]]]:
|
||||||
|
"""
|
||||||
|
Depending on the current state of the request and multi step worker,
|
||||||
|
this method may skip the normal _prepare_model_input and
|
||||||
|
_prepare_worker_input methods and instead used cached values.
|
||||||
|
"""
|
||||||
|
if self.is_driver_worker:
|
||||||
|
if execute_model_req is None:
|
||||||
|
if self.do_metadata_broadcast:
|
||||||
|
# This signals that there's no more requests to process for
|
||||||
|
# now. All workers are running infinite loop with
|
||||||
|
# broadcast_tensor_dict, and it stops the loop when the
|
||||||
|
# driver broadcasts an empty input. Send an empty input to
|
||||||
|
# notify all other workers to stop their execution loop.
|
||||||
|
broadcast_tensor_dict({}, src=0)
|
||||||
|
return None
|
||||||
|
|
||||||
|
virtual_engine = execute_model_req.virtual_engine
|
||||||
|
(model_input, worker_input,
|
||||||
|
kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
|
||||||
|
assert isinstance(model_input, StatefulModelInput)
|
||||||
|
if execute_model_req.is_first_multi_step:
|
||||||
|
# cache the worker input and model input for the next steps
|
||||||
|
self.multi_step_states[virtual_engine] = MultiStepState(
|
||||||
|
worker_input=worker_input, model_input=model_input)
|
||||||
|
# if TP workers
|
||||||
|
else:
|
||||||
|
broadcast_data = self._get_worker_input_from_broadcast()
|
||||||
|
# if the driver has sent an empty input, we should stop the worker
|
||||||
|
# loop
|
||||||
|
if broadcast_data is None:
|
||||||
|
return None
|
||||||
|
model_input, worker_input, kwargs = broadcast_data
|
||||||
|
assert isinstance(model_input, StatefulModelInput)
|
||||||
|
virtual_engine = worker_input.virtual_engine
|
||||||
|
if model_input.is_first_multi_step:
|
||||||
|
pass
|
||||||
|
# TODO(will) Can cache the worker input and model input for the
|
||||||
|
# next steps. See below for details
|
||||||
|
else:
|
||||||
|
# TODO(will) possible to also cache and reuse the cached worker
|
||||||
|
# input and model input. The idea is essentially the delta
|
||||||
|
# optimization for model_inputs. Where the TP workers can cache
|
||||||
|
# the model input states and we only broadcast the delta need
|
||||||
|
# for the next step (sampled_token_ids from the previous step)
|
||||||
|
|
||||||
|
assert isinstance(model_input, StatefulModelInput)
|
||||||
|
# we need to update the last sampled token ids in the model
|
||||||
|
# input for the workers so that they can run inplace
|
||||||
|
# advance_step
|
||||||
|
model_input.add_sampler_output(
|
||||||
|
SamplerOutput(outputs=[], sampled_token_ids=None),
|
||||||
|
model_input.last_sampled_token_ids)
|
||||||
|
|
||||||
|
assert model_input is not None
|
||||||
|
assert worker_input is not None
|
||||||
|
return model_input, worker_input, kwargs
|
||||||
Reference in New Issue
Block a user