diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 4cc9301..8456cb8 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -16,6 +16,7 @@ # from dataclasses import dataclass +from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import numpy as np @@ -38,7 +39,8 @@ from vllm.attention.backends.utils import (CommonAttentionState, from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: - from vllm_ascend.worker.model_runner import ModelInputForNPUBuilder + from vllm_ascend.worker.model_runner import ( + ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata) def generate_attn_mask(max_seq_len: int, dtype=torch.float16): @@ -197,26 +199,52 @@ class AscendMetadata(AttentionMetadata): # FIXME: It is for flash attn. # Maximum sequence length among prefill batch. 0 if there are decoding + # Avoid mypy error + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) block_tables: Optional[torch.Tensor] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] = None - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] = None - # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + # Self-attention prefill/decode metadata cache _cached_prefill_metadata: Optional["AscendMetadata"] = None _cached_decode_metadata: Optional["AscendMetadata"] = None @@ -254,10 +282,18 @@ class AscendMetadata(AttentionMetadata): or (self.encoder_seq_lens is not None)) # Compute some attn_metadata fields which default to None. + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[:self.num_prefill_tokens]) seq_lens = (None if self.seq_lens is None else self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) @@ -274,7 +310,11 @@ class AscendMetadata(AttentionMetadata): seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, @@ -302,6 +342,8 @@ class AscendMetadata(AttentionMetadata): self.slot_mapping[self.num_prefill_tokens:]) seq_lens = (None if self.seq_lens is None else self.seq_lens[self.num_prefills:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) seq_lens_tensor = (None if self.seq_lens_tensor is None else @@ -314,8 +356,19 @@ class AscendMetadata(AttentionMetadata): slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, block_tables=block_tables, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, @@ -328,6 +381,98 @@ class AscendMetadata(AttentionMetadata): enable_kv_scales_calculation=False) return self._cached_decode_metadata + def advance_step(self, + model_input: "ModelInputForNPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + # TODO optimize these codes using ascendc just like flash attention backend using cuda + + # update input_tokens + sampled_token_ids_list = sampled_token_ids[: + num_queries].squeeze( # type: ignore + -1) + model_input.input_tokens[: + num_queries] = sampled_token_ids_list # type: ignore + + # get seq_lens and input_positions + seq_lens = self.seq_lens_tensor[:num_queries] + next_seq_lens = seq_lens + 1 + next_input_pos = next_seq_lens - 1 + + # update seq_lens and input_positions + self.seq_lens_tensor[:num_queries] = next_seq_lens + model_input.input_positions[: + num_queries] = next_input_pos # type: ignore + + # 计算 block index 和 offset + block_idx = next_input_pos // block_size + block_offset = next_input_pos % block_size + + current_block_table = self.block_tables.gather( + 1, block_idx.unsqueeze(-1)).squeeze(-1) + slot_num = current_block_table * block_size + block_offset + + # update slot_mapping + self.slot_mapping[:num_queries] = slot_num + class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): @@ -430,6 +575,11 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): device = self.runner.device max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) @@ -440,6 +590,9 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): self.input_builder.runner.device) else: self.attn_mask = None + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) block_tables = make_tensor_with_pad( self.block_tables, @@ -450,9 +603,17 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): assert max_query_len > 0, "query_lens: {}".format(query_lens) assert device is not None - + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32, device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in @@ -466,15 +627,19 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): return AscendMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=False, num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=self.num_decode_tokens, + num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, attn_mask=self.attn_mask, ) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index ce2dd90..727750a 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -105,7 +105,11 @@ class NPUPlatform(Platform): def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": - parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker" + if vllm_config.scheduler_config.is_multi_step: + parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker" + else: + parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker" + cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: cache_config.block_size = 128 diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index d12e72c..5d4c6be 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import torch from vllm.logger import init_logger logger = init_logger(__name__) @@ -33,3 +33,23 @@ def try_register_lib(lib_name: str, lib_info: str = ""): logger.info(lib_info) except Exception: pass + + +_current_stream = None + + +def current_stream() -> torch.npu.Stream: + """ + replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`. + it turns out that `torch.npu.current_stream()` is quite expensive, + as it will construct a new stream object at each call. + here we patch `torch.npu.set_stream` to keep track of the current stream + directly, so that we can avoid calling `torch.npu.current_stream()`. + + """ + global _current_stream + if _current_stream is None: + # when this function is called before any stream is set, + # we return the default stream. + _current_stream = torch.npu.current_stream() + return _current_stream diff --git a/vllm_ascend/worker/multi_step_runner.py b/vllm_ascend/worker/multi_step_runner.py new file mode 100644 index 0000000..3ba39ad --- /dev/null +++ b/vllm_ascend/worker/multi_step_runner.py @@ -0,0 +1,674 @@ +import dataclasses +import functools +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, + SamplerOutput, + SamplingMetadata, get_logprobs, + get_pythonized_sample_results) +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.worker.multi_step_model_runner import (ModelOutput, + PythonizationCache, + StatefulModelInput) + +from vllm_ascend.utils import current_stream +from vllm_ascend.worker.model_runner import ( + ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase) + +logger = init_logger(__name__) + + +@dataclass(frozen=False) +class NPUStatefulModelInput(StatefulModelInput): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def record_step_event(self, current_stream: torch.npu.Stream): + # record the event for the current step so that the next step can sync + # on it. We modulo by 2 to keep the events in a circular buffer and + # support any attn backends that may be supported in the future. ie + # Flashinfer would want two DecodeWrappers to overlap the CPU and NPU. + self.step_cuda_events[self.current_step & 1] = \ + torch.npu.Event(blocking=True) + self.step_cuda_events[self.current_step & 1].record(current_stream) + + +@dataclass(frozen=False) +class NPUModelOutput(ModelOutput): + + logprobs: Optional["torch.Tensor"] = None + + def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", + copy_stream: torch.npu.Stream, + pinned_sampled_token_buffer: torch.Tensor, + blocking: bool) -> bool: + """ + If blocking is set, will block until the forward pass for the output is + ready and pythonize the output. Upon completing Pythonization, erases + self.logprobs (note that a non-blocking call that is performed when + the sampler output is not yet ready, will not erase self.logprobs.) + """ + assert self.sampled_token_ids is not None + if not blocking and not self.sampler_output_ready_event.query(): + return False + + if blocking: + self.sampler_output_ready_event.synchronize() + with torch.npu.stream(copy_stream): + _pythonize_sampler_output(input_metadata, self.sampler_output, + pinned_sampled_token_buffer, + self.sampled_token_ids, self.logprobs, + self.pythonization_cache) + + # Erase the logprobs GPU-side tensor. + # Note that although _pythonize_sampler_output() runs in its + # own CUDA stream, nonetheless _pythonize_sampler_output() + # cannot return until Pythonization is complete; therefore + # we know that by the time the CPU reaches this point, + # `self.logprobs` is no longer needed. + self.logprobs = None + return True + + +class MultiStepModelNPURunner(NPUModelRunnerBase[NPUStatefulModelInput]): + # mypy: enable-error-code=type-var + + def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs): + super().__init__(*args, **kwargs) + + # uses the base model runner to execute the model and wraps it with + # multi-step logic + self._base_model_runner: NPUModelRunnerBase = base_model_runner + + self.is_multi_step = self.scheduler_config.is_multi_step + self.pinned_sampled_token_ids: Optional[torch.Tensor] = None + + # Using the PythonizationCache in Pipeline-Parallel clobbers the + # SequenceOutput and CompletionSequenceGroupOutput object. + # When cache-reset happens at the last step of a multi-step + # execution, there may be other on-going single-step/multi-step + # executions. The current caching implementation does not check + # for this. + self.pythonization_cache = PythonizationCache() \ + if self.parallel_config.pipeline_parallel_size == 1 else None + + def get_model(self) -> nn.Module: + return self.model + + @functools.cached_property + def _copy_stream(self): + # used to copy tensors from NPU to CPU asynchronously + return torch.npu.Stream() + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: + model_input = (NPUStatefulModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + return model_input + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> StatefulModelInput: + frozen_model_input: ModelInputForNPUWithSamplingMetadata = \ + self._base_model_runner.prepare_model_input( + seq_group_metadata_list, + virtual_engine, + finished_requests_ids) + + assert frozen_model_input.query_lens is not None + assert frozen_model_input.seq_lens is not None + assert frozen_model_input.attn_metadata is not None + num_queries = len(frozen_model_input.query_lens) + num_seqs = len(frozen_model_input.seq_lens) + num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills + + model_input = NPUStatefulModelInput( + frozen_model_input=frozen_model_input, + num_seqs=num_seqs, + num_queries=num_queries, + num_single_step_prefills=num_single_step_prefills, + step_cuda_events=[torch.npu.Event(blocking=True)] * 2, + ) + + return model_input + + def _async_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: Callable): + # Proceed with pythonization and output_proc in order. + # Stop on the first one that fails to pythonize + output_proc_callback() + + cont = True + for step_num, model_output in enumerate(model_input.cached_outputs): + if not model_output.pythonized: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + if model_output.pythonized: + ctx = output_proc_callback.keywords["ctx"] # type: ignore + ctx.append_output( + outputs=[model_output.sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=step_num == 0) + + output_proc_callback() + else: + cont = False + + if not cont: + break + + def _final_process_outputs( + self, model_input: StatefulModelInput, + output_proc_callback: Optional[Callable]) -> List[SamplerOutput]: + assert model_input.frozen_model_input is not None + + has_async_callback = output_proc_callback is not None + + outputs = [] + for step_num, output in enumerate(model_input.cached_outputs): + is_last_step = step_num == len(model_input.cached_outputs) - 1 + + # For non-async case: + # -- We simply add the outputs + # For async case: + # -- Invoke callback, pythonize, add to callback queue and repeat + # -- For last output, just add to callback queue + if has_async_callback: + assert output_proc_callback is not None + + # Invoke callback before pythonize (to overlap with NPU) + output_proc_callback() + + # Pythonize + if not output.pythonized: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + # For non last step, add to callback queue to chain + # callbacks=>pythonize pairs (for NPU overlap) + if not is_last_step: + ctx = output_proc_callback.keywords[ # type: ignore + "ctx"] # type: ignore + ctx.append_output( + outputs=[output.sampler_output], + seq_group_metadata_list=ctx. + seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=step_num == 0) + else: + outputs.append(output.sampler_output) + else: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + outputs.append(output.sampler_output) + + return outputs + + @torch.inference_mode() + def execute_model( + self, + model_input: StatefulModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + """ + Execute the model for a single step and update multi-step + metadata + """ + assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1" + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + # path for warm up runs + if not model_input.is_multi_step: + return self._base_model_runner.execute_model( + frozen_model_input, kv_caches, intermediate_tensors, num_steps) + + # make sure we skip the sampler on the lask rank and only pythonize + # if CPU is ahead. + if self.is_driver_worker and get_pp_group().is_last_rank: + if self.pinned_sampled_token_ids is None: + self.pinned_sampled_token_ids = torch.zeros( + (self.scheduler_config.max_num_seqs, 1), + dtype=torch.long, + device="cpu", + pin_memory=True) + + self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( + True) + if frozen_model_input.sampling_metadata: + frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( + True) + + # some pre-execute model logic for multi-step: + # - if it's the first step, we need to reset the sampling tensors + # - if it's not the first step, we need to advance the step using the + # appended sampler output from last iteration + # - also maybe pythonize if CPU is ahead of NPU + + stream = current_stream() + if not model_input.is_first_multi_step: + # Explicitly block on the previous step's forward to make sure we + # don't clobber any NPU tensors still in use. + # This is not needed for flashattn backend, but for other attn + # backends such as flashinfer that performs extra CPU operations on + # input metadata we may need to synchronize any CPU operations that + # might clobber enqueued forwards. (prevents CPU from running too + # far ahead if needed) + model_input.wait_previous_step() + model_input = self._advance_step( + model_input, model_input.cached_outputs[-1].sampler_output) + + # frozen_model_input may have been updated + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + if model_input.base_output_proc_callback is None: + assert frozen_model_input is not None + model_input.base_output_proc_callback = \ + frozen_model_input.async_callback + + if frozen_model_input.async_callback is not None: + assert model_input.base_output_proc_callback is not None + async_callback = functools.partial( + self._async_process_outputs, + model_input=model_input, + output_proc_callback=model_input.base_output_proc_callback) + + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=async_callback) + # Update the local instance + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + # Execute the model + output = self._base_model_runner.execute_model(frozen_model_input, + kv_caches, + intermediate_tensors, + num_steps=1) + + # record the event for the current step so that the next step can sync + model_input.record_step_event(stream) + + if get_pp_group().is_last_rank and self.is_driver_worker: + assert isinstance(output, list) + assert len( + output + ) == 1, "MultiStepModelRunner requires single-step base_models" + + # event for the pythonization so that we only pythonize if the + # tensors are ready. May be able to be combined with the step event + output_ready_event = torch.npu.Event() + output_ready_event.record(stream) + if self.parallel_config.pipeline_parallel_size > 1: + output[0].sampled_token_ids_cpu = output[ + 0].sampled_token_ids.cpu() + model_input.cached_outputs.append( + NPUModelOutput(output[0], output_ready_event, + output[0].sampled_token_ids, False, + output[0].logprobs, self.pythonization_cache)) + + # These NPU tensors are not required by multi-step; + # erase them to ensure they are not pythonized or + # transferred to CPU + output[0].sampled_token_ids = None + output[0].sampled_token_probs = None + output[0].logprobs = None + + # Pythonize the output if CPU is ahead and the previous step is + # ready. + if frozen_model_input.async_callback is None: + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, + self._copy_stream, + self.pinned_sampled_token_ids) + + model_input.current_step += 1 + + if not get_pp_group().is_last_rank: + # Should be IntermediateTensors + assert isinstance(output, IntermediateTensors) + return output + if not self.is_driver_worker: + return [] + + # Pythonize the output and block if needed since it is the last step + if model_input.is_last_step: + outputs = self._final_process_outputs( + model_input, model_input.base_output_proc_callback) + if self.pythonization_cache: + self.pythonization_cache.reset() + return outputs + + # should be [SamplerOutput] + return output + + def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata, + num_seqs: Optional[int], num_queries: int): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _advance_step(self, model_input: StatefulModelInput, + out: SamplerOutput) -> StatefulModelInput: + + model_input.maybe_advance_frozen_model_input(self.device, + self.pin_memory) + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.input_tokens is not None + assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs + assert frozen_model_input.attn_metadata is not None + + sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids + num_seqs = model_input.num_seqs + num_queries = model_input.num_queries + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + attn_metadata = frozen_model_input.attn_metadata + assert attn_metadata is not None + + turn_prefills_into_decodes: bool = model_input.current_step == 1 and \ + model_input.num_single_step_prefills != 0 + attn_metadata.advance_step( + frozen_model_input, + sampled_token_ids, + self.block_size, + num_seqs, + num_queries, + turn_prefills_into_decodes=turn_prefills_into_decodes) + + return model_input + + def load_model(self) -> None: + self._base_model_runner.load_model() + self.model_memory_usage = self._base_model_runner.model_memory_usage + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + return self._base_model_runner.save_sharded_state( + path, pattern, max_size) + + def save_tensorized_model(self, + tensorizer_config: TensorizerConfig) -> None: + return self._base_model_runner.save_tensorized_model(tensorizer_config) + + def profile_run(self) -> None: + return self._base_model_runner.profile_run() + + def remove_all_loras(self): + return self._base_model_runner.remove_all_loras() + + def capture_model(self, kv_caches: List[List]) -> None: + return self._base_model_runner.capture_model(kv_caches) + + @property + def vocab_size(self) -> int: + return self._base_model_runner.vocab_size + + +DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]], + Optional[List[SampleLogprobs]]] + + +def deferred_pythonize_logprobs( + output: SamplerOutput, + sampling_metadata: SamplingMetadata, + logprobs_tensor: Optional[torch.Tensor], +) -> DeferredLogprobsReturnType: + """Perform deferred logprob Pythonization. + + 1. Pythonize NPU-side sampler result tensors into CPU-side sampler result. + 2. Pythonize NPU-side logprobs tensor into CPU-side logprobs lists, + utilizing the Pythonized sampler result computed in step 1. + + These deferred computations are not required for single-step scheduling + or the `profile_run()` phase of multi-step scheduling. + + Args: + output: sampler output (under deferred Pythonization) + sampling_metadata + + Returns: + prompt_logprobs (CPU), sample_logprobs (CPU) + """ + + # - Deferred pythonization of sample result + sampler_result = get_pythonized_sample_results( + output.deferred_sample_results_args) + + # - Erase the NPU-side deferred sample_result + # computation args to ensure it is never + # pythonized or transferred to CPU + output.deferred_sample_results_args = None + + # - Deferred pythonization of logprobs + ( + prompt_logprobs, + sample_logprobs, + ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result) + assert len(prompt_logprobs) == len(sampling_metadata.seq_groups) + assert len(sample_logprobs) == len(sampling_metadata.seq_groups) + + return prompt_logprobs, sample_logprobs + + +def _pythonize_sampler_output( + model_input: StatefulModelInput, + output: SamplerOutput, + pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: torch.Tensor, + logprobs_tensor: Optional[torch.Tensor], + cache: Optional[PythonizationCache], +) -> None: + """ This function is only called when the output tensors are ready. + See :class:`ModelOutput`. + + Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, + adding a Pythonized output data structure + (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`. + + Args: + model_input + output: sampler output + pinned_sampled_token_token_buffer: CPU-side pinned memory + (receives copy of + NPU-side token buffer.) + sampled_token_ids: NPU-side token buffer + logprobs_tensor: NPU-side tensor containing + logprobs computed during sampling + """ + + assert model_input.frozen_model_input is not None + + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input.sampling_metadata is not None + sampling_metadata = frozen_model_input.sampling_metadata + # samples generation should have been skipped + assert not output.outputs + + pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] + + # We guarantee output tensors are ready, so it is safe to + # pythonize the sampler output & obtain CPU-side logprobs. + # + # However we should check whether logprobs pythonization may + # be skipped entirely, i.e. because no logprobs were requested + # or pythonization was not deferred. To that end, + # + # * `prompt_logprobs_are_requested_for_prefill` signals that + # there are *any* prefill-phase requests which specify that + # prompt logprobs should be returned. + # + # * `any_logprobs_are_requested` signals that there are any + # requests which (1) specify that sample logprobs should be + # returned, or (2) are in the prefill phase AND specify that + # prompt logprobs should be returned. + # + # Later on, these flags cause adjustments to the pythonization + # process to accommodate logprobs. + + seq_groups = sampling_metadata.seq_groups + prompt_logprobs_are_requested_for_prefill = any([ + sg.sampling_params.prompt_logprobs is not None and sg.is_prompt + for sg in seq_groups + ]) + any_logprobs_are_requested = ( + prompt_logprobs_are_requested_for_prefill + or any([sg.sampling_params.logprobs is not None for sg in seq_groups])) + + if prompt_logprobs_are_requested_for_prefill: + # CPU NPU sync, after gathering *only* sampled tokens (since + # requesting prompt logprobs leads `sampled_token_ids` to + # include prompt token ids in addition to sampled token ids.) + sample_idx_tensor = torch.tensor( + [sdx for sg in seq_groups for sdx in sg.sample_indices]) + pinned_buffer = pinned_buffer.copy_( + sampled_token_ids[sample_idx_tensor, :], non_blocking=False) + else: + # CPU NPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, + non_blocking=False) + + # this will not block as the tensors are already on CPU + samples_list = pinned_buffer.tolist() + + skip_sampler_cpu_output = ( + frozen_model_input.sampling_metadata.skip_sampler_cpu_output) + + # *Don't* skip logprobs pythonization *if*: + # * Any requests require logprobs to be returned in this + # iteration AND + # * These requests are being scheduled in a fashion which + # defers pythonization (i.e. multi-step scheduling.) + do_pythonize_logprobs = (skip_sampler_cpu_output + and any_logprobs_are_requested) + ( + prompt_logprobs, + sample_logprobs, + ) = (deferred_pythonize_logprobs(output, sampling_metadata, + logprobs_tensor) + if do_pythonize_logprobs else (None, None)) + + for sgdx, (seq_group, + sample_result) in enumerate(zip(seq_groups, samples_list)): + # Reminder: Please update docs/source/features/compatibility_matrix.md + # If the feature combo become valid + # (Check for Guided Decoding) + if seq_group.sampling_params.logits_processors: + assert len(seq_group.sampling_params.logits_processors) == 0, ( + "Logits Processors are not supported in multi-step decoding") + + if do_pythonize_logprobs: + assert prompt_logprobs is not None + assert sample_logprobs is not None + + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( # Utilize deferred pythonization results + prompt_logprobs[sgdx], + sample_logprobs[sgdx], + ) + elif any_logprobs_are_requested: + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( + # profile_run: use already-computed logprobs + output.outputs[sgdx].prompt_logprobs, + [sample.logprobs for sample in output.outputs[sgdx].samples]) + + seq_ids = seq_group.seq_ids + next_token_ids = sample_result + parent_ids = [0] + seq_outputs: List[SequenceOutput] + + if cache is not None: + completion_seq_group_output: CompletionSequenceGroupOutput = \ + cache.cached_completion_seq_group_output.get_object() + completion_seq_group_output.samples.clear() + seq_outputs = completion_seq_group_output.samples + else: + seq_outputs = [] + + for tdx, (parent_id, + next_token_id) in enumerate(zip(parent_ids, next_token_ids)): + if cache is not None: + seq_output: SequenceOutput = cache.cached_seq_output.get_object( + ) + seq_output.parent_seq_id = seq_ids[parent_id] + seq_output.output_token = next_token_id + + if any_logprobs_are_requested: + seq_output.logprobs = group_sample_logprobs[tdx] + else: + logprobs = next(iter(seq_output.logprobs.values())) + seq_output.logprobs.clear() + + logprobs.logprob = float('inf') + logprobs.rank = None + logprobs.decoded_token = None + + seq_output.logprobs[next_token_id] = logprobs + + seq_outputs.append(seq_output) + + else: + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + (group_sample_logprobs[tdx] + if any_logprobs_are_requested else { + next_token_id: + Logprob(logprob=float('inf'), + rank=None, + decoded_token=None) + }))) + if cache is not None: + completion_seq_group_output.prompt_logprobs = \ + group_prompt_logprobs if any_logprobs_are_requested else None + output.outputs.append(completion_seq_group_output) + else: + output.outputs.append( + CompletionSequenceGroupOutput( + seq_outputs, (group_prompt_logprobs + if any_logprobs_are_requested else None))) + + assert len(output.outputs) > 0 diff --git a/vllm_ascend/worker/multi_step_worker.py b/vllm_ascend/worker/multi_step_worker.py new file mode 100644 index 0000000..ba83f6b --- /dev/null +++ b/vllm_ascend/worker/multi_step_worker.py @@ -0,0 +1,194 @@ +import dataclasses +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.worker.model_runner_base import BroadcastableModelInput +from vllm.worker.multi_step_model_runner import StatefulModelInput + +from vllm_ascend.worker.multi_step_runner import MultiStepModelNPURunner +from vllm_ascend.worker.worker import NPUWorker, WorkerInput + + +@dataclass +class MultiStepState: + worker_input: WorkerInput + model_input: StatefulModelInput + + +class MultiStepWorker(NPUWorker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + base_model_runner = self.model_runner + # for multi-step model, wrap the model runner with MultiStepModelRunner + self.model_runner = MultiStepModelNPURunner( + base_model_runner, + vllm_config=base_model_runner.vllm_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=base_model_runner.is_driver_worker, + ) + + pipeline_parallel_size = self.parallel_config.pipeline_parallel_size + self.multi_step_states: List[ + Optional[MultiStepState]] = [None] * pipeline_parallel_size + self.temp_output = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: + """ + Get the driver input and broadcast it to other workers. + """ + assert self.is_driver_worker + virtual_engine = execute_model_req.virtual_engine + is_first_multi_step = execute_model_req.is_first_multi_step + if is_first_multi_step: + # on first step we prepare the worker input and model input normally + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: StatefulModelInput = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=execute_model_req.async_callback) + else: + # on subsequent steps we reuse the worker input and model input + multi_step_state = self.multi_step_states[virtual_engine] + worker_input = multi_step_state.worker_input + model_input = multi_step_state.model_input + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + # clear the cached metadata so that it can be recomputed on + # the workers. + frozen_model_input.attn_metadata._cached_prefill_metadata = None + frozen_model_input.attn_metadata._cached_decode_metadata = None + + model_input.is_first_multi_step = is_first_multi_step + model_input.is_last_step = execute_model_req.is_last_step + + if not is_first_multi_step: + # we broadcast the last sampled token ids to all TP workers so they + # can update their model input metadata in-place. + self._prepare_last_sampled_token_ids_for_tp_workers( + execute_model_req=execute_model_req, model_input=model_input) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} + + def _prepare_last_sampled_token_ids_for_tp_workers( + self, + execute_model_req: ExecuteModelRequest, + model_input: StatefulModelInput, + ) -> None: + """ + Prepare the last sampled token ids for TP workers. If it's the last + PP rank, then the last sampled token ids are already in the model_input. + If it is NOT the last PP rank, then we need to get the last sampled + token that is cached in the execute_model_req. + """ + if get_pp_group().is_last_rank: + assert model_input.cached_outputs[ + -1].sampler_output.sampled_token_ids is None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + model_input.last_sampled_token_ids = model_input.cached_outputs[ + -1].sampled_token_ids + # free sampled token ids from the previous step if it has been + # pythonized. Cannot free the last sampled token ids because + # we need it for GPU advance_step. + for output in model_input.cached_outputs[:-1]: + if output.pythonized: + output.sampled_token_ids = None + else: + # otherwise we need to get the cached sampled token ids from the + # execute_model_req + assert execute_model_req.last_sampled_token_ids is not None + model_input.last_sampled_token_ids = ( + execute_model_req.last_sampled_token_ids.cuda()) + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + # free sampled token ids from the previous step. + # TODO(will) we could reuse the sampled token ids tensor from + # the previous step instead. + for output in model_input.cached_outputs[:-1]: + output.sampled_token_ids = None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str, + torch.Tensor]]]: + """ + Depending on the current state of the request and multi step worker, + this method may skip the normal _prepare_model_input and + _prepare_worker_input methods and instead used cached values. + """ + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + virtual_engine = execute_model_req.virtual_engine + (model_input, worker_input, + kwargs) = self._get_driver_input_and_broadcast(execute_model_req) + assert isinstance(model_input, StatefulModelInput) + if execute_model_req.is_first_multi_step: + # cache the worker input and model input for the next steps + self.multi_step_states[virtual_engine] = MultiStepState( + worker_input=worker_input, model_input=model_input) + # if TP workers + else: + broadcast_data = self._get_worker_input_from_broadcast() + # if the driver has sent an empty input, we should stop the worker + # loop + if broadcast_data is None: + return None + model_input, worker_input, kwargs = broadcast_data + assert isinstance(model_input, StatefulModelInput) + virtual_engine = worker_input.virtual_engine + if model_input.is_first_multi_step: + pass + # TODO(will) Can cache the worker input and model input for the + # next steps. See below for details + else: + # TODO(will) possible to also cache and reuse the cached worker + # input and model input. The idea is essentially the delta + # optimization for model_inputs. Where the TP workers can cache + # the model input states and we only broadcast the delta need + # for the next step (sampled_token_ids from the previous step) + + assert isinstance(model_input, StatefulModelInput) + # we need to update the last sampled token ids in the model + # input for the workers so that they can run inplace + # advance_step + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + assert model_input is not None + assert worker_input is not None + return model_input, worker_input, kwargs