Files
xc-llm-ascend/vllm_ascend/worker/multi_step_runner.py
Wan_Danfeng 5cf9ff18e9 [Performance]: Custom AscendC Kernel of Multi-Step Prepare Input (#814)
### What this PR does / why we need it?

- According to https://github.com/vllm-project/vllm-ascend/issues/807,
we pull request for customer ascendc kernel of multi-step.
- also a bug we found in multi_step_runner.py is fixed when we use
multi-step on V0 Engine.


### Does this PR introduce _any_ user-facing change?

no user-facing change


### How was this patch tested?
we add Unit Test file and offline inference file to test the custom
ascendc kernel. See test/ops/test_multi_step.py and
examples/offline_multi_step.py

---------

Signed-off-by: wan_danfeng <wonderful199082@126.com>
2025-05-20 09:31:30 +08:00

738 lines
31 KiB
Python

import dataclasses
import functools
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
import torch
from torch import nn
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
SamplerOutput,
SamplingMetadata, get_logprobs,
get_pythonized_sample_results)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict,
_init_frozen_model_input_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from vllm.worker.multi_step_model_runner import (ModelOutput,
PythonizationCache,
StatefulModelInput)
from vllm_ascend.utils import current_stream
from vllm_ascend.worker.model_runner import (
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@dataclass(frozen=False)
class StatefulModelInputForNPU(StatefulModelInput):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def record_step_event(self, current_stream: torch.npu.Stream):
# record the event for the current step so that the next step can sync
# on it. We modulo by 2 to keep the events in a circular buffer and
# support any attn backends that may be supported in the future. ie
# Flashinfer would want two DecodeWrappers to overlap the CPU and NPU.
self.step_cuda_events[self.current_step & 1] = \
torch.npu.Event(blocking=True)
self.step_cuda_events[self.current_step & 1].record(current_stream)
# actual frozen model input dataclass passed to _base_model_runner
frozen_model_input: Optional[ModelInputForNPUWithSamplingMetadata] = None
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "StatefulModelInputForNPU":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
tensor_dict = _init_frozen_model_input_from_tensor_dict(
ModelInputForNPUWithSamplingMetadata, tensor_dict)
return cls(**tensor_dict)
def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool):
"""
Advancing the datastructures of StatefulModelInput::frozen_model_input
is only required when prefills are scheduled with decodes to run in
multi-step. This advancement/correction is required to account for
the conversion of Prefills to Decodes after the first multi-step.
"""
if self.current_step != 1 or self.num_single_step_prefills == 0:
return
assert self.frozen_model_input is not None
fmi = self.frozen_model_input
# Truncate input_tokens
assert fmi.input_tokens is not None
assert fmi.input_tokens.shape[0] >= self.num_seqs
fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs]
# Update frozen_model_input::input_positons.
assert fmi.input_positions is not None
assert fmi.input_positions.shape[0] >= self.num_seqs
fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self.
num_seqs]
# Assert unsupported
# TODO Uncomment the following codes when NPU supported
# assert fmi.lora_mapping is None
# assert fmi.lora_requests is not None
# assert len(fmi.lora_requests) == 0
# assert fmi.prompt_adapter_mapping is None
# assert fmi.prompt_adapter_requests is not None
# assert len(fmi.prompt_adapter_requests) == 0
assert fmi.attn_metadata is not None
assert fmi.multi_modal_kwargs is not None
assert len(fmi.multi_modal_kwargs) == 0
self.frozen_model_input = dataclasses.replace(
self.frozen_model_input,
input_tokens=fmi_new_input_tokens,
input_positions=fmi_new_input_positions)
self.maybe_advance_sampling_metadata(device, pin_memory)
@dataclass(frozen=False)
class NPUModelOutput(ModelOutput):
logprobs: Optional["torch.Tensor"] = None
def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput",
copy_stream: torch.npu.Stream,
pinned_sampled_token_buffer: torch.Tensor,
blocking: bool) -> bool:
"""
If blocking is set, will block until the forward pass for the output is
ready and pythonize the output. Upon completing Pythonization, erases
self.logprobs (note that a non-blocking call that is performed when
the sampler output is not yet ready, will not erase self.logprobs.)
"""
assert self.sampled_token_ids is not None
if not blocking and not self.sampler_output_ready_event.query():
return False
if blocking:
self.sampler_output_ready_event.synchronize()
with torch.npu.stream(copy_stream):
_pythonize_sampler_output(input_metadata, self.sampler_output,
pinned_sampled_token_buffer,
self.sampled_token_ids, self.logprobs,
self.pythonization_cache)
# Erase the logprobs GPU-side tensor.
# Note that although _pythonize_sampler_output() runs in its
# own CUDA stream, nonetheless _pythonize_sampler_output()
# cannot return until Pythonization is complete; therefore
# we know that by the time the CPU reaches this point,
# `self.logprobs` is no longer needed.
self.logprobs = None
return True
class MultiStepModelNPURunner(NPUModelRunnerBase[StatefulModelInputForNPU]):
# mypy: enable-error-code=type-var
def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs):
super().__init__(*args, **kwargs)
# uses the base model runner to execute the model and wraps it with
# multi-step logic
self._base_model_runner: NPUModelRunnerBase = base_model_runner
self.is_multi_step = self.scheduler_config.is_multi_step
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
# Using the PythonizationCache in Pipeline-Parallel clobbers the
# SequenceOutput and CompletionSequenceGroupOutput object.
# When cache-reset happens at the last step of a multi-step
# execution, there may be other on-going single-step/multi-step
# executions. The current caching implementation does not check
# for this.
self.pythonization_cache = PythonizationCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
def get_model(self) -> nn.Module:
return self._base_model_runner.get_model()
@functools.cached_property
def _copy_stream(self):
# used to copy tensors from NPU to CPU asynchronously
return torch.npu.Stream()
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> StatefulModelInputForNPU:
model_input = (StatefulModelInputForNPU.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
return model_input
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> StatefulModelInputForNPU:
frozen_model_input: ModelInputForNPUWithSamplingMetadata = \
self._base_model_runner.prepare_model_input(
seq_group_metadata_list,
virtual_engine,
finished_requests_ids)
assert frozen_model_input.query_lens is not None
assert frozen_model_input.seq_lens is not None
assert frozen_model_input.attn_metadata is not None
num_queries = len(frozen_model_input.query_lens)
num_seqs = len(frozen_model_input.seq_lens)
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
model_input = StatefulModelInputForNPU(
frozen_model_input=frozen_model_input,
num_seqs=num_seqs,
num_queries=num_queries,
num_single_step_prefills=num_single_step_prefills,
step_cuda_events=[torch.npu.Event(blocking=True)] * 2,
)
return model_input
def _async_process_outputs(self, model_input: StatefulModelInputForNPU,
output_proc_callback: Callable):
# Proceed with pythonization and output_proc in order.
# Stop on the first one that fails to pythonize
output_proc_callback()
cont = True
for step_num, model_output in enumerate(model_input.cached_outputs):
if not model_output.pythonized:
model_output.maybe_pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
if model_output.pythonized:
ctx = output_proc_callback.keywords["ctx"] # type: ignore
ctx.append_output(
outputs=[model_output.sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=step_num == 0)
output_proc_callback()
else:
cont = False
if not cont:
break
def _final_process_outputs(
self, model_input: StatefulModelInputForNPU,
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
assert model_input.frozen_model_input is not None
has_async_callback = output_proc_callback is not None
outputs = []
for step_num, output in enumerate(model_input.cached_outputs):
is_last_step = step_num == len(model_input.cached_outputs) - 1
# For non-async case:
# -- We simply add the outputs
# For async case:
# -- Invoke callback, pythonize, add to callback queue and repeat
# -- For last output, just add to callback queue
if has_async_callback:
assert output_proc_callback is not None
# Invoke callback before pythonize (to overlap with NPU)
output_proc_callback()
# Pythonize
if not output.pythonized:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
# For non last step, add to callback queue to chain
# callbacks=>pythonize pairs (for NPU overlap)
if not is_last_step:
ctx = output_proc_callback.keywords[ # type: ignore
"ctx"] # type: ignore
ctx.append_output(
outputs=[output.sampler_output],
seq_group_metadata_list=ctx.
seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=step_num == 0)
else:
outputs.append(output.sampler_output)
else:
output.pythonize(model_input, self._copy_stream,
self.pinned_sampled_token_ids)
outputs.append(output.sampler_output)
return outputs
@torch.inference_mode()
def execute_model(
self,
model_input: StatefulModelInputForNPU,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
"""
Execute the model for a single step and update multi-step
metadata
"""
assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
# path for warm up runs
if not model_input.is_multi_step:
return self._base_model_runner.execute_model(
frozen_model_input, kv_caches, intermediate_tensors, num_steps)
# make sure we skip the sampler on the lask rank and only pythonize
# if CPU is ahead.
if self.is_driver_worker and get_pp_group().is_last_rank:
if self.pinned_sampled_token_ids is None:
self.pinned_sampled_token_ids = torch.zeros(
(self.scheduler_config.max_num_seqs, 1),
dtype=torch.long,
device="cpu",
pin_memory=True)
self._base_model_runner.sampler.include_gpu_probs_tensor = True
if frozen_model_input.sampling_metadata:
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
True)
# some pre-execute model logic for multi-step:
# - if it's the first step, we need to reset the sampling tensors
# - if it's not the first step, we need to advance the step using the
# appended sampler output from last iteration
# - also maybe pythonize if CPU is ahead of NPU
stream = current_stream()
if not model_input.is_first_multi_step:
# Explicitly block on the previous step's forward to make sure we
# don't clobber any NPU tensors still in use.
# This is not needed for flashattn backend, but for other attn
# backends such as flashinfer that performs extra CPU operations on
# input metadata we may need to synchronize any CPU operations that
# might clobber enqueued forwards. (prevents CPU from running too
# far ahead if needed)
model_input.wait_previous_step()
model_input = self._advance_step(
model_input, model_input.cached_outputs[-1].sampler_output)
# frozen_model_input may have been updated
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
if model_input.base_output_proc_callback is None:
assert frozen_model_input is not None
model_input.base_output_proc_callback = \
frozen_model_input.async_callback
if frozen_model_input.async_callback is not None:
assert model_input.base_output_proc_callback is not None
async_callback = functools.partial(
self._async_process_outputs,
model_input=model_input,
output_proc_callback=model_input.base_output_proc_callback)
model_input.frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=async_callback)
# Update the local instance
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
# Execute the model
output = self._base_model_runner.execute_model(frozen_model_input,
kv_caches,
intermediate_tensors,
num_steps=1)
# record the event for the current step so that the next step can sync
model_input.record_step_event(stream)
if get_pp_group().is_last_rank and self.is_driver_worker:
assert isinstance(output, list)
assert len(
output
) == 1, "MultiStepModelRunner requires single-step base_models"
# event for the pythonization so that we only pythonize if the
# tensors are ready. May be able to be combined with the step event
output_ready_event = torch.npu.Event()
output_ready_event.record(stream)
if self.parallel_config.pipeline_parallel_size > 1:
output[0].sampled_token_ids_cpu = output[
0].sampled_token_ids.cpu()
model_input.cached_outputs.append(
NPUModelOutput(output[0], output_ready_event,
output[0].sampled_token_ids, False,
output[0].logprobs, self.pythonization_cache))
# These NPU tensors are not required by multi-step;
# erase them to ensure they are not pythonized or
# transferred to CPU
output[0].sampled_token_ids = None
output[0].sampled_token_probs = None
output[0].logprobs = None
# Pythonize the output if CPU is ahead and the previous step is
# ready.
if frozen_model_input.async_callback is None:
for model_output in model_input.cached_outputs:
model_output.maybe_pythonize(model_input,
self._copy_stream,
self.pinned_sampled_token_ids)
model_input.current_step += 1
if not get_pp_group().is_last_rank:
# Should be IntermediateTensors
assert isinstance(output, IntermediateTensors)
return output
if not self.is_driver_worker:
return []
# Pythonize the output and block if needed since it is the last step
if model_input.is_last_step:
outputs = self._final_process_outputs(
model_input, model_input.base_output_proc_callback)
if self.pythonization_cache:
self.pythonization_cache.reset()
return outputs
# should be [SamplerOutput]
return output
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
num_seqs: Optional[int], num_queries: int):
assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
assert sampling_metadata.selected_token_indices.shape == (
num_queries, )
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for i in range(num_queries):
seq_group = sampling_metadata.seq_groups[i]
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode
def _advance_step(self, model_input: StatefulModelInputForNPU,
out: SamplerOutput) -> StatefulModelInputForNPU:
model_input.maybe_advance_frozen_model_input(self.device,
self.pin_memory)
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.input_tokens is not None
assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs
assert frozen_model_input.attn_metadata is not None
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
num_seqs = model_input.num_seqs
num_queries = model_input.num_queries
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
attn_metadata = frozen_model_input.attn_metadata
assert attn_metadata is not None
turn_prefills_into_decodes: bool = model_input.current_step == 1 and \
model_input.num_single_step_prefills != 0
attn_metadata.advance_step(
frozen_model_input,
sampled_token_ids,
self.block_size,
num_seqs,
num_queries,
turn_prefills_into_decodes=turn_prefills_into_decodes)
return model_input
def load_model(self) -> None:
self._base_model_runner.load_model()
self.model_memory_usage = self._base_model_runner.model_memory_usage
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
return self._base_model_runner.save_sharded_state(
path, pattern, max_size)
def save_tensorized_model(self,
tensorizer_config: TensorizerConfig) -> None:
return self._base_model_runner.save_tensorized_model(tensorizer_config)
def profile_run(self) -> None:
return self._base_model_runner.profile_run()
def remove_all_loras(self):
return self._base_model_runner.remove_all_loras()
def capture_model(self, kv_caches: List[List]) -> None:
return self._base_model_runner.capture_model(kv_caches)
@property
def vocab_size(self) -> int:
return self._base_model_runner.vocab_size
DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
Optional[List[SampleLogprobs]]]
def deferred_pythonize_logprobs(
output: SamplerOutput,
sampling_metadata: SamplingMetadata,
logprobs_tensor: Optional[torch.Tensor],
) -> DeferredLogprobsReturnType:
"""Perform deferred logprob Pythonization.
1. Pythonize NPU-side sampler result tensors into CPU-side sampler result.
2. Pythonize NPU-side logprobs tensor into CPU-side logprobs lists,
utilizing the Pythonized sampler result computed in step 1.
These deferred computations are not required for single-step scheduling
or the `profile_run()` phase of multi-step scheduling.
Args:
output: sampler output (under deferred Pythonization)
sampling_metadata
Returns:
prompt_logprobs (CPU), sample_logprobs (CPU)
"""
# - Deferred pythonization of sample result
sampler_result = get_pythonized_sample_results(
output.deferred_sample_results_args)
# - Erase the NPU-side deferred sample_result
# computation args to ensure it is never
# pythonized or transferred to CPU
output.deferred_sample_results_args = None
# - Deferred pythonization of logprobs
(
prompt_logprobs,
sample_logprobs,
) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
return prompt_logprobs, sample_logprobs
def _pythonize_sampler_output(
model_input: StatefulModelInputForNPU,
output: SamplerOutput,
pinned_sampled_token_buffer: torch.Tensor,
sampled_token_ids: torch.Tensor,
logprobs_tensor: Optional[torch.Tensor],
cache: Optional[PythonizationCache],
) -> None:
""" This function is only called when the output tensors are ready.
See :class:`ModelOutput`.
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
adding a Pythonized output data structure
(:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
Args:
model_input
output: sampler output
pinned_sampled_token_token_buffer: CPU-side pinned memory
(receives copy of
NPU-side token buffer.)
sampled_token_ids: NPU-side token buffer
logprobs_tensor: NPU-side tensor containing
logprobs computed during sampling
"""
assert model_input.frozen_model_input is not None
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input.sampling_metadata is not None
sampling_metadata = frozen_model_input.sampling_metadata
# samples generation should have been skipped
assert not output.outputs
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
# We guarantee output tensors are ready, so it is safe to
# pythonize the sampler output & obtain CPU-side logprobs.
#
# However we should check whether logprobs pythonization may
# be skipped entirely, i.e. because no logprobs were requested
# or pythonization was not deferred. To that end,
#
# * `prompt_logprobs_are_requested_for_prefill` signals that
# there are *any* prefill-phase requests which specify that
# prompt logprobs should be returned.
#
# * `any_logprobs_are_requested` signals that there are any
# requests which (1) specify that sample logprobs should be
# returned, or (2) are in the prefill phase AND specify that
# prompt logprobs should be returned.
#
# Later on, these flags cause adjustments to the pythonization
# process to accommodate logprobs.
seq_groups = sampling_metadata.seq_groups
prompt_logprobs_are_requested_for_prefill = any([
sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
for sg in seq_groups
])
any_logprobs_are_requested = (
prompt_logprobs_are_requested_for_prefill
or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))
if prompt_logprobs_are_requested_for_prefill:
# CPU NPU sync, after gathering *only* sampled tokens (since
# requesting prompt logprobs leads `sampled_token_ids` to
# include prompt token ids in addition to sampled token ids.)
sample_idx_tensor = torch.tensor(
[sdx for sg in seq_groups for sdx in sg.sample_indices])
pinned_buffer = pinned_buffer.copy_(
sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
else:
# CPU NPU sync
pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
non_blocking=False)
# this will not block as the tensors are already on CPU
samples_list = pinned_buffer.tolist()
skip_sampler_cpu_output = (
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
# *Don't* skip logprobs pythonization *if*:
# * Any requests require logprobs to be returned in this
# iteration AND
# * These requests are being scheduled in a fashion which
# defers pythonization (i.e. multi-step scheduling.)
do_pythonize_logprobs = (skip_sampler_cpu_output
and any_logprobs_are_requested)
(
prompt_logprobs,
sample_logprobs,
) = (deferred_pythonize_logprobs(output, sampling_metadata,
logprobs_tensor)
if do_pythonize_logprobs else (None, None))
for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)):
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
# (Check for Guided Decoding)
if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding")
if do_pythonize_logprobs:
assert prompt_logprobs is not None
assert sample_logprobs is not None
(
group_prompt_logprobs,
group_sample_logprobs,
) = ( # Utilize deferred pythonization results
prompt_logprobs[sgdx],
sample_logprobs[sgdx],
)
elif any_logprobs_are_requested:
(
group_prompt_logprobs,
group_sample_logprobs,
) = (
# profile_run: use already-computed logprobs
output.outputs[sgdx].prompt_logprobs,
[sample.logprobs for sample in output.outputs[sgdx].samples])
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
seq_outputs: List[SequenceOutput]
if cache is not None:
completion_seq_group_output: CompletionSequenceGroupOutput = \
cache.cached_completion_seq_group_output.get_object()
completion_seq_group_output.samples.clear()
seq_outputs = completion_seq_group_output.samples
else:
seq_outputs = []
for tdx, (parent_id,
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
if cache is not None:
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
)
seq_output.parent_seq_id = seq_ids[parent_id]
seq_output.output_token = next_token_id
if any_logprobs_are_requested:
seq_output.logprobs = group_sample_logprobs[tdx]
else:
logprobs = next(iter(seq_output.logprobs.values()))
seq_output.logprobs.clear()
logprobs.logprob = float('inf')
logprobs.rank = None
logprobs.decoded_token = None
seq_output.logprobs[next_token_id] = logprobs
seq_outputs.append(seq_output)
else:
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
(group_sample_logprobs[tdx]
if any_logprobs_are_requested else {
next_token_id:
Logprob(logprob=float('inf'),
rank=None,
decoded_token=None)
})))
if cache is not None:
completion_seq_group_output.prompt_logprobs = \
group_prompt_logprobs if any_logprobs_are_requested else None
output.outputs.append(completion_seq_group_output)
else:
output.outputs.append(
CompletionSequenceGroupOutput(
seq_outputs, (group_prompt_logprobs
if any_logprobs_are_requested else None)))
assert len(output.outputs) > 0