support multistep decode (#299)
Add multi step scheduler support for vllm-ascend Signed-off-by: new-TonyWang <wangtonyyu222@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#
|
||||
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
@@ -38,7 +39,8 @@ from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_ascend.worker.model_runner import ModelInputForNPUBuilder
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||
|
||||
|
||||
def generate_attn_mask(max_seq_len: int, dtype=torch.float16):
|
||||
@@ -197,26 +199,52 @@ class AscendMetadata(AttentionMetadata):
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# Avoid mypy error
|
||||
# Total number of prefill requests.
|
||||
num_prefills: int
|
||||
# Number of prefill tokens.
|
||||
num_prefill_tokens: int
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Self-attention prefill/decode metadata cache
|
||||
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
||||
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
||||
@@ -254,10 +282,18 @@ class AscendMetadata(AttentionMetadata):
|
||||
or (self.encoder_seq_lens is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None.
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
@@ -274,7 +310,11 @@ class AscendMetadata(AttentionMetadata):
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_query_len=0,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
@@ -302,6 +342,8 @@ class AscendMetadata(AttentionMetadata):
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[self.num_prefills:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
@@ -314,8 +356,19 @@ class AscendMetadata(AttentionMetadata):
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
@@ -328,6 +381,98 @@ class AscendMetadata(AttentionMetadata):
|
||||
enable_kv_scales_calculation=False)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForNPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
# TODO optimize these codes using ascendc just like flash attention backend using cuda
|
||||
|
||||
# update input_tokens
|
||||
sampled_token_ids_list = sampled_token_ids[:
|
||||
num_queries].squeeze( # type: ignore
|
||||
-1)
|
||||
model_input.input_tokens[:
|
||||
num_queries] = sampled_token_ids_list # type: ignore
|
||||
|
||||
# get seq_lens and input_positions
|
||||
seq_lens = self.seq_lens_tensor[:num_queries]
|
||||
next_seq_lens = seq_lens + 1
|
||||
next_input_pos = next_seq_lens - 1
|
||||
|
||||
# update seq_lens and input_positions
|
||||
self.seq_lens_tensor[:num_queries] = next_seq_lens
|
||||
model_input.input_positions[:
|
||||
num_queries] = next_input_pos # type: ignore
|
||||
|
||||
# 计算 block index 和 offset
|
||||
block_idx = next_input_pos // block_size
|
||||
block_offset = next_input_pos % block_size
|
||||
|
||||
current_block_table = self.block_tables.gather(
|
||||
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
||||
slot_num = current_block_table * block_size + block_offset
|
||||
|
||||
# update slot_mapping
|
||||
self.slot_mapping[:num_queries] = slot_num
|
||||
|
||||
|
||||
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
|
||||
@@ -430,6 +575,11 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
device = self.runner.device
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
|
||||
@@ -440,6 +590,9 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.input_builder.runner.device)
|
||||
else:
|
||||
self.attn_mask = None
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
@@ -450,9 +603,17 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||
|
||||
assert device is not None
|
||||
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
@@ -466,15 +627,19 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
return AscendMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
attn_mask=self.attn_mask,
|
||||
)
|
||||
|
||||
@@ -105,7 +105,11 @@ class NPUPlatform(Platform):
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
||||
if vllm_config.scheduler_config.is_multi_step:
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 128
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -33,3 +33,23 @@ def try_register_lib(lib_name: str, lib_info: str = ""):
|
||||
logger.info(lib_info)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
_current_stream = None
|
||||
|
||||
|
||||
def current_stream() -> torch.npu.Stream:
|
||||
"""
|
||||
replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`.
|
||||
it turns out that `torch.npu.current_stream()` is quite expensive,
|
||||
as it will construct a new stream object at each call.
|
||||
here we patch `torch.npu.set_stream` to keep track of the current stream
|
||||
directly, so that we can avoid calling `torch.npu.current_stream()`.
|
||||
|
||||
"""
|
||||
global _current_stream
|
||||
if _current_stream is None:
|
||||
# when this function is called before any stream is set,
|
||||
# we return the default stream.
|
||||
_current_stream = torch.npu.current_stream()
|
||||
return _current_stream
|
||||
|
||||
674
vllm_ascend/worker/multi_step_runner.py
Normal file
674
vllm_ascend/worker/multi_step_runner.py
Normal file
@@ -0,0 +1,674 @@
|
||||
import dataclasses
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
||||
SamplerOutput,
|
||||
SamplingMetadata, get_logprobs,
|
||||
get_pythonized_sample_results)
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.worker.multi_step_model_runner import (ModelOutput,
|
||||
PythonizationCache,
|
||||
StatefulModelInput)
|
||||
|
||||
from vllm_ascend.utils import current_stream
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class NPUStatefulModelInput(StatefulModelInput):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def record_step_event(self, current_stream: torch.npu.Stream):
|
||||
# record the event for the current step so that the next step can sync
|
||||
# on it. We modulo by 2 to keep the events in a circular buffer and
|
||||
# support any attn backends that may be supported in the future. ie
|
||||
# Flashinfer would want two DecodeWrappers to overlap the CPU and NPU.
|
||||
self.step_cuda_events[self.current_step & 1] = \
|
||||
torch.npu.Event(blocking=True)
|
||||
self.step_cuda_events[self.current_step & 1].record(current_stream)
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class NPUModelOutput(ModelOutput):
|
||||
|
||||
logprobs: Optional["torch.Tensor"] = None
|
||||
|
||||
def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput",
|
||||
copy_stream: torch.npu.Stream,
|
||||
pinned_sampled_token_buffer: torch.Tensor,
|
||||
blocking: bool) -> bool:
|
||||
"""
|
||||
If blocking is set, will block until the forward pass for the output is
|
||||
ready and pythonize the output. Upon completing Pythonization, erases
|
||||
self.logprobs (note that a non-blocking call that is performed when
|
||||
the sampler output is not yet ready, will not erase self.logprobs.)
|
||||
"""
|
||||
assert self.sampled_token_ids is not None
|
||||
if not blocking and not self.sampler_output_ready_event.query():
|
||||
return False
|
||||
|
||||
if blocking:
|
||||
self.sampler_output_ready_event.synchronize()
|
||||
with torch.npu.stream(copy_stream):
|
||||
_pythonize_sampler_output(input_metadata, self.sampler_output,
|
||||
pinned_sampled_token_buffer,
|
||||
self.sampled_token_ids, self.logprobs,
|
||||
self.pythonization_cache)
|
||||
|
||||
# Erase the logprobs GPU-side tensor.
|
||||
# Note that although _pythonize_sampler_output() runs in its
|
||||
# own CUDA stream, nonetheless _pythonize_sampler_output()
|
||||
# cannot return until Pythonization is complete; therefore
|
||||
# we know that by the time the CPU reaches this point,
|
||||
# `self.logprobs` is no longer needed.
|
||||
self.logprobs = None
|
||||
return True
|
||||
|
||||
|
||||
class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]):
|
||||
# mypy: enable-error-code=type-var
|
||||
|
||||
def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# uses the base model runner to execute the model and wraps it with
|
||||
# multi-step logic
|
||||
self._base_model_runner: NPUModelRunnerBase = base_model_runner
|
||||
|
||||
self.is_multi_step = self.scheduler_config.is_multi_step
|
||||
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# Using the PythonizationCache in Pipeline-Parallel clobbers the
|
||||
# SequenceOutput and CompletionSequenceGroupOutput object.
|
||||
# When cache-reset happens at the last step of a multi-step
|
||||
# execution, there may be other on-going single-step/multi-step
|
||||
# executions. The current caching implementation does not check
|
||||
# for this.
|
||||
self.pythonization_cache = PythonizationCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@functools.cached_property
|
||||
def _copy_stream(self):
|
||||
# used to copy tensors from NPU to CPU asynchronously
|
||||
return torch.npu.Stream()
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
|
||||
model_input = (NPUStatefulModelInput.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
))
|
||||
return model_input
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> StatefulModelInput:
|
||||
frozen_model_input: ModelInputForNPUWithSamplingMetadata = \
|
||||
self._base_model_runner.prepare_model_input(
|
||||
seq_group_metadata_list,
|
||||
virtual_engine,
|
||||
finished_requests_ids)
|
||||
|
||||
assert frozen_model_input.query_lens is not None
|
||||
assert frozen_model_input.seq_lens is not None
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
num_queries = len(frozen_model_input.query_lens)
|
||||
num_seqs = len(frozen_model_input.seq_lens)
|
||||
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
|
||||
|
||||
model_input = NPUStatefulModelInput(
|
||||
frozen_model_input=frozen_model_input,
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
num_single_step_prefills=num_single_step_prefills,
|
||||
step_cuda_events=[torch.npu.Event(blocking=True)] * 2,
|
||||
)
|
||||
|
||||
return model_input
|
||||
|
||||
def _async_process_outputs(self, model_input: StatefulModelInput,
|
||||
output_proc_callback: Callable):
|
||||
# Proceed with pythonization and output_proc in order.
|
||||
# Stop on the first one that fails to pythonize
|
||||
output_proc_callback()
|
||||
|
||||
cont = True
|
||||
for step_num, model_output in enumerate(model_input.cached_outputs):
|
||||
if not model_output.pythonized:
|
||||
model_output.maybe_pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
if model_output.pythonized:
|
||||
ctx = output_proc_callback.keywords["ctx"] # type: ignore
|
||||
ctx.append_output(
|
||||
outputs=[model_output.sampler_output],
|
||||
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False,
|
||||
is_first_step_output=step_num == 0)
|
||||
|
||||
output_proc_callback()
|
||||
else:
|
||||
cont = False
|
||||
|
||||
if not cont:
|
||||
break
|
||||
|
||||
def _final_process_outputs(
|
||||
self, model_input: StatefulModelInput,
|
||||
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
|
||||
assert model_input.frozen_model_input is not None
|
||||
|
||||
has_async_callback = output_proc_callback is not None
|
||||
|
||||
outputs = []
|
||||
for step_num, output in enumerate(model_input.cached_outputs):
|
||||
is_last_step = step_num == len(model_input.cached_outputs) - 1
|
||||
|
||||
# For non-async case:
|
||||
# -- We simply add the outputs
|
||||
# For async case:
|
||||
# -- Invoke callback, pythonize, add to callback queue and repeat
|
||||
# -- For last output, just add to callback queue
|
||||
if has_async_callback:
|
||||
assert output_proc_callback is not None
|
||||
|
||||
# Invoke callback before pythonize (to overlap with NPU)
|
||||
output_proc_callback()
|
||||
|
||||
# Pythonize
|
||||
if not output.pythonized:
|
||||
output.pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
|
||||
# For non last step, add to callback queue to chain
|
||||
# callbacks=>pythonize pairs (for NPU overlap)
|
||||
if not is_last_step:
|
||||
ctx = output_proc_callback.keywords[ # type: ignore
|
||||
"ctx"] # type: ignore
|
||||
ctx.append_output(
|
||||
outputs=[output.sampler_output],
|
||||
seq_group_metadata_list=ctx.
|
||||
seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False,
|
||||
is_first_step_output=step_num == 0)
|
||||
else:
|
||||
outputs.append(output.sampler_output)
|
||||
else:
|
||||
output.pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
outputs.append(output.sampler_output)
|
||||
|
||||
return outputs
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: StatefulModelInput,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
"""
|
||||
Execute the model for a single step and update multi-step
|
||||
metadata
|
||||
"""
|
||||
assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
# path for warm up runs
|
||||
if not model_input.is_multi_step:
|
||||
return self._base_model_runner.execute_model(
|
||||
frozen_model_input, kv_caches, intermediate_tensors, num_steps)
|
||||
|
||||
# make sure we skip the sampler on the lask rank and only pythonize
|
||||
# if CPU is ahead.
|
||||
if self.is_driver_worker and get_pp_group().is_last_rank:
|
||||
if self.pinned_sampled_token_ids is None:
|
||||
self.pinned_sampled_token_ids = torch.zeros(
|
||||
(self.scheduler_config.max_num_seqs, 1),
|
||||
dtype=torch.long,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
|
||||
self._base_model_runner.model.sampler.include_gpu_probs_tensor = (
|
||||
True)
|
||||
if frozen_model_input.sampling_metadata:
|
||||
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
True)
|
||||
|
||||
# some pre-execute model logic for multi-step:
|
||||
# - if it's the first step, we need to reset the sampling tensors
|
||||
# - if it's not the first step, we need to advance the step using the
|
||||
# appended sampler output from last iteration
|
||||
# - also maybe pythonize if CPU is ahead of NPU
|
||||
|
||||
stream = current_stream()
|
||||
if not model_input.is_first_multi_step:
|
||||
# Explicitly block on the previous step's forward to make sure we
|
||||
# don't clobber any NPU tensors still in use.
|
||||
# This is not needed for flashattn backend, but for other attn
|
||||
# backends such as flashinfer that performs extra CPU operations on
|
||||
# input metadata we may need to synchronize any CPU operations that
|
||||
# might clobber enqueued forwards. (prevents CPU from running too
|
||||
# far ahead if needed)
|
||||
model_input.wait_previous_step()
|
||||
model_input = self._advance_step(
|
||||
model_input, model_input.cached_outputs[-1].sampler_output)
|
||||
|
||||
# frozen_model_input may have been updated
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
if model_input.base_output_proc_callback is None:
|
||||
assert frozen_model_input is not None
|
||||
model_input.base_output_proc_callback = \
|
||||
frozen_model_input.async_callback
|
||||
|
||||
if frozen_model_input.async_callback is not None:
|
||||
assert model_input.base_output_proc_callback is not None
|
||||
async_callback = functools.partial(
|
||||
self._async_process_outputs,
|
||||
model_input=model_input,
|
||||
output_proc_callback=model_input.base_output_proc_callback)
|
||||
|
||||
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||
model_input.frozen_model_input,
|
||||
async_callback=async_callback)
|
||||
# Update the local instance
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
# Execute the model
|
||||
output = self._base_model_runner.execute_model(frozen_model_input,
|
||||
kv_caches,
|
||||
intermediate_tensors,
|
||||
num_steps=1)
|
||||
|
||||
# record the event for the current step so that the next step can sync
|
||||
model_input.record_step_event(stream)
|
||||
|
||||
if get_pp_group().is_last_rank and self.is_driver_worker:
|
||||
assert isinstance(output, list)
|
||||
assert len(
|
||||
output
|
||||
) == 1, "MultiStepModelRunner requires single-step base_models"
|
||||
|
||||
# event for the pythonization so that we only pythonize if the
|
||||
# tensors are ready. May be able to be combined with the step event
|
||||
output_ready_event = torch.npu.Event()
|
||||
output_ready_event.record(stream)
|
||||
if self.parallel_config.pipeline_parallel_size > 1:
|
||||
output[0].sampled_token_ids_cpu = output[
|
||||
0].sampled_token_ids.cpu()
|
||||
model_input.cached_outputs.append(
|
||||
NPUModelOutput(output[0], output_ready_event,
|
||||
output[0].sampled_token_ids, False,
|
||||
output[0].logprobs, self.pythonization_cache))
|
||||
|
||||
# These NPU tensors are not required by multi-step;
|
||||
# erase them to ensure they are not pythonized or
|
||||
# transferred to CPU
|
||||
output[0].sampled_token_ids = None
|
||||
output[0].sampled_token_probs = None
|
||||
output[0].logprobs = None
|
||||
|
||||
# Pythonize the output if CPU is ahead and the previous step is
|
||||
# ready.
|
||||
if frozen_model_input.async_callback is None:
|
||||
for model_output in model_input.cached_outputs:
|
||||
model_output.maybe_pythonize(model_input,
|
||||
self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
|
||||
model_input.current_step += 1
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
# Should be IntermediateTensors
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
return output
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
# Pythonize the output and block if needed since it is the last step
|
||||
if model_input.is_last_step:
|
||||
outputs = self._final_process_outputs(
|
||||
model_input, model_input.base_output_proc_callback)
|
||||
if self.pythonization_cache:
|
||||
self.pythonization_cache.reset()
|
||||
return outputs
|
||||
|
||||
# should be [SamplerOutput]
|
||||
return output
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
|
||||
num_seqs: Optional[int], num_queries: int):
|
||||
|
||||
assert sampling_metadata.num_prompts == 0
|
||||
assert len(sampling_metadata.seq_groups) == num_queries
|
||||
assert sampling_metadata.selected_token_indices.shape == (
|
||||
num_queries, )
|
||||
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
|
||||
|
||||
# Verify that all sequences are decodes
|
||||
for i in range(num_queries):
|
||||
seq_group = sampling_metadata.seq_groups[i]
|
||||
|
||||
assert seq_group.is_prompt is False # No prompt
|
||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||
assert seq_group.sample_indices == [i] # Simple
|
||||
assert seq_group.seq_len is None # Decode
|
||||
assert seq_group.query_len is None # Decode
|
||||
|
||||
def _advance_step(self, model_input: StatefulModelInput,
|
||||
out: SamplerOutput) -> StatefulModelInput:
|
||||
|
||||
model_input.maybe_advance_frozen_model_input(self.device,
|
||||
self.pin_memory)
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
assert frozen_model_input.input_tokens is not None
|
||||
assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
|
||||
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
|
||||
num_seqs = model_input.num_seqs
|
||||
num_queries = model_input.num_queries
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
attn_metadata = frozen_model_input.attn_metadata
|
||||
assert attn_metadata is not None
|
||||
|
||||
turn_prefills_into_decodes: bool = model_input.current_step == 1 and \
|
||||
model_input.num_single_step_prefills != 0
|
||||
attn_metadata.advance_step(
|
||||
frozen_model_input,
|
||||
sampled_token_ids,
|
||||
self.block_size,
|
||||
num_seqs,
|
||||
num_queries,
|
||||
turn_prefills_into_decodes=turn_prefills_into_decodes)
|
||||
|
||||
return model_input
|
||||
|
||||
def load_model(self) -> None:
|
||||
self._base_model_runner.load_model()
|
||||
self.model_memory_usage = self._base_model_runner.model_memory_usage
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
return self._base_model_runner.save_sharded_state(
|
||||
path, pattern, max_size)
|
||||
|
||||
def save_tensorized_model(self,
|
||||
tensorizer_config: TensorizerConfig) -> None:
|
||||
return self._base_model_runner.save_tensorized_model(tensorizer_config)
|
||||
|
||||
def profile_run(self) -> None:
|
||||
return self._base_model_runner.profile_run()
|
||||
|
||||
def remove_all_loras(self):
|
||||
return self._base_model_runner.remove_all_loras()
|
||||
|
||||
def capture_model(self, kv_caches: List[List]) -> None:
|
||||
return self._base_model_runner.capture_model(kv_caches)
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._base_model_runner.vocab_size
|
||||
|
||||
|
||||
DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
|
||||
Optional[List[SampleLogprobs]]]
|
||||
|
||||
|
||||
def deferred_pythonize_logprobs(
|
||||
output: SamplerOutput,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
logprobs_tensor: Optional[torch.Tensor],
|
||||
) -> DeferredLogprobsReturnType:
|
||||
"""Perform deferred logprob Pythonization.
|
||||
|
||||
1. Pythonize NPU-side sampler result tensors into CPU-side sampler result.
|
||||
2. Pythonize NPU-side logprobs tensor into CPU-side logprobs lists,
|
||||
utilizing the Pythonized sampler result computed in step 1.
|
||||
|
||||
These deferred computations are not required for single-step scheduling
|
||||
or the `profile_run()` phase of multi-step scheduling.
|
||||
|
||||
Args:
|
||||
output: sampler output (under deferred Pythonization)
|
||||
sampling_metadata
|
||||
|
||||
Returns:
|
||||
prompt_logprobs (CPU), sample_logprobs (CPU)
|
||||
"""
|
||||
|
||||
# - Deferred pythonization of sample result
|
||||
sampler_result = get_pythonized_sample_results(
|
||||
output.deferred_sample_results_args)
|
||||
|
||||
# - Erase the NPU-side deferred sample_result
|
||||
# computation args to ensure it is never
|
||||
# pythonized or transferred to CPU
|
||||
output.deferred_sample_results_args = None
|
||||
|
||||
# - Deferred pythonization of logprobs
|
||||
(
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
|
||||
assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
|
||||
assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
|
||||
|
||||
return prompt_logprobs, sample_logprobs
|
||||
|
||||
|
||||
def _pythonize_sampler_output(
|
||||
model_input: StatefulModelInput,
|
||||
output: SamplerOutput,
|
||||
pinned_sampled_token_buffer: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
logprobs_tensor: Optional[torch.Tensor],
|
||||
cache: Optional[PythonizationCache],
|
||||
) -> None:
|
||||
""" This function is only called when the output tensors are ready.
|
||||
See :class:`ModelOutput`.
|
||||
|
||||
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
|
||||
adding a Pythonized output data structure
|
||||
(:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
|
||||
|
||||
Args:
|
||||
model_input
|
||||
output: sampler output
|
||||
pinned_sampled_token_token_buffer: CPU-side pinned memory
|
||||
(receives copy of
|
||||
NPU-side token buffer.)
|
||||
sampled_token_ids: NPU-side token buffer
|
||||
logprobs_tensor: NPU-side tensor containing
|
||||
logprobs computed during sampling
|
||||
"""
|
||||
|
||||
assert model_input.frozen_model_input is not None
|
||||
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input.sampling_metadata is not None
|
||||
sampling_metadata = frozen_model_input.sampling_metadata
|
||||
# samples generation should have been skipped
|
||||
assert not output.outputs
|
||||
|
||||
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
|
||||
|
||||
# We guarantee output tensors are ready, so it is safe to
|
||||
# pythonize the sampler output & obtain CPU-side logprobs.
|
||||
#
|
||||
# However we should check whether logprobs pythonization may
|
||||
# be skipped entirely, i.e. because no logprobs were requested
|
||||
# or pythonization was not deferred. To that end,
|
||||
#
|
||||
# * `prompt_logprobs_are_requested_for_prefill` signals that
|
||||
# there are *any* prefill-phase requests which specify that
|
||||
# prompt logprobs should be returned.
|
||||
#
|
||||
# * `any_logprobs_are_requested` signals that there are any
|
||||
# requests which (1) specify that sample logprobs should be
|
||||
# returned, or (2) are in the prefill phase AND specify that
|
||||
# prompt logprobs should be returned.
|
||||
#
|
||||
# Later on, these flags cause adjustments to the pythonization
|
||||
# process to accommodate logprobs.
|
||||
|
||||
seq_groups = sampling_metadata.seq_groups
|
||||
prompt_logprobs_are_requested_for_prefill = any([
|
||||
sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
|
||||
for sg in seq_groups
|
||||
])
|
||||
any_logprobs_are_requested = (
|
||||
prompt_logprobs_are_requested_for_prefill
|
||||
or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))
|
||||
|
||||
if prompt_logprobs_are_requested_for_prefill:
|
||||
# CPU NPU sync, after gathering *only* sampled tokens (since
|
||||
# requesting prompt logprobs leads `sampled_token_ids` to
|
||||
# include prompt token ids in addition to sampled token ids.)
|
||||
sample_idx_tensor = torch.tensor(
|
||||
[sdx for sg in seq_groups for sdx in sg.sample_indices])
|
||||
pinned_buffer = pinned_buffer.copy_(
|
||||
sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
|
||||
else:
|
||||
# CPU NPU sync
|
||||
pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
|
||||
non_blocking=False)
|
||||
|
||||
# this will not block as the tensors are already on CPU
|
||||
samples_list = pinned_buffer.tolist()
|
||||
|
||||
skip_sampler_cpu_output = (
|
||||
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
|
||||
|
||||
# *Don't* skip logprobs pythonization *if*:
|
||||
# * Any requests require logprobs to be returned in this
|
||||
# iteration AND
|
||||
# * These requests are being scheduled in a fashion which
|
||||
# defers pythonization (i.e. multi-step scheduling.)
|
||||
do_pythonize_logprobs = (skip_sampler_cpu_output
|
||||
and any_logprobs_are_requested)
|
||||
(
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
) = (deferred_pythonize_logprobs(output, sampling_metadata,
|
||||
logprobs_tensor)
|
||||
if do_pythonize_logprobs else (None, None))
|
||||
|
||||
for sgdx, (seq_group,
|
||||
sample_result) in enumerate(zip(seq_groups, samples_list)):
|
||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
# (Check for Guided Decoding)
|
||||
if seq_group.sampling_params.logits_processors:
|
||||
assert len(seq_group.sampling_params.logits_processors) == 0, (
|
||||
"Logits Processors are not supported in multi-step decoding")
|
||||
|
||||
if do_pythonize_logprobs:
|
||||
assert prompt_logprobs is not None
|
||||
assert sample_logprobs is not None
|
||||
|
||||
(
|
||||
group_prompt_logprobs,
|
||||
group_sample_logprobs,
|
||||
) = ( # Utilize deferred pythonization results
|
||||
prompt_logprobs[sgdx],
|
||||
sample_logprobs[sgdx],
|
||||
)
|
||||
elif any_logprobs_are_requested:
|
||||
(
|
||||
group_prompt_logprobs,
|
||||
group_sample_logprobs,
|
||||
) = (
|
||||
# profile_run: use already-computed logprobs
|
||||
output.outputs[sgdx].prompt_logprobs,
|
||||
[sample.logprobs for sample in output.outputs[sgdx].samples])
|
||||
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids = sample_result
|
||||
parent_ids = [0]
|
||||
seq_outputs: List[SequenceOutput]
|
||||
|
||||
if cache is not None:
|
||||
completion_seq_group_output: CompletionSequenceGroupOutput = \
|
||||
cache.cached_completion_seq_group_output.get_object()
|
||||
completion_seq_group_output.samples.clear()
|
||||
seq_outputs = completion_seq_group_output.samples
|
||||
else:
|
||||
seq_outputs = []
|
||||
|
||||
for tdx, (parent_id,
|
||||
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
|
||||
if cache is not None:
|
||||
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
|
||||
)
|
||||
seq_output.parent_seq_id = seq_ids[parent_id]
|
||||
seq_output.output_token = next_token_id
|
||||
|
||||
if any_logprobs_are_requested:
|
||||
seq_output.logprobs = group_sample_logprobs[tdx]
|
||||
else:
|
||||
logprobs = next(iter(seq_output.logprobs.values()))
|
||||
seq_output.logprobs.clear()
|
||||
|
||||
logprobs.logprob = float('inf')
|
||||
logprobs.rank = None
|
||||
logprobs.decoded_token = None
|
||||
|
||||
seq_output.logprobs[next_token_id] = logprobs
|
||||
|
||||
seq_outputs.append(seq_output)
|
||||
|
||||
else:
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||
(group_sample_logprobs[tdx]
|
||||
if any_logprobs_are_requested else {
|
||||
next_token_id:
|
||||
Logprob(logprob=float('inf'),
|
||||
rank=None,
|
||||
decoded_token=None)
|
||||
})))
|
||||
if cache is not None:
|
||||
completion_seq_group_output.prompt_logprobs = \
|
||||
group_prompt_logprobs if any_logprobs_are_requested else None
|
||||
output.outputs.append(completion_seq_group_output)
|
||||
else:
|
||||
output.outputs.append(
|
||||
CompletionSequenceGroupOutput(
|
||||
seq_outputs, (group_prompt_logprobs
|
||||
if any_logprobs_are_requested else None)))
|
||||
|
||||
assert len(output.outputs) > 0
|
||||
194
vllm_ascend/worker/multi_step_worker.py
Normal file
194
vllm_ascend/worker/multi_step_worker.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.model_runner_base import BroadcastableModelInput
|
||||
from vllm.worker.multi_step_model_runner import StatefulModelInput
|
||||
|
||||
from vllm_ascend.worker.multi_step_runner import MultiStepModelNPURunner
|
||||
from vllm_ascend.worker.worker import NPUWorker, WorkerInput
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiStepState:
|
||||
worker_input: WorkerInput
|
||||
model_input: StatefulModelInput
|
||||
|
||||
|
||||
class MultiStepWorker(NPUWorker):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
base_model_runner = self.model_runner
|
||||
# for multi-step model, wrap the model runner with MultiStepModelRunner
|
||||
self.model_runner = MultiStepModelNPURunner(
|
||||
base_model_runner,
|
||||
vllm_config=base_model_runner.vllm_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=base_model_runner.is_driver_worker,
|
||||
)
|
||||
|
||||
pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
self.multi_step_states: List[
|
||||
Optional[MultiStepState]] = [None] * pipeline_parallel_size
|
||||
self.temp_output = None
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Get the driver input and broadcast it to other workers.
|
||||
"""
|
||||
assert self.is_driver_worker
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
is_first_multi_step = execute_model_req.is_first_multi_step
|
||||
if is_first_multi_step:
|
||||
# on first step we prepare the worker input and model input normally
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: StatefulModelInput = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||
model_input.frozen_model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
else:
|
||||
# on subsequent steps we reuse the worker input and model input
|
||||
multi_step_state = self.multi_step_states[virtual_engine]
|
||||
worker_input = multi_step_state.worker_input
|
||||
model_input = multi_step_state.model_input
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
# clear the cached metadata so that it can be recomputed on
|
||||
# the workers.
|
||||
frozen_model_input.attn_metadata._cached_prefill_metadata = None
|
||||
frozen_model_input.attn_metadata._cached_decode_metadata = None
|
||||
|
||||
model_input.is_first_multi_step = is_first_multi_step
|
||||
model_input.is_last_step = execute_model_req.is_last_step
|
||||
|
||||
if not is_first_multi_step:
|
||||
# we broadcast the last sampled token ids to all TP workers so they
|
||||
# can update their model input metadata in-place.
|
||||
self._prepare_last_sampled_token_ids_for_tp_workers(
|
||||
execute_model_req=execute_model_req, model_input=model_input)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
# Retuning empty dict here to keep this compatible with
|
||||
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||
return model_input, worker_input, {}
|
||||
|
||||
def _prepare_last_sampled_token_ids_for_tp_workers(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
model_input: StatefulModelInput,
|
||||
) -> None:
|
||||
"""
|
||||
Prepare the last sampled token ids for TP workers. If it's the last
|
||||
PP rank, then the last sampled token ids are already in the model_input.
|
||||
If it is NOT the last PP rank, then we need to get the last sampled
|
||||
token that is cached in the execute_model_req.
|
||||
"""
|
||||
if get_pp_group().is_last_rank:
|
||||
assert model_input.cached_outputs[
|
||||
-1].sampler_output.sampled_token_ids is None
|
||||
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
||||
model_input.last_sampled_token_ids = model_input.cached_outputs[
|
||||
-1].sampled_token_ids
|
||||
# free sampled token ids from the previous step if it has been
|
||||
# pythonized. Cannot free the last sampled token ids because
|
||||
# we need it for GPU advance_step.
|
||||
for output in model_input.cached_outputs[:-1]:
|
||||
if output.pythonized:
|
||||
output.sampled_token_ids = None
|
||||
else:
|
||||
# otherwise we need to get the cached sampled token ids from the
|
||||
# execute_model_req
|
||||
assert execute_model_req.last_sampled_token_ids is not None
|
||||
model_input.last_sampled_token_ids = (
|
||||
execute_model_req.last_sampled_token_ids.cuda())
|
||||
model_input.add_sampler_output(
|
||||
SamplerOutput(outputs=[], sampled_token_ids=None),
|
||||
model_input.last_sampled_token_ids)
|
||||
|
||||
# free sampled token ids from the previous step.
|
||||
# TODO(will) we could reuse the sampled token ids tensor from
|
||||
# the previous step instead.
|
||||
for output in model_input.cached_outputs[:-1]:
|
||||
output.sampled_token_ids = None
|
||||
assert model_input.cached_outputs[-1].sampled_token_ids is not None
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
|
||||
torch.Tensor]]]:
|
||||
"""
|
||||
Depending on the current state of the request and multi step worker,
|
||||
this method may skip the normal _prepare_model_input and
|
||||
_prepare_worker_input methods and instead used cached values.
|
||||
"""
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
(model_input, worker_input,
|
||||
kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
if execute_model_req.is_first_multi_step:
|
||||
# cache the worker input and model input for the next steps
|
||||
self.multi_step_states[virtual_engine] = MultiStepState(
|
||||
worker_input=worker_input, model_input=model_input)
|
||||
# if TP workers
|
||||
else:
|
||||
broadcast_data = self._get_worker_input_from_broadcast()
|
||||
# if the driver has sent an empty input, we should stop the worker
|
||||
# loop
|
||||
if broadcast_data is None:
|
||||
return None
|
||||
model_input, worker_input, kwargs = broadcast_data
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
if model_input.is_first_multi_step:
|
||||
pass
|
||||
# TODO(will) Can cache the worker input and model input for the
|
||||
# next steps. See below for details
|
||||
else:
|
||||
# TODO(will) possible to also cache and reuse the cached worker
|
||||
# input and model input. The idea is essentially the delta
|
||||
# optimization for model_inputs. Where the TP workers can cache
|
||||
# the model input states and we only broadcast the delta need
|
||||
# for the next step (sampled_token_ids from the previous step)
|
||||
|
||||
assert isinstance(model_input, StatefulModelInput)
|
||||
# we need to update the last sampled token ids in the model
|
||||
# input for the workers so that they can run inplace
|
||||
# advance_step
|
||||
model_input.add_sampler_output(
|
||||
SamplerOutput(outputs=[], sampled_token_ids=None),
|
||||
model_input.last_sampled_token_ids)
|
||||
|
||||
assert model_input is not None
|
||||
assert worker_input is not None
|
||||
return model_input, worker_input, kwargs
|
||||
Reference in New Issue
Block a user