From b005def0a5791075e0828a1ccb07fb6f11fe4598 Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Wed, 16 Jul 2025 14:06:49 +0800 Subject: [PATCH] [Misc][V0 Deprecation] Remove Multi-Step Model Runner (#1820) ### What this PR does / why we need it? Remove multi-step model runner. This PR is a part of https://github.com/vllm-project/vllm-ascend/issues/1620. - vLLM version: v0.9.2 - vLLM main: https://github.com/vllm-project/vllm/commit/34cda778a091d4e1fd204cfde4a0f5e2b5616ac2 Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/worker/multi_step_runner.py | 737 ------------------------ 1 file changed, 737 deletions(-) delete mode 100644 vllm_ascend/worker/multi_step_runner.py diff --git a/vllm_ascend/worker/multi_step_runner.py b/vllm_ascend/worker/multi_step_runner.py deleted file mode 100644 index 028bcd0..0000000 --- a/vllm_ascend/worker/multi_step_runner.py +++ /dev/null @@ -1,737 +0,0 @@ -import dataclasses -import functools -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Union) - -import torch -from torch import nn -from vllm.distributed import get_pp_group -from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, - SamplerOutput, - SamplingMetadata, get_logprobs, - get_pythonized_sample_results) -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SequenceGroupMetadata, SequenceOutput) -from vllm.worker.model_runner_base import ( - _init_attn_metadata_from_tensor_dict, - _init_frozen_model_input_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict) -from vllm.worker.multi_step_model_runner import (ModelOutput, - PythonizationCache, - StatefulModelInput) - -from vllm_ascend.utils import current_stream -from vllm_ascend.worker.model_runner import ( - ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase) - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - - -@dataclass(frozen=False) -class StatefulModelInputForNPU(StatefulModelInput): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def record_step_event(self, current_stream: torch.npu.Stream): - # record the event for the current step so that the next step can sync - # on it. We modulo by 2 to keep the events in a circular buffer and - # support any attn backends that may be supported in the future. ie - # Flashinfer would want two DecodeWrappers to overlap the CPU and NPU. - self.step_cuda_events[self.current_step & 1] = \ - torch.npu.Event(blocking=True) - self.step_cuda_events[self.current_step & 1].record(current_stream) - - # actual frozen model input dataclass passed to _base_model_runner - frozen_model_input: Optional[ModelInputForNPUWithSamplingMetadata] = None - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "StatefulModelInputForNPU": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - tensor_dict = _init_frozen_model_input_from_tensor_dict( - ModelInputForNPUWithSamplingMetadata, tensor_dict) - return cls(**tensor_dict) - - def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): - """ - Advancing the datastructures of StatefulModelInput::frozen_model_input - is only required when prefills are scheduled with decodes to run in - multi-step. This advancement/correction is required to account for - the conversion of Prefills to Decodes after the first multi-step. - """ - if self.current_step != 1 or self.num_single_step_prefills == 0: - return - - assert self.frozen_model_input is not None - fmi = self.frozen_model_input - - # Truncate input_tokens - assert fmi.input_tokens is not None - assert fmi.input_tokens.shape[0] >= self.num_seqs - fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs] - - # Update frozen_model_input::input_positons. - assert fmi.input_positions is not None - assert fmi.input_positions.shape[0] >= self.num_seqs - fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self. - num_seqs] - - # Assert unsupported - # TODO Uncomment the following codes when NPU supported - # assert fmi.lora_mapping is None - # assert fmi.lora_requests is not None - # assert len(fmi.lora_requests) == 0 - # assert fmi.prompt_adapter_mapping is None - # assert fmi.prompt_adapter_requests is not None - # assert len(fmi.prompt_adapter_requests) == 0 - assert fmi.attn_metadata is not None - assert fmi.multi_modal_kwargs is not None - assert len(fmi.multi_modal_kwargs) == 0 - - self.frozen_model_input = dataclasses.replace( - self.frozen_model_input, - input_tokens=fmi_new_input_tokens, - input_positions=fmi_new_input_positions) - - self.maybe_advance_sampling_metadata(device, pin_memory) - - -@dataclass(frozen=False) -class NPUModelOutput(ModelOutput): - - logprobs: Optional["torch.Tensor"] = None - - def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", - copy_stream: torch.npu.Stream, - pinned_sampled_token_buffer: torch.Tensor, - blocking: bool) -> bool: - """ - If blocking is set, will block until the forward pass for the output is - ready and pythonize the output. Upon completing Pythonization, erases - self.logprobs (note that a non-blocking call that is performed when - the sampler output is not yet ready, will not erase self.logprobs.) - """ - assert self.sampled_token_ids is not None - if not blocking and not self.sampler_output_ready_event.query(): - return False - - if blocking: - self.sampler_output_ready_event.synchronize() - with torch.npu.stream(copy_stream): - _pythonize_sampler_output(input_metadata, self.sampler_output, - pinned_sampled_token_buffer, - self.sampled_token_ids, self.logprobs, - self.pythonization_cache) - - # Erase the logprobs GPU-side tensor. - # Note that although _pythonize_sampler_output() runs in its - # own CUDA stream, nonetheless _pythonize_sampler_output() - # cannot return until Pythonization is complete; therefore - # we know that by the time the CPU reaches this point, - # `self.logprobs` is no longer needed. - self.logprobs = None - return True - - -class MultiStepModelNPURunner(NPUModelRunnerBase[StatefulModelInputForNPU]): - # mypy: enable-error-code=type-var - - def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs): - super().__init__(*args, **kwargs) - - # uses the base model runner to execute the model and wraps it with - # multi-step logic - self._base_model_runner: NPUModelRunnerBase = base_model_runner - - self.is_multi_step = self.scheduler_config.is_multi_step - self.pinned_sampled_token_ids: Optional[torch.Tensor] = None - - # Using the PythonizationCache in Pipeline-Parallel clobbers the - # SequenceOutput and CompletionSequenceGroupOutput object. - # When cache-reset happens at the last step of a multi-step - # execution, there may be other on-going single-step/multi-step - # executions. The current caching implementation does not check - # for this. - self.pythonization_cache = PythonizationCache() \ - if self.parallel_config.pipeline_parallel_size == 1 else None - - def get_model(self) -> nn.Module: - return self._base_model_runner.get_model() - - @functools.cached_property - def _copy_stream(self): - # used to copy tensors from NPU to CPU asynchronously - return torch.npu.Stream() - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> StatefulModelInputForNPU: - model_input = (StatefulModelInputForNPU.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - )) - return model_input - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> StatefulModelInputForNPU: - frozen_model_input: ModelInputForNPUWithSamplingMetadata = \ - self._base_model_runner.prepare_model_input( - seq_group_metadata_list, - virtual_engine, - finished_requests_ids) - - assert frozen_model_input.query_lens is not None - assert frozen_model_input.seq_lens is not None - assert frozen_model_input.attn_metadata is not None - num_queries = len(frozen_model_input.query_lens) - num_seqs = len(frozen_model_input.seq_lens) - num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills - - model_input = StatefulModelInputForNPU( - frozen_model_input=frozen_model_input, - num_seqs=num_seqs, - num_queries=num_queries, - num_single_step_prefills=num_single_step_prefills, - step_cuda_events=[torch.npu.Event(blocking=True)] * 2, - ) - - return model_input - - def _async_process_outputs(self, model_input: StatefulModelInputForNPU, - output_proc_callback: Callable): - # Proceed with pythonization and output_proc in order. - # Stop on the first one that fails to pythonize - output_proc_callback() - - cont = True - for step_num, model_output in enumerate(model_input.cached_outputs): - if not model_output.pythonized: - model_output.maybe_pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) - if model_output.pythonized: - ctx = output_proc_callback.keywords["ctx"] # type: ignore - ctx.append_output( - outputs=[model_output.sampler_output], - seq_group_metadata_list=ctx.seq_group_metadata_list, - scheduler_outputs=ctx.scheduler_outputs, - is_async=False, - is_last_step=False, - is_first_step_output=step_num == 0) - - output_proc_callback() - else: - cont = False - - if not cont: - break - - def _final_process_outputs( - self, model_input: StatefulModelInputForNPU, - output_proc_callback: Optional[Callable]) -> List[SamplerOutput]: - assert model_input.frozen_model_input is not None - - has_async_callback = output_proc_callback is not None - - outputs = [] - for step_num, output in enumerate(model_input.cached_outputs): - is_last_step = step_num == len(model_input.cached_outputs) - 1 - - # For non-async case: - # -- We simply add the outputs - # For async case: - # -- Invoke callback, pythonize, add to callback queue and repeat - # -- For last output, just add to callback queue - if has_async_callback: - assert output_proc_callback is not None - - # Invoke callback before pythonize (to overlap with NPU) - output_proc_callback() - - # Pythonize - if not output.pythonized: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) - - # For non last step, add to callback queue to chain - # callbacks=>pythonize pairs (for NPU overlap) - if not is_last_step: - ctx = output_proc_callback.keywords[ # type: ignore - "ctx"] # type: ignore - ctx.append_output( - outputs=[output.sampler_output], - seq_group_metadata_list=ctx. - seq_group_metadata_list, - scheduler_outputs=ctx.scheduler_outputs, - is_async=False, - is_last_step=False, - is_first_step_output=step_num == 0) - else: - outputs.append(output.sampler_output) - else: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) - outputs.append(output.sampler_output) - - return outputs - - @torch.inference_mode() - def execute_model( - self, - model_input: StatefulModelInputForNPU, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - """ - Execute the model for a single step and update multi-step - metadata - """ - assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1" - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - - # path for warm up runs - if not model_input.is_multi_step: - return self._base_model_runner.execute_model( - frozen_model_input, kv_caches, intermediate_tensors, num_steps) - - # make sure we skip the sampler on the lask rank and only pythonize - # if CPU is ahead. - if self.is_driver_worker and get_pp_group().is_last_rank: - if self.pinned_sampled_token_ids is None: - self.pinned_sampled_token_ids = torch.zeros( - (self.scheduler_config.max_num_seqs, 1), - dtype=torch.long, - device="cpu", - pin_memory=True) - self._base_model_runner.sampler.include_gpu_probs_tensor = True - if frozen_model_input.sampling_metadata: - frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( - True) - - # some pre-execute model logic for multi-step: - # - if it's the first step, we need to reset the sampling tensors - # - if it's not the first step, we need to advance the step using the - # appended sampler output from last iteration - # - also maybe pythonize if CPU is ahead of NPU - - stream = current_stream() - if not model_input.is_first_multi_step: - # Explicitly block on the previous step's forward to make sure we - # don't clobber any NPU tensors still in use. - # This is not needed for flashattn backend, but for other attn - # backends such as flashinfer that performs extra CPU operations on - # input metadata we may need to synchronize any CPU operations that - # might clobber enqueued forwards. (prevents CPU from running too - # far ahead if needed) - model_input.wait_previous_step() - model_input = self._advance_step( - model_input, model_input.cached_outputs[-1].sampler_output) - - # frozen_model_input may have been updated - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - - if model_input.base_output_proc_callback is None: - assert frozen_model_input is not None - model_input.base_output_proc_callback = \ - frozen_model_input.async_callback - - if frozen_model_input.async_callback is not None: - assert model_input.base_output_proc_callback is not None - async_callback = functools.partial( - self._async_process_outputs, - model_input=model_input, - output_proc_callback=model_input.base_output_proc_callback) - - model_input.frozen_model_input = dataclasses.replace( # type: ignore - model_input.frozen_model_input, - async_callback=async_callback) - # Update the local instance - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - - # Execute the model - output = self._base_model_runner.execute_model(frozen_model_input, - kv_caches, - intermediate_tensors, - num_steps=1) - - # record the event for the current step so that the next step can sync - model_input.record_step_event(stream) - - if get_pp_group().is_last_rank and self.is_driver_worker: - assert isinstance(output, list) - assert len( - output - ) == 1, "MultiStepModelRunner requires single-step base_models" - - # event for the pythonization so that we only pythonize if the - # tensors are ready. May be able to be combined with the step event - output_ready_event = torch.npu.Event() - output_ready_event.record(stream) - if self.parallel_config.pipeline_parallel_size > 1: - output[0].sampled_token_ids_cpu = output[ - 0].sampled_token_ids.cpu() - model_input.cached_outputs.append( - NPUModelOutput(output[0], output_ready_event, - output[0].sampled_token_ids, False, - output[0].logprobs, self.pythonization_cache)) - - # These NPU tensors are not required by multi-step; - # erase them to ensure they are not pythonized or - # transferred to CPU - output[0].sampled_token_ids = None - output[0].sampled_token_probs = None - output[0].logprobs = None - - # Pythonize the output if CPU is ahead and the previous step is - # ready. - if frozen_model_input.async_callback is None: - for model_output in model_input.cached_outputs: - model_output.maybe_pythonize(model_input, - self._copy_stream, - self.pinned_sampled_token_ids) - - model_input.current_step += 1 - - if not get_pp_group().is_last_rank: - # Should be IntermediateTensors - assert isinstance(output, IntermediateTensors) - return output - if not self.is_driver_worker: - return [] - - # Pythonize the output and block if needed since it is the last step - if model_input.is_last_step: - outputs = self._final_process_outputs( - model_input, model_input.base_output_proc_callback) - if self.pythonization_cache: - self.pythonization_cache.reset() - return outputs - - # should be [SamplerOutput] - return output - - def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata, - num_seqs: Optional[int], num_queries: int): - - assert sampling_metadata.num_prompts == 0 - assert len(sampling_metadata.seq_groups) == num_queries - assert sampling_metadata.selected_token_indices.shape == ( - num_queries, ) - # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 - - # Verify that all sequences are decodes - for i in range(num_queries): - seq_group = sampling_metadata.seq_groups[i] - - assert seq_group.is_prompt is False # No prompt - assert seq_group.prompt_logprob_indices == [] # No prompt - assert seq_group.sample_indices == [i] # Simple - assert seq_group.seq_len is None # Decode - assert seq_group.query_len is None # Decode - - def _advance_step(self, model_input: StatefulModelInputForNPU, - out: SamplerOutput) -> StatefulModelInputForNPU: - - model_input.maybe_advance_frozen_model_input(self.device, - self.pin_memory) - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - assert frozen_model_input.input_tokens is not None - assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs - assert frozen_model_input.attn_metadata is not None - - sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids - num_seqs = model_input.num_seqs - num_queries = model_input.num_queries - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - attn_metadata = frozen_model_input.attn_metadata - assert attn_metadata is not None - - turn_prefills_into_decodes: bool = model_input.current_step == 1 and \ - model_input.num_single_step_prefills != 0 - attn_metadata.advance_step( - frozen_model_input, - sampled_token_ids, - self.block_size, - num_seqs, - num_queries, - turn_prefills_into_decodes=turn_prefills_into_decodes) - - return model_input - - def load_model(self) -> None: - self._base_model_runner.load_model() - self.model_memory_usage = self._base_model_runner.model_memory_usage - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - return self._base_model_runner.save_sharded_state( - path, pattern, max_size) - - def save_tensorized_model(self, - tensorizer_config: TensorizerConfig) -> None: - return self._base_model_runner.save_tensorized_model(tensorizer_config) - - def profile_run(self) -> None: - return self._base_model_runner.profile_run() - - def remove_all_loras(self): - return self._base_model_runner.remove_all_loras() - - def capture_model(self, kv_caches: List[List]) -> None: - return self._base_model_runner.capture_model(kv_caches) - - @property - def vocab_size(self) -> int: - return self._base_model_runner.vocab_size - - -DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]], - Optional[List[SampleLogprobs]]] - - -def deferred_pythonize_logprobs( - output: SamplerOutput, - sampling_metadata: SamplingMetadata, - logprobs_tensor: Optional[torch.Tensor], -) -> DeferredLogprobsReturnType: - """Perform deferred logprob Pythonization. - - 1. Pythonize NPU-side sampler result tensors into CPU-side sampler result. - 2. Pythonize NPU-side logprobs tensor into CPU-side logprobs lists, - utilizing the Pythonized sampler result computed in step 1. - - These deferred computations are not required for single-step scheduling - or the `profile_run()` phase of multi-step scheduling. - - Args: - output: sampler output (under deferred Pythonization) - sampling_metadata - - Returns: - prompt_logprobs (CPU), sample_logprobs (CPU) - """ - - # - Deferred pythonization of sample result - sampler_result = get_pythonized_sample_results( - output.deferred_sample_results_args) - - # - Erase the NPU-side deferred sample_result - # computation args to ensure it is never - # pythonized or transferred to CPU - output.deferred_sample_results_args = None - - # - Deferred pythonization of logprobs - ( - prompt_logprobs, - sample_logprobs, - ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result) - assert len(prompt_logprobs) == len(sampling_metadata.seq_groups) - assert len(sample_logprobs) == len(sampling_metadata.seq_groups) - - return prompt_logprobs, sample_logprobs - - -def _pythonize_sampler_output( - model_input: StatefulModelInputForNPU, - output: SamplerOutput, - pinned_sampled_token_buffer: torch.Tensor, - sampled_token_ids: torch.Tensor, - logprobs_tensor: Optional[torch.Tensor], - cache: Optional[PythonizationCache], -) -> None: - """ This function is only called when the output tensors are ready. - See :class:`ModelOutput`. - - Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, - adding a Pythonized output data structure - (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`. - - Args: - model_input - output: sampler output - pinned_sampled_token_token_buffer: CPU-side pinned memory - (receives copy of - NPU-side token buffer.) - sampled_token_ids: NPU-side token buffer - logprobs_tensor: NPU-side tensor containing - logprobs computed during sampling - """ - - assert model_input.frozen_model_input is not None - - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input.sampling_metadata is not None - sampling_metadata = frozen_model_input.sampling_metadata - # samples generation should have been skipped - assert not output.outputs - - pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] - - # We guarantee output tensors are ready, so it is safe to - # pythonize the sampler output & obtain CPU-side logprobs. - # - # However we should check whether logprobs pythonization may - # be skipped entirely, i.e. because no logprobs were requested - # or pythonization was not deferred. To that end, - # - # * `prompt_logprobs_are_requested_for_prefill` signals that - # there are *any* prefill-phase requests which specify that - # prompt logprobs should be returned. - # - # * `any_logprobs_are_requested` signals that there are any - # requests which (1) specify that sample logprobs should be - # returned, or (2) are in the prefill phase AND specify that - # prompt logprobs should be returned. - # - # Later on, these flags cause adjustments to the pythonization - # process to accommodate logprobs. - - seq_groups = sampling_metadata.seq_groups - prompt_logprobs_are_requested_for_prefill = any([ - sg.sampling_params.prompt_logprobs is not None and sg.is_prompt - for sg in seq_groups - ]) - any_logprobs_are_requested = ( - prompt_logprobs_are_requested_for_prefill - or any([sg.sampling_params.logprobs is not None for sg in seq_groups])) - - if prompt_logprobs_are_requested_for_prefill: - # CPU NPU sync, after gathering *only* sampled tokens (since - # requesting prompt logprobs leads `sampled_token_ids` to - # include prompt token ids in addition to sampled token ids.) - sample_idx_tensor = torch.tensor( - [sdx for sg in seq_groups for sdx in sg.sample_indices]) - pinned_buffer = pinned_buffer.copy_( - sampled_token_ids[sample_idx_tensor, :], non_blocking=False) - else: - # CPU NPU sync - pinned_buffer = pinned_buffer.copy_(sampled_token_ids, - non_blocking=False) - - # this will not block as the tensors are already on CPU - samples_list = pinned_buffer.tolist() - - skip_sampler_cpu_output = ( - frozen_model_input.sampling_metadata.skip_sampler_cpu_output) - - # *Don't* skip logprobs pythonization *if*: - # * Any requests require logprobs to be returned in this - # iteration AND - # * These requests are being scheduled in a fashion which - # defers pythonization (i.e. multi-step scheduling.) - do_pythonize_logprobs = (skip_sampler_cpu_output - and any_logprobs_are_requested) - ( - prompt_logprobs, - sample_logprobs, - ) = (deferred_pythonize_logprobs(output, sampling_metadata, - logprobs_tensor) - if do_pythonize_logprobs else (None, None)) - - for sgdx, (seq_group, - sample_result) in enumerate(zip(seq_groups, samples_list)): - # Reminder: Please update docs/source/features/compatibility_matrix.md - # If the feature combo become valid - # (Check for Guided Decoding) - if seq_group.sampling_params.logits_processors: - assert len(seq_group.sampling_params.logits_processors) == 0, ( - "Logits Processors are not supported in multi-step decoding") - - if do_pythonize_logprobs: - assert prompt_logprobs is not None - assert sample_logprobs is not None - - ( - group_prompt_logprobs, - group_sample_logprobs, - ) = ( # Utilize deferred pythonization results - prompt_logprobs[sgdx], - sample_logprobs[sgdx], - ) - elif any_logprobs_are_requested: - ( - group_prompt_logprobs, - group_sample_logprobs, - ) = ( - # profile_run: use already-computed logprobs - output.outputs[sgdx].prompt_logprobs, - [sample.logprobs for sample in output.outputs[sgdx].samples]) - - seq_ids = seq_group.seq_ids - next_token_ids = sample_result - parent_ids = [0] - seq_outputs: List[SequenceOutput] - - if cache is not None: - completion_seq_group_output: CompletionSequenceGroupOutput = \ - cache.cached_completion_seq_group_output.get_object() - completion_seq_group_output.samples.clear() - seq_outputs = completion_seq_group_output.samples - else: - seq_outputs = [] - - for tdx, (parent_id, - next_token_id) in enumerate(zip(parent_ids, next_token_ids)): - if cache is not None: - seq_output: SequenceOutput = cache.cached_seq_output.get_object( - ) - seq_output.parent_seq_id = seq_ids[parent_id] - seq_output.output_token = next_token_id - - if any_logprobs_are_requested: - seq_output.logprobs = group_sample_logprobs[tdx] - else: - logprobs = next(iter(seq_output.logprobs.values())) - seq_output.logprobs.clear() - - logprobs.logprob = float('inf') - logprobs.rank = None - logprobs.decoded_token = None - - seq_output.logprobs[next_token_id] = logprobs - - seq_outputs.append(seq_output) - - else: - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, - (group_sample_logprobs[tdx] - if any_logprobs_are_requested else { - next_token_id: - Logprob(logprob=float('inf'), - rank=None, - decoded_token=None) - }))) - if cache is not None: - completion_seq_group_output.prompt_logprobs = \ - group_prompt_logprobs if any_logprobs_are_requested else None - output.outputs.append(completion_seq_group_output) - else: - output.outputs.append( - CompletionSequenceGroupOutput( - seq_outputs, (group_prompt_logprobs - if any_logprobs_are_requested else None))) - - assert len(output.outputs) > 0