[Misc][V0 Deprecation] Remove Multi-Step Model Runner (#1820)
### What this PR does / why we need it?
Remove multi-step model runner.
This PR is a part of
https://github.com/vllm-project/vllm-ascend/issues/1620.
- vLLM version: v0.9.2
- vLLM main:
34cda778a0
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -1,737 +0,0 @@
|
||||
import dataclasses
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
||||
SamplerOutput,
|
||||
SamplingMetadata, get_logprobs,
|
||||
get_pythonized_sample_results)
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.worker.model_runner_base import (
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_frozen_model_input_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
from vllm.worker.multi_step_model_runner import (ModelOutput,
|
||||
PythonizationCache,
|
||||
StatefulModelInput)
|
||||
|
||||
from vllm_ascend.utils import current_stream
|
||||
from vllm_ascend.worker.model_runner import (
|
||||
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class StatefulModelInputForNPU(StatefulModelInput):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def record_step_event(self, current_stream: torch.npu.Stream):
|
||||
# record the event for the current step so that the next step can sync
|
||||
# on it. We modulo by 2 to keep the events in a circular buffer and
|
||||
# support any attn backends that may be supported in the future. ie
|
||||
# Flashinfer would want two DecodeWrappers to overlap the CPU and NPU.
|
||||
self.step_cuda_events[self.current_step & 1] = \
|
||||
torch.npu.Event(blocking=True)
|
||||
self.step_cuda_events[self.current_step & 1].record(current_stream)
|
||||
|
||||
# actual frozen model input dataclass passed to _base_model_runner
|
||||
frozen_model_input: Optional[ModelInputForNPUWithSamplingMetadata] = None
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "StatefulModelInputForNPU":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
tensor_dict = _init_frozen_model_input_from_tensor_dict(
|
||||
ModelInputForNPUWithSamplingMetadata, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool):
|
||||
"""
|
||||
Advancing the datastructures of StatefulModelInput::frozen_model_input
|
||||
is only required when prefills are scheduled with decodes to run in
|
||||
multi-step. This advancement/correction is required to account for
|
||||
the conversion of Prefills to Decodes after the first multi-step.
|
||||
"""
|
||||
if self.current_step != 1 or self.num_single_step_prefills == 0:
|
||||
return
|
||||
|
||||
assert self.frozen_model_input is not None
|
||||
fmi = self.frozen_model_input
|
||||
|
||||
# Truncate input_tokens
|
||||
assert fmi.input_tokens is not None
|
||||
assert fmi.input_tokens.shape[0] >= self.num_seqs
|
||||
fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs]
|
||||
|
||||
# Update frozen_model_input::input_positons.
|
||||
assert fmi.input_positions is not None
|
||||
assert fmi.input_positions.shape[0] >= self.num_seqs
|
||||
fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self.
|
||||
num_seqs]
|
||||
|
||||
# Assert unsupported
|
||||
# TODO Uncomment the following codes when NPU supported
|
||||
# assert fmi.lora_mapping is None
|
||||
# assert fmi.lora_requests is not None
|
||||
# assert len(fmi.lora_requests) == 0
|
||||
# assert fmi.prompt_adapter_mapping is None
|
||||
# assert fmi.prompt_adapter_requests is not None
|
||||
# assert len(fmi.prompt_adapter_requests) == 0
|
||||
assert fmi.attn_metadata is not None
|
||||
assert fmi.multi_modal_kwargs is not None
|
||||
assert len(fmi.multi_modal_kwargs) == 0
|
||||
|
||||
self.frozen_model_input = dataclasses.replace(
|
||||
self.frozen_model_input,
|
||||
input_tokens=fmi_new_input_tokens,
|
||||
input_positions=fmi_new_input_positions)
|
||||
|
||||
self.maybe_advance_sampling_metadata(device, pin_memory)
|
||||
|
||||
|
||||
@dataclass(frozen=False)
|
||||
class NPUModelOutput(ModelOutput):
|
||||
|
||||
logprobs: Optional["torch.Tensor"] = None
|
||||
|
||||
def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput",
|
||||
copy_stream: torch.npu.Stream,
|
||||
pinned_sampled_token_buffer: torch.Tensor,
|
||||
blocking: bool) -> bool:
|
||||
"""
|
||||
If blocking is set, will block until the forward pass for the output is
|
||||
ready and pythonize the output. Upon completing Pythonization, erases
|
||||
self.logprobs (note that a non-blocking call that is performed when
|
||||
the sampler output is not yet ready, will not erase self.logprobs.)
|
||||
"""
|
||||
assert self.sampled_token_ids is not None
|
||||
if not blocking and not self.sampler_output_ready_event.query():
|
||||
return False
|
||||
|
||||
if blocking:
|
||||
self.sampler_output_ready_event.synchronize()
|
||||
with torch.npu.stream(copy_stream):
|
||||
_pythonize_sampler_output(input_metadata, self.sampler_output,
|
||||
pinned_sampled_token_buffer,
|
||||
self.sampled_token_ids, self.logprobs,
|
||||
self.pythonization_cache)
|
||||
|
||||
# Erase the logprobs GPU-side tensor.
|
||||
# Note that although _pythonize_sampler_output() runs in its
|
||||
# own CUDA stream, nonetheless _pythonize_sampler_output()
|
||||
# cannot return until Pythonization is complete; therefore
|
||||
# we know that by the time the CPU reaches this point,
|
||||
# `self.logprobs` is no longer needed.
|
||||
self.logprobs = None
|
||||
return True
|
||||
|
||||
|
||||
class MultiStepModelNPURunner(NPUModelRunnerBase[StatefulModelInputForNPU]):
|
||||
# mypy: enable-error-code=type-var
|
||||
|
||||
def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# uses the base model runner to execute the model and wraps it with
|
||||
# multi-step logic
|
||||
self._base_model_runner: NPUModelRunnerBase = base_model_runner
|
||||
|
||||
self.is_multi_step = self.scheduler_config.is_multi_step
|
||||
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# Using the PythonizationCache in Pipeline-Parallel clobbers the
|
||||
# SequenceOutput and CompletionSequenceGroupOutput object.
|
||||
# When cache-reset happens at the last step of a multi-step
|
||||
# execution, there may be other on-going single-step/multi-step
|
||||
# executions. The current caching implementation does not check
|
||||
# for this.
|
||||
self.pythonization_cache = PythonizationCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self._base_model_runner.get_model()
|
||||
|
||||
@functools.cached_property
|
||||
def _copy_stream(self):
|
||||
# used to copy tensors from NPU to CPU asynchronously
|
||||
return torch.npu.Stream()
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> StatefulModelInputForNPU:
|
||||
model_input = (StatefulModelInputForNPU.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
))
|
||||
return model_input
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> StatefulModelInputForNPU:
|
||||
frozen_model_input: ModelInputForNPUWithSamplingMetadata = \
|
||||
self._base_model_runner.prepare_model_input(
|
||||
seq_group_metadata_list,
|
||||
virtual_engine,
|
||||
finished_requests_ids)
|
||||
|
||||
assert frozen_model_input.query_lens is not None
|
||||
assert frozen_model_input.seq_lens is not None
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
num_queries = len(frozen_model_input.query_lens)
|
||||
num_seqs = len(frozen_model_input.seq_lens)
|
||||
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
|
||||
|
||||
model_input = StatefulModelInputForNPU(
|
||||
frozen_model_input=frozen_model_input,
|
||||
num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
num_single_step_prefills=num_single_step_prefills,
|
||||
step_cuda_events=[torch.npu.Event(blocking=True)] * 2,
|
||||
)
|
||||
|
||||
return model_input
|
||||
|
||||
def _async_process_outputs(self, model_input: StatefulModelInputForNPU,
|
||||
output_proc_callback: Callable):
|
||||
# Proceed with pythonization and output_proc in order.
|
||||
# Stop on the first one that fails to pythonize
|
||||
output_proc_callback()
|
||||
|
||||
cont = True
|
||||
for step_num, model_output in enumerate(model_input.cached_outputs):
|
||||
if not model_output.pythonized:
|
||||
model_output.maybe_pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
if model_output.pythonized:
|
||||
ctx = output_proc_callback.keywords["ctx"] # type: ignore
|
||||
ctx.append_output(
|
||||
outputs=[model_output.sampler_output],
|
||||
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False,
|
||||
is_first_step_output=step_num == 0)
|
||||
|
||||
output_proc_callback()
|
||||
else:
|
||||
cont = False
|
||||
|
||||
if not cont:
|
||||
break
|
||||
|
||||
def _final_process_outputs(
|
||||
self, model_input: StatefulModelInputForNPU,
|
||||
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
|
||||
assert model_input.frozen_model_input is not None
|
||||
|
||||
has_async_callback = output_proc_callback is not None
|
||||
|
||||
outputs = []
|
||||
for step_num, output in enumerate(model_input.cached_outputs):
|
||||
is_last_step = step_num == len(model_input.cached_outputs) - 1
|
||||
|
||||
# For non-async case:
|
||||
# -- We simply add the outputs
|
||||
# For async case:
|
||||
# -- Invoke callback, pythonize, add to callback queue and repeat
|
||||
# -- For last output, just add to callback queue
|
||||
if has_async_callback:
|
||||
assert output_proc_callback is not None
|
||||
|
||||
# Invoke callback before pythonize (to overlap with NPU)
|
||||
output_proc_callback()
|
||||
|
||||
# Pythonize
|
||||
if not output.pythonized:
|
||||
output.pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
|
||||
# For non last step, add to callback queue to chain
|
||||
# callbacks=>pythonize pairs (for NPU overlap)
|
||||
if not is_last_step:
|
||||
ctx = output_proc_callback.keywords[ # type: ignore
|
||||
"ctx"] # type: ignore
|
||||
ctx.append_output(
|
||||
outputs=[output.sampler_output],
|
||||
seq_group_metadata_list=ctx.
|
||||
seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False,
|
||||
is_first_step_output=step_num == 0)
|
||||
else:
|
||||
outputs.append(output.sampler_output)
|
||||
else:
|
||||
output.pythonize(model_input, self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
outputs.append(output.sampler_output)
|
||||
|
||||
return outputs
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: StatefulModelInputForNPU,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
"""
|
||||
Execute the model for a single step and update multi-step
|
||||
metadata
|
||||
"""
|
||||
assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
# path for warm up runs
|
||||
if not model_input.is_multi_step:
|
||||
return self._base_model_runner.execute_model(
|
||||
frozen_model_input, kv_caches, intermediate_tensors, num_steps)
|
||||
|
||||
# make sure we skip the sampler on the lask rank and only pythonize
|
||||
# if CPU is ahead.
|
||||
if self.is_driver_worker and get_pp_group().is_last_rank:
|
||||
if self.pinned_sampled_token_ids is None:
|
||||
self.pinned_sampled_token_ids = torch.zeros(
|
||||
(self.scheduler_config.max_num_seqs, 1),
|
||||
dtype=torch.long,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self._base_model_runner.sampler.include_gpu_probs_tensor = True
|
||||
if frozen_model_input.sampling_metadata:
|
||||
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
True)
|
||||
|
||||
# some pre-execute model logic for multi-step:
|
||||
# - if it's the first step, we need to reset the sampling tensors
|
||||
# - if it's not the first step, we need to advance the step using the
|
||||
# appended sampler output from last iteration
|
||||
# - also maybe pythonize if CPU is ahead of NPU
|
||||
|
||||
stream = current_stream()
|
||||
if not model_input.is_first_multi_step:
|
||||
# Explicitly block on the previous step's forward to make sure we
|
||||
# don't clobber any NPU tensors still in use.
|
||||
# This is not needed for flashattn backend, but for other attn
|
||||
# backends such as flashinfer that performs extra CPU operations on
|
||||
# input metadata we may need to synchronize any CPU operations that
|
||||
# might clobber enqueued forwards. (prevents CPU from running too
|
||||
# far ahead if needed)
|
||||
model_input.wait_previous_step()
|
||||
model_input = self._advance_step(
|
||||
model_input, model_input.cached_outputs[-1].sampler_output)
|
||||
|
||||
# frozen_model_input may have been updated
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
if model_input.base_output_proc_callback is None:
|
||||
assert frozen_model_input is not None
|
||||
model_input.base_output_proc_callback = \
|
||||
frozen_model_input.async_callback
|
||||
|
||||
if frozen_model_input.async_callback is not None:
|
||||
assert model_input.base_output_proc_callback is not None
|
||||
async_callback = functools.partial(
|
||||
self._async_process_outputs,
|
||||
model_input=model_input,
|
||||
output_proc_callback=model_input.base_output_proc_callback)
|
||||
|
||||
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
||||
model_input.frozen_model_input,
|
||||
async_callback=async_callback)
|
||||
# Update the local instance
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
|
||||
# Execute the model
|
||||
output = self._base_model_runner.execute_model(frozen_model_input,
|
||||
kv_caches,
|
||||
intermediate_tensors,
|
||||
num_steps=1)
|
||||
|
||||
# record the event for the current step so that the next step can sync
|
||||
model_input.record_step_event(stream)
|
||||
|
||||
if get_pp_group().is_last_rank and self.is_driver_worker:
|
||||
assert isinstance(output, list)
|
||||
assert len(
|
||||
output
|
||||
) == 1, "MultiStepModelRunner requires single-step base_models"
|
||||
|
||||
# event for the pythonization so that we only pythonize if the
|
||||
# tensors are ready. May be able to be combined with the step event
|
||||
output_ready_event = torch.npu.Event()
|
||||
output_ready_event.record(stream)
|
||||
if self.parallel_config.pipeline_parallel_size > 1:
|
||||
output[0].sampled_token_ids_cpu = output[
|
||||
0].sampled_token_ids.cpu()
|
||||
model_input.cached_outputs.append(
|
||||
NPUModelOutput(output[0], output_ready_event,
|
||||
output[0].sampled_token_ids, False,
|
||||
output[0].logprobs, self.pythonization_cache))
|
||||
|
||||
# These NPU tensors are not required by multi-step;
|
||||
# erase them to ensure they are not pythonized or
|
||||
# transferred to CPU
|
||||
output[0].sampled_token_ids = None
|
||||
output[0].sampled_token_probs = None
|
||||
output[0].logprobs = None
|
||||
|
||||
# Pythonize the output if CPU is ahead and the previous step is
|
||||
# ready.
|
||||
if frozen_model_input.async_callback is None:
|
||||
for model_output in model_input.cached_outputs:
|
||||
model_output.maybe_pythonize(model_input,
|
||||
self._copy_stream,
|
||||
self.pinned_sampled_token_ids)
|
||||
|
||||
model_input.current_step += 1
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
# Should be IntermediateTensors
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
return output
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
# Pythonize the output and block if needed since it is the last step
|
||||
if model_input.is_last_step:
|
||||
outputs = self._final_process_outputs(
|
||||
model_input, model_input.base_output_proc_callback)
|
||||
if self.pythonization_cache:
|
||||
self.pythonization_cache.reset()
|
||||
return outputs
|
||||
|
||||
# should be [SamplerOutput]
|
||||
return output
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
|
||||
num_seqs: Optional[int], num_queries: int):
|
||||
|
||||
assert sampling_metadata.num_prompts == 0
|
||||
assert len(sampling_metadata.seq_groups) == num_queries
|
||||
assert sampling_metadata.selected_token_indices.shape == (
|
||||
num_queries, )
|
||||
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
|
||||
|
||||
# Verify that all sequences are decodes
|
||||
for i in range(num_queries):
|
||||
seq_group = sampling_metadata.seq_groups[i]
|
||||
|
||||
assert seq_group.is_prompt is False # No prompt
|
||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||
assert seq_group.sample_indices == [i] # Simple
|
||||
assert seq_group.seq_len is None # Decode
|
||||
assert seq_group.query_len is None # Decode
|
||||
|
||||
def _advance_step(self, model_input: StatefulModelInputForNPU,
|
||||
out: SamplerOutput) -> StatefulModelInputForNPU:
|
||||
|
||||
model_input.maybe_advance_frozen_model_input(self.device,
|
||||
self.pin_memory)
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
assert frozen_model_input.input_tokens is not None
|
||||
assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs
|
||||
assert frozen_model_input.attn_metadata is not None
|
||||
|
||||
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
|
||||
num_seqs = model_input.num_seqs
|
||||
num_queries = model_input.num_queries
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input is not None
|
||||
attn_metadata = frozen_model_input.attn_metadata
|
||||
assert attn_metadata is not None
|
||||
|
||||
turn_prefills_into_decodes: bool = model_input.current_step == 1 and \
|
||||
model_input.num_single_step_prefills != 0
|
||||
attn_metadata.advance_step(
|
||||
frozen_model_input,
|
||||
sampled_token_ids,
|
||||
self.block_size,
|
||||
num_seqs,
|
||||
num_queries,
|
||||
turn_prefills_into_decodes=turn_prefills_into_decodes)
|
||||
|
||||
return model_input
|
||||
|
||||
def load_model(self) -> None:
|
||||
self._base_model_runner.load_model()
|
||||
self.model_memory_usage = self._base_model_runner.model_memory_usage
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
return self._base_model_runner.save_sharded_state(
|
||||
path, pattern, max_size)
|
||||
|
||||
def save_tensorized_model(self,
|
||||
tensorizer_config: TensorizerConfig) -> None:
|
||||
return self._base_model_runner.save_tensorized_model(tensorizer_config)
|
||||
|
||||
def profile_run(self) -> None:
|
||||
return self._base_model_runner.profile_run()
|
||||
|
||||
def remove_all_loras(self):
|
||||
return self._base_model_runner.remove_all_loras()
|
||||
|
||||
def capture_model(self, kv_caches: List[List]) -> None:
|
||||
return self._base_model_runner.capture_model(kv_caches)
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._base_model_runner.vocab_size
|
||||
|
||||
|
||||
DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
|
||||
Optional[List[SampleLogprobs]]]
|
||||
|
||||
|
||||
def deferred_pythonize_logprobs(
|
||||
output: SamplerOutput,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
logprobs_tensor: Optional[torch.Tensor],
|
||||
) -> DeferredLogprobsReturnType:
|
||||
"""Perform deferred logprob Pythonization.
|
||||
|
||||
1. Pythonize NPU-side sampler result tensors into CPU-side sampler result.
|
||||
2. Pythonize NPU-side logprobs tensor into CPU-side logprobs lists,
|
||||
utilizing the Pythonized sampler result computed in step 1.
|
||||
|
||||
These deferred computations are not required for single-step scheduling
|
||||
or the `profile_run()` phase of multi-step scheduling.
|
||||
|
||||
Args:
|
||||
output: sampler output (under deferred Pythonization)
|
||||
sampling_metadata
|
||||
|
||||
Returns:
|
||||
prompt_logprobs (CPU), sample_logprobs (CPU)
|
||||
"""
|
||||
|
||||
# - Deferred pythonization of sample result
|
||||
sampler_result = get_pythonized_sample_results(
|
||||
output.deferred_sample_results_args)
|
||||
|
||||
# - Erase the NPU-side deferred sample_result
|
||||
# computation args to ensure it is never
|
||||
# pythonized or transferred to CPU
|
||||
output.deferred_sample_results_args = None
|
||||
|
||||
# - Deferred pythonization of logprobs
|
||||
(
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
|
||||
assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
|
||||
assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
|
||||
|
||||
return prompt_logprobs, sample_logprobs
|
||||
|
||||
|
||||
def _pythonize_sampler_output(
|
||||
model_input: StatefulModelInputForNPU,
|
||||
output: SamplerOutput,
|
||||
pinned_sampled_token_buffer: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
logprobs_tensor: Optional[torch.Tensor],
|
||||
cache: Optional[PythonizationCache],
|
||||
) -> None:
|
||||
""" This function is only called when the output tensors are ready.
|
||||
See :class:`ModelOutput`.
|
||||
|
||||
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
|
||||
adding a Pythonized output data structure
|
||||
(:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
|
||||
|
||||
Args:
|
||||
model_input
|
||||
output: sampler output
|
||||
pinned_sampled_token_token_buffer: CPU-side pinned memory
|
||||
(receives copy of
|
||||
NPU-side token buffer.)
|
||||
sampled_token_ids: NPU-side token buffer
|
||||
logprobs_tensor: NPU-side tensor containing
|
||||
logprobs computed during sampling
|
||||
"""
|
||||
|
||||
assert model_input.frozen_model_input is not None
|
||||
|
||||
frozen_model_input = model_input.frozen_model_input
|
||||
assert frozen_model_input.sampling_metadata is not None
|
||||
sampling_metadata = frozen_model_input.sampling_metadata
|
||||
# samples generation should have been skipped
|
||||
assert not output.outputs
|
||||
|
||||
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
|
||||
|
||||
# We guarantee output tensors are ready, so it is safe to
|
||||
# pythonize the sampler output & obtain CPU-side logprobs.
|
||||
#
|
||||
# However we should check whether logprobs pythonization may
|
||||
# be skipped entirely, i.e. because no logprobs were requested
|
||||
# or pythonization was not deferred. To that end,
|
||||
#
|
||||
# * `prompt_logprobs_are_requested_for_prefill` signals that
|
||||
# there are *any* prefill-phase requests which specify that
|
||||
# prompt logprobs should be returned.
|
||||
#
|
||||
# * `any_logprobs_are_requested` signals that there are any
|
||||
# requests which (1) specify that sample logprobs should be
|
||||
# returned, or (2) are in the prefill phase AND specify that
|
||||
# prompt logprobs should be returned.
|
||||
#
|
||||
# Later on, these flags cause adjustments to the pythonization
|
||||
# process to accommodate logprobs.
|
||||
|
||||
seq_groups = sampling_metadata.seq_groups
|
||||
prompt_logprobs_are_requested_for_prefill = any([
|
||||
sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
|
||||
for sg in seq_groups
|
||||
])
|
||||
any_logprobs_are_requested = (
|
||||
prompt_logprobs_are_requested_for_prefill
|
||||
or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))
|
||||
|
||||
if prompt_logprobs_are_requested_for_prefill:
|
||||
# CPU NPU sync, after gathering *only* sampled tokens (since
|
||||
# requesting prompt logprobs leads `sampled_token_ids` to
|
||||
# include prompt token ids in addition to sampled token ids.)
|
||||
sample_idx_tensor = torch.tensor(
|
||||
[sdx for sg in seq_groups for sdx in sg.sample_indices])
|
||||
pinned_buffer = pinned_buffer.copy_(
|
||||
sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
|
||||
else:
|
||||
# CPU NPU sync
|
||||
pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
|
||||
non_blocking=False)
|
||||
|
||||
# this will not block as the tensors are already on CPU
|
||||
samples_list = pinned_buffer.tolist()
|
||||
|
||||
skip_sampler_cpu_output = (
|
||||
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
|
||||
|
||||
# *Don't* skip logprobs pythonization *if*:
|
||||
# * Any requests require logprobs to be returned in this
|
||||
# iteration AND
|
||||
# * These requests are being scheduled in a fashion which
|
||||
# defers pythonization (i.e. multi-step scheduling.)
|
||||
do_pythonize_logprobs = (skip_sampler_cpu_output
|
||||
and any_logprobs_are_requested)
|
||||
(
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
) = (deferred_pythonize_logprobs(output, sampling_metadata,
|
||||
logprobs_tensor)
|
||||
if do_pythonize_logprobs else (None, None))
|
||||
|
||||
for sgdx, (seq_group,
|
||||
sample_result) in enumerate(zip(seq_groups, samples_list)):
|
||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
# (Check for Guided Decoding)
|
||||
if seq_group.sampling_params.logits_processors:
|
||||
assert len(seq_group.sampling_params.logits_processors) == 0, (
|
||||
"Logits Processors are not supported in multi-step decoding")
|
||||
|
||||
if do_pythonize_logprobs:
|
||||
assert prompt_logprobs is not None
|
||||
assert sample_logprobs is not None
|
||||
|
||||
(
|
||||
group_prompt_logprobs,
|
||||
group_sample_logprobs,
|
||||
) = ( # Utilize deferred pythonization results
|
||||
prompt_logprobs[sgdx],
|
||||
sample_logprobs[sgdx],
|
||||
)
|
||||
elif any_logprobs_are_requested:
|
||||
(
|
||||
group_prompt_logprobs,
|
||||
group_sample_logprobs,
|
||||
) = (
|
||||
# profile_run: use already-computed logprobs
|
||||
output.outputs[sgdx].prompt_logprobs,
|
||||
[sample.logprobs for sample in output.outputs[sgdx].samples])
|
||||
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids = sample_result
|
||||
parent_ids = [0]
|
||||
seq_outputs: List[SequenceOutput]
|
||||
|
||||
if cache is not None:
|
||||
completion_seq_group_output: CompletionSequenceGroupOutput = \
|
||||
cache.cached_completion_seq_group_output.get_object()
|
||||
completion_seq_group_output.samples.clear()
|
||||
seq_outputs = completion_seq_group_output.samples
|
||||
else:
|
||||
seq_outputs = []
|
||||
|
||||
for tdx, (parent_id,
|
||||
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
|
||||
if cache is not None:
|
||||
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
|
||||
)
|
||||
seq_output.parent_seq_id = seq_ids[parent_id]
|
||||
seq_output.output_token = next_token_id
|
||||
|
||||
if any_logprobs_are_requested:
|
||||
seq_output.logprobs = group_sample_logprobs[tdx]
|
||||
else:
|
||||
logprobs = next(iter(seq_output.logprobs.values()))
|
||||
seq_output.logprobs.clear()
|
||||
|
||||
logprobs.logprob = float('inf')
|
||||
logprobs.rank = None
|
||||
logprobs.decoded_token = None
|
||||
|
||||
seq_output.logprobs[next_token_id] = logprobs
|
||||
|
||||
seq_outputs.append(seq_output)
|
||||
|
||||
else:
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||
(group_sample_logprobs[tdx]
|
||||
if any_logprobs_are_requested else {
|
||||
next_token_id:
|
||||
Logprob(logprob=float('inf'),
|
||||
rank=None,
|
||||
decoded_token=None)
|
||||
})))
|
||||
if cache is not None:
|
||||
completion_seq_group_output.prompt_logprobs = \
|
||||
group_prompt_logprobs if any_logprobs_are_requested else None
|
||||
output.outputs.append(completion_seq_group_output)
|
||||
else:
|
||||
output.outputs.append(
|
||||
CompletionSequenceGroupOutput(
|
||||
seq_outputs, (group_prompt_logprobs
|
||||
if any_logprobs_are_requested else None)))
|
||||
|
||||
assert len(output.outputs) > 0
|
||||
Reference in New Issue
Block a user