diff --git a/docs/source/user_guide/suppoted_features.md b/docs/source/user_guide/suppoted_features.md index 4c52995..a44183f 100644 --- a/docs/source/user_guide/suppoted_features.md +++ b/docs/source/user_guide/suppoted_features.md @@ -13,7 +13,7 @@ | LogProbs | ✅ | | | Basic functions available | Need fully test | | Prompt logProbs | ✅ | | | Basic functions available | Need fully test | | Async output | ✅ | | | Basic functions available | Need fully test | -| Multi step scheduler | ✅ | | | Basic functions available | Need fully test | +| Multi step scheduler | ✅ | | | Basic functions available | Need fully test, Find more details at [ Blog ](https://blog.vllm.ai/2024/09/05/perf-update.html#batch-scheduling-multiple-steps-ahead-pr-7000), [ RFC ](https://github.com/vllm-project/vllm/issues/6854) and [issue](https://github.com/vllm-project/vllm/pull/7000) | | Best of | ✅ | | | Basic functions available | Need fully test | | Beam search | ✅ | | | Basic functions available | Need fully test | | Guided Decoding | ✅ | | | Basic functions available | Find more details at the [issue](https://github.com/vllm-project/vllm-ascend/issues/177) | diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 587d443..178d381 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -16,7 +16,6 @@ # from dataclasses import dataclass -from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import numpy as np @@ -216,9 +215,6 @@ class AscendMetadata(AttentionMetadata): # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) @@ -234,18 +230,6 @@ class AscendMetadata(AttentionMetadata): # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - # Self-attention prefill/decode metadata cache _cached_prefill_metadata: Optional["AscendMetadata"] = None _cached_decode_metadata: Optional["AscendMetadata"] = None @@ -283,18 +267,10 @@ class AscendMetadata(AttentionMetadata): or (self.encoder_seq_lens is not None)) # Compute some attn_metadata fields which default to None. - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[:self.num_prefill_tokens]) seq_lens = (None if self.seq_lens is None else self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) @@ -311,11 +287,7 @@ class AscendMetadata(AttentionMetadata): seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, block_tables=block_tables, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, @@ -343,8 +315,6 @@ class AscendMetadata(AttentionMetadata): self.slot_mapping[self.num_prefill_tokens:]) seq_lens = (None if self.seq_lens is None else self.seq_lens[self.num_prefills:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) seq_lens_tensor = (None if self.seq_lens_tensor is None else @@ -357,19 +327,9 @@ class AscendMetadata(AttentionMetadata): slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, block_tables=block_tables, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, @@ -427,14 +387,6 @@ class AscendMetadata(AttentionMetadata): assert self.max_query_len == 1 assert self.max_prefill_seq_len == 0 - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - assert self.block_tables is not None assert self.block_tables.shape[0] == num_seqs @@ -576,11 +528,6 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): device = self.runner.device max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) @@ -592,8 +539,6 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): else: self.attn_mask = None num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) block_tables = make_tensor_with_pad( self.block_tables, @@ -604,27 +549,16 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): assert max_query_len > 0, "query_lens: {}".format(query_lens) assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32, device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in self.multimodal_placeholder_maps.items() } - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.long, - device=device) - return AscendMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -635,12 +569,8 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, block_tables=block_tables, attn_mask=self.attn_mask, ) diff --git a/vllm_ascend/worker/multi_step_runner.py b/vllm_ascend/worker/multi_step_runner.py index 3ba39ad..65f0208 100644 --- a/vllm_ascend/worker/multi_step_runner.py +++ b/vllm_ascend/worker/multi_step_runner.py @@ -1,7 +1,8 @@ import dataclasses import functools from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) import torch from torch import nn @@ -14,19 +15,26 @@ from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.utils import current_stream +from vllm.worker.model_runner_base import ( + _init_attn_metadata_from_tensor_dict, + _init_frozen_model_input_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) from vllm.worker.multi_step_model_runner import (ModelOutput, PythonizationCache, StatefulModelInput) -from vllm_ascend.utils import current_stream from vllm_ascend.worker.model_runner import ( ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase) +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + logger = init_logger(__name__) @dataclass(frozen=False) -class NPUStatefulModelInput(StatefulModelInput): +class StatefulModelInputForNPU(StatefulModelInput): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -40,6 +48,66 @@ class NPUStatefulModelInput(StatefulModelInput): torch.npu.Event(blocking=True) self.step_cuda_events[self.current_step & 1].record(current_stream) + # actual frozen model input dataclass passed to _base_model_runner + frozen_model_input: Optional[ModelInputForNPUWithSamplingMetadata] = None + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "StatefulModelInputForNPU": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + tensor_dict = _init_frozen_model_input_from_tensor_dict( + ModelInputForNPUWithSamplingMetadata, tensor_dict) + return cls(**tensor_dict) + + def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): + """ + Advancing the datastructures of StatefulModelInput::frozen_model_input + is only required when prefills are scheduled with decodes to run in + multi-step. This advancement/correction is required to account for + the conversion of Prefills to Decodes after the first multi-step. + """ + if self.current_step != 1 or self.num_single_step_prefills == 0: + return + + assert self.frozen_model_input is not None + fmi = self.frozen_model_input + + # Truncate input_tokens + assert fmi.input_tokens is not None + assert fmi.input_tokens.shape[0] >= self.num_seqs + fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs] + + # Update frozen_model_input::input_positons. + assert fmi.input_positions is not None + assert fmi.input_positions.shape[0] >= self.num_seqs + fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self. + num_seqs] + + # Assert unsupported + # TODO Uncomment the following codes when NPU supported + # assert fmi.lora_mapping is None + # assert fmi.lora_requests is not None + # assert len(fmi.lora_requests) == 0 + # assert fmi.prompt_adapter_mapping is None + # assert fmi.prompt_adapter_requests is not None + # assert len(fmi.prompt_adapter_requests) == 0 + assert fmi.attn_metadata is not None + assert fmi.multi_modal_kwargs is not None + assert len(fmi.multi_modal_kwargs) == 0 + + self.frozen_model_input = dataclasses.replace( + self.frozen_model_input, + input_tokens=fmi_new_input_tokens, + input_positions=fmi_new_input_positions) + + self.maybe_advance_sampling_metadata(device, pin_memory) + @dataclass(frozen=False) class NPUModelOutput(ModelOutput): @@ -78,7 +146,7 @@ class NPUModelOutput(ModelOutput): return True -class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]): +class MultiStepModelNPURunner(NPUModelRunnerBase[StatefulModelInputForNPU]): # mypy: enable-error-code=type-var def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs): @@ -101,7 +169,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]): if self.parallel_config.pipeline_parallel_size == 1 else None def get_model(self) -> nn.Module: - return self.model + return self._base_model_runner.get_model() @functools.cached_property def _copy_stream(self): @@ -109,8 +177,8 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]): return torch.npu.Stream() def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: - model_input = (NPUStatefulModelInput.from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> StatefulModelInputForNPU: + model_input = (StatefulModelInputForNPU.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, )) @@ -121,7 +189,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]): seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None - ) -> StatefulModelInput: + ) -> StatefulModelInputForNPU: frozen_model_input: ModelInputForNPUWithSamplingMetadata = \ self._base_model_runner.prepare_model_input( seq_group_metadata_list, @@ -135,7 +203,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]): num_seqs = len(frozen_model_input.seq_lens) num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills - model_input = NPUStatefulModelInput( + model_input = StatefulModelInputForNPU( frozen_model_input=frozen_model_input, num_seqs=num_seqs, num_queries=num_queries, @@ -145,7 +213,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]): return model_input - def _async_process_outputs(self, model_input: StatefulModelInput, + def _async_process_outputs(self, model_input: StatefulModelInputForNPU, output_proc_callback: Callable): # Proceed with pythonization and output_proc in order. # Stop on the first one that fails to pythonize @@ -174,7 +242,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]): break def _final_process_outputs( - self, model_input: StatefulModelInput, + self, model_input: StatefulModelInputForNPU, output_proc_callback: Optional[Callable]) -> List[SamplerOutput]: assert model_input.frozen_model_input is not None @@ -225,7 +293,7 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]): @torch.inference_mode() def execute_model( self, - model_input: StatefulModelInput, + model_input: StatefulModelInputForNPU, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, @@ -382,8 +450,8 @@ class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]): assert seq_group.seq_len is None # Decode assert seq_group.query_len is None # Decode - def _advance_step(self, model_input: StatefulModelInput, - out: SamplerOutput) -> StatefulModelInput: + def _advance_step(self, model_input: StatefulModelInputForNPU, + out: SamplerOutput) -> StatefulModelInputForNPU: model_input.maybe_advance_frozen_model_input(self.device, self.pin_memory) @@ -491,7 +559,7 @@ def deferred_pythonize_logprobs( def _pythonize_sampler_output( - model_input: StatefulModelInput, + model_input: StatefulModelInputForNPU, output: SamplerOutput, pinned_sampled_token_buffer: torch.Tensor, sampled_token_ids: torch.Tensor,