### What this PR does / why we need it? - According to https://github.com/vllm-project/vllm-ascend/issues/807, we pull request for customer ascendc kernel of multi-step. - also a bug we found in multi_step_runner.py is fixed when we use multi-step on V0 Engine. ### Does this PR introduce _any_ user-facing change? no user-facing change ### How was this patch tested? we add Unit Test file and offline inference file to test the custom ascendc kernel. See test/ops/test_multi_step.py and examples/offline_multi_step.py --------- Signed-off-by: wan_danfeng <wonderful199082@126.com>
738 lines
31 KiB
Python
738 lines
31 KiB
Python
import dataclasses
|
|
import functools
|
|
from dataclasses import dataclass
|
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
|
Union)
|
|
|
|
import torch
|
|
from torch import nn
|
|
from vllm.distributed import get_pp_group
|
|
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
|
|
SamplerOutput,
|
|
SamplingMetadata, get_logprobs,
|
|
get_pythonized_sample_results)
|
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
|
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
|
Logprob, SequenceGroupMetadata, SequenceOutput)
|
|
from vllm.worker.model_runner_base import (
|
|
_init_attn_metadata_from_tensor_dict,
|
|
_init_frozen_model_input_from_tensor_dict,
|
|
_init_sampling_metadata_from_tensor_dict)
|
|
from vllm.worker.multi_step_model_runner import (ModelOutput,
|
|
PythonizationCache,
|
|
StatefulModelInput)
|
|
|
|
from vllm_ascend.utils import current_stream
|
|
from vllm_ascend.worker.model_runner import (
|
|
ModelInputForNPUWithSamplingMetadata, NPUModelRunnerBase)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
|
|
|
|
@dataclass(frozen=False)
|
|
class StatefulModelInputForNPU(StatefulModelInput):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def record_step_event(self, current_stream: torch.npu.Stream):
|
|
# record the event for the current step so that the next step can sync
|
|
# on it. We modulo by 2 to keep the events in a circular buffer and
|
|
# support any attn backends that may be supported in the future. ie
|
|
# Flashinfer would want two DecodeWrappers to overlap the CPU and NPU.
|
|
self.step_cuda_events[self.current_step & 1] = \
|
|
torch.npu.Event(blocking=True)
|
|
self.step_cuda_events[self.current_step & 1].record(current_stream)
|
|
|
|
# actual frozen model input dataclass passed to _base_model_runner
|
|
frozen_model_input: Optional[ModelInputForNPUWithSamplingMetadata] = None
|
|
|
|
@classmethod
|
|
def from_broadcasted_tensor_dict(
|
|
cls,
|
|
tensor_dict: Dict[str, Any],
|
|
attn_backend: Optional["AttentionBackend"] = None,
|
|
) -> "StatefulModelInputForNPU":
|
|
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
|
if attn_backend is not None:
|
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
|
attn_backend, tensor_dict)
|
|
tensor_dict = _init_frozen_model_input_from_tensor_dict(
|
|
ModelInputForNPUWithSamplingMetadata, tensor_dict)
|
|
return cls(**tensor_dict)
|
|
|
|
def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool):
|
|
"""
|
|
Advancing the datastructures of StatefulModelInput::frozen_model_input
|
|
is only required when prefills are scheduled with decodes to run in
|
|
multi-step. This advancement/correction is required to account for
|
|
the conversion of Prefills to Decodes after the first multi-step.
|
|
"""
|
|
if self.current_step != 1 or self.num_single_step_prefills == 0:
|
|
return
|
|
|
|
assert self.frozen_model_input is not None
|
|
fmi = self.frozen_model_input
|
|
|
|
# Truncate input_tokens
|
|
assert fmi.input_tokens is not None
|
|
assert fmi.input_tokens.shape[0] >= self.num_seqs
|
|
fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs]
|
|
|
|
# Update frozen_model_input::input_positons.
|
|
assert fmi.input_positions is not None
|
|
assert fmi.input_positions.shape[0] >= self.num_seqs
|
|
fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self.
|
|
num_seqs]
|
|
|
|
# Assert unsupported
|
|
# TODO Uncomment the following codes when NPU supported
|
|
# assert fmi.lora_mapping is None
|
|
# assert fmi.lora_requests is not None
|
|
# assert len(fmi.lora_requests) == 0
|
|
# assert fmi.prompt_adapter_mapping is None
|
|
# assert fmi.prompt_adapter_requests is not None
|
|
# assert len(fmi.prompt_adapter_requests) == 0
|
|
assert fmi.attn_metadata is not None
|
|
assert fmi.multi_modal_kwargs is not None
|
|
assert len(fmi.multi_modal_kwargs) == 0
|
|
|
|
self.frozen_model_input = dataclasses.replace(
|
|
self.frozen_model_input,
|
|
input_tokens=fmi_new_input_tokens,
|
|
input_positions=fmi_new_input_positions)
|
|
|
|
self.maybe_advance_sampling_metadata(device, pin_memory)
|
|
|
|
|
|
@dataclass(frozen=False)
|
|
class NPUModelOutput(ModelOutput):
|
|
|
|
logprobs: Optional["torch.Tensor"] = None
|
|
|
|
def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput",
|
|
copy_stream: torch.npu.Stream,
|
|
pinned_sampled_token_buffer: torch.Tensor,
|
|
blocking: bool) -> bool:
|
|
"""
|
|
If blocking is set, will block until the forward pass for the output is
|
|
ready and pythonize the output. Upon completing Pythonization, erases
|
|
self.logprobs (note that a non-blocking call that is performed when
|
|
the sampler output is not yet ready, will not erase self.logprobs.)
|
|
"""
|
|
assert self.sampled_token_ids is not None
|
|
if not blocking and not self.sampler_output_ready_event.query():
|
|
return False
|
|
|
|
if blocking:
|
|
self.sampler_output_ready_event.synchronize()
|
|
with torch.npu.stream(copy_stream):
|
|
_pythonize_sampler_output(input_metadata, self.sampler_output,
|
|
pinned_sampled_token_buffer,
|
|
self.sampled_token_ids, self.logprobs,
|
|
self.pythonization_cache)
|
|
|
|
# Erase the logprobs GPU-side tensor.
|
|
# Note that although _pythonize_sampler_output() runs in its
|
|
# own CUDA stream, nonetheless _pythonize_sampler_output()
|
|
# cannot return until Pythonization is complete; therefore
|
|
# we know that by the time the CPU reaches this point,
|
|
# `self.logprobs` is no longer needed.
|
|
self.logprobs = None
|
|
return True
|
|
|
|
|
|
class MultiStepModelNPURunner(NPUModelRunnerBase[StatefulModelInputForNPU]):
|
|
# mypy: enable-error-code=type-var
|
|
|
|
def __init__(self, base_model_runner: NPUModelRunnerBase, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# uses the base model runner to execute the model and wraps it with
|
|
# multi-step logic
|
|
self._base_model_runner: NPUModelRunnerBase = base_model_runner
|
|
|
|
self.is_multi_step = self.scheduler_config.is_multi_step
|
|
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
|
|
|
|
# Using the PythonizationCache in Pipeline-Parallel clobbers the
|
|
# SequenceOutput and CompletionSequenceGroupOutput object.
|
|
# When cache-reset happens at the last step of a multi-step
|
|
# execution, there may be other on-going single-step/multi-step
|
|
# executions. The current caching implementation does not check
|
|
# for this.
|
|
self.pythonization_cache = PythonizationCache() \
|
|
if self.parallel_config.pipeline_parallel_size == 1 else None
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self._base_model_runner.get_model()
|
|
|
|
@functools.cached_property
|
|
def _copy_stream(self):
|
|
# used to copy tensors from NPU to CPU asynchronously
|
|
return torch.npu.Stream()
|
|
|
|
def make_model_input_from_broadcasted_tensor_dict(
|
|
self, tensor_dict: Dict[str, Any]) -> StatefulModelInputForNPU:
|
|
model_input = (StatefulModelInputForNPU.from_broadcasted_tensor_dict(
|
|
tensor_dict,
|
|
attn_backend=self.attn_backend,
|
|
))
|
|
return model_input
|
|
|
|
def prepare_model_input(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
virtual_engine: int = 0,
|
|
finished_requests_ids: Optional[List[str]] = None
|
|
) -> StatefulModelInputForNPU:
|
|
frozen_model_input: ModelInputForNPUWithSamplingMetadata = \
|
|
self._base_model_runner.prepare_model_input(
|
|
seq_group_metadata_list,
|
|
virtual_engine,
|
|
finished_requests_ids)
|
|
|
|
assert frozen_model_input.query_lens is not None
|
|
assert frozen_model_input.seq_lens is not None
|
|
assert frozen_model_input.attn_metadata is not None
|
|
num_queries = len(frozen_model_input.query_lens)
|
|
num_seqs = len(frozen_model_input.seq_lens)
|
|
num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
|
|
|
|
model_input = StatefulModelInputForNPU(
|
|
frozen_model_input=frozen_model_input,
|
|
num_seqs=num_seqs,
|
|
num_queries=num_queries,
|
|
num_single_step_prefills=num_single_step_prefills,
|
|
step_cuda_events=[torch.npu.Event(blocking=True)] * 2,
|
|
)
|
|
|
|
return model_input
|
|
|
|
def _async_process_outputs(self, model_input: StatefulModelInputForNPU,
|
|
output_proc_callback: Callable):
|
|
# Proceed with pythonization and output_proc in order.
|
|
# Stop on the first one that fails to pythonize
|
|
output_proc_callback()
|
|
|
|
cont = True
|
|
for step_num, model_output in enumerate(model_input.cached_outputs):
|
|
if not model_output.pythonized:
|
|
model_output.maybe_pythonize(model_input, self._copy_stream,
|
|
self.pinned_sampled_token_ids)
|
|
if model_output.pythonized:
|
|
ctx = output_proc_callback.keywords["ctx"] # type: ignore
|
|
ctx.append_output(
|
|
outputs=[model_output.sampler_output],
|
|
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
|
scheduler_outputs=ctx.scheduler_outputs,
|
|
is_async=False,
|
|
is_last_step=False,
|
|
is_first_step_output=step_num == 0)
|
|
|
|
output_proc_callback()
|
|
else:
|
|
cont = False
|
|
|
|
if not cont:
|
|
break
|
|
|
|
def _final_process_outputs(
|
|
self, model_input: StatefulModelInputForNPU,
|
|
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
|
|
assert model_input.frozen_model_input is not None
|
|
|
|
has_async_callback = output_proc_callback is not None
|
|
|
|
outputs = []
|
|
for step_num, output in enumerate(model_input.cached_outputs):
|
|
is_last_step = step_num == len(model_input.cached_outputs) - 1
|
|
|
|
# For non-async case:
|
|
# -- We simply add the outputs
|
|
# For async case:
|
|
# -- Invoke callback, pythonize, add to callback queue and repeat
|
|
# -- For last output, just add to callback queue
|
|
if has_async_callback:
|
|
assert output_proc_callback is not None
|
|
|
|
# Invoke callback before pythonize (to overlap with NPU)
|
|
output_proc_callback()
|
|
|
|
# Pythonize
|
|
if not output.pythonized:
|
|
output.pythonize(model_input, self._copy_stream,
|
|
self.pinned_sampled_token_ids)
|
|
|
|
# For non last step, add to callback queue to chain
|
|
# callbacks=>pythonize pairs (for NPU overlap)
|
|
if not is_last_step:
|
|
ctx = output_proc_callback.keywords[ # type: ignore
|
|
"ctx"] # type: ignore
|
|
ctx.append_output(
|
|
outputs=[output.sampler_output],
|
|
seq_group_metadata_list=ctx.
|
|
seq_group_metadata_list,
|
|
scheduler_outputs=ctx.scheduler_outputs,
|
|
is_async=False,
|
|
is_last_step=False,
|
|
is_first_step_output=step_num == 0)
|
|
else:
|
|
outputs.append(output.sampler_output)
|
|
else:
|
|
output.pythonize(model_input, self._copy_stream,
|
|
self.pinned_sampled_token_ids)
|
|
outputs.append(output.sampler_output)
|
|
|
|
return outputs
|
|
|
|
@torch.inference_mode()
|
|
def execute_model(
|
|
self,
|
|
model_input: StatefulModelInputForNPU,
|
|
kv_caches: List[torch.Tensor],
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
num_steps: int = 1,
|
|
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
|
"""
|
|
Execute the model for a single step and update multi-step
|
|
metadata
|
|
"""
|
|
assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
|
|
frozen_model_input = model_input.frozen_model_input
|
|
assert frozen_model_input is not None
|
|
|
|
# path for warm up runs
|
|
if not model_input.is_multi_step:
|
|
return self._base_model_runner.execute_model(
|
|
frozen_model_input, kv_caches, intermediate_tensors, num_steps)
|
|
|
|
# make sure we skip the sampler on the lask rank and only pythonize
|
|
# if CPU is ahead.
|
|
if self.is_driver_worker and get_pp_group().is_last_rank:
|
|
if self.pinned_sampled_token_ids is None:
|
|
self.pinned_sampled_token_ids = torch.zeros(
|
|
(self.scheduler_config.max_num_seqs, 1),
|
|
dtype=torch.long,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self._base_model_runner.sampler.include_gpu_probs_tensor = True
|
|
if frozen_model_input.sampling_metadata:
|
|
frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
|
|
True)
|
|
|
|
# some pre-execute model logic for multi-step:
|
|
# - if it's the first step, we need to reset the sampling tensors
|
|
# - if it's not the first step, we need to advance the step using the
|
|
# appended sampler output from last iteration
|
|
# - also maybe pythonize if CPU is ahead of NPU
|
|
|
|
stream = current_stream()
|
|
if not model_input.is_first_multi_step:
|
|
# Explicitly block on the previous step's forward to make sure we
|
|
# don't clobber any NPU tensors still in use.
|
|
# This is not needed for flashattn backend, but for other attn
|
|
# backends such as flashinfer that performs extra CPU operations on
|
|
# input metadata we may need to synchronize any CPU operations that
|
|
# might clobber enqueued forwards. (prevents CPU from running too
|
|
# far ahead if needed)
|
|
model_input.wait_previous_step()
|
|
model_input = self._advance_step(
|
|
model_input, model_input.cached_outputs[-1].sampler_output)
|
|
|
|
# frozen_model_input may have been updated
|
|
frozen_model_input = model_input.frozen_model_input
|
|
assert frozen_model_input is not None
|
|
|
|
if model_input.base_output_proc_callback is None:
|
|
assert frozen_model_input is not None
|
|
model_input.base_output_proc_callback = \
|
|
frozen_model_input.async_callback
|
|
|
|
if frozen_model_input.async_callback is not None:
|
|
assert model_input.base_output_proc_callback is not None
|
|
async_callback = functools.partial(
|
|
self._async_process_outputs,
|
|
model_input=model_input,
|
|
output_proc_callback=model_input.base_output_proc_callback)
|
|
|
|
model_input.frozen_model_input = dataclasses.replace( # type: ignore
|
|
model_input.frozen_model_input,
|
|
async_callback=async_callback)
|
|
# Update the local instance
|
|
frozen_model_input = model_input.frozen_model_input
|
|
assert frozen_model_input is not None
|
|
|
|
# Execute the model
|
|
output = self._base_model_runner.execute_model(frozen_model_input,
|
|
kv_caches,
|
|
intermediate_tensors,
|
|
num_steps=1)
|
|
|
|
# record the event for the current step so that the next step can sync
|
|
model_input.record_step_event(stream)
|
|
|
|
if get_pp_group().is_last_rank and self.is_driver_worker:
|
|
assert isinstance(output, list)
|
|
assert len(
|
|
output
|
|
) == 1, "MultiStepModelRunner requires single-step base_models"
|
|
|
|
# event for the pythonization so that we only pythonize if the
|
|
# tensors are ready. May be able to be combined with the step event
|
|
output_ready_event = torch.npu.Event()
|
|
output_ready_event.record(stream)
|
|
if self.parallel_config.pipeline_parallel_size > 1:
|
|
output[0].sampled_token_ids_cpu = output[
|
|
0].sampled_token_ids.cpu()
|
|
model_input.cached_outputs.append(
|
|
NPUModelOutput(output[0], output_ready_event,
|
|
output[0].sampled_token_ids, False,
|
|
output[0].logprobs, self.pythonization_cache))
|
|
|
|
# These NPU tensors are not required by multi-step;
|
|
# erase them to ensure they are not pythonized or
|
|
# transferred to CPU
|
|
output[0].sampled_token_ids = None
|
|
output[0].sampled_token_probs = None
|
|
output[0].logprobs = None
|
|
|
|
# Pythonize the output if CPU is ahead and the previous step is
|
|
# ready.
|
|
if frozen_model_input.async_callback is None:
|
|
for model_output in model_input.cached_outputs:
|
|
model_output.maybe_pythonize(model_input,
|
|
self._copy_stream,
|
|
self.pinned_sampled_token_ids)
|
|
|
|
model_input.current_step += 1
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
# Should be IntermediateTensors
|
|
assert isinstance(output, IntermediateTensors)
|
|
return output
|
|
if not self.is_driver_worker:
|
|
return []
|
|
|
|
# Pythonize the output and block if needed since it is the last step
|
|
if model_input.is_last_step:
|
|
outputs = self._final_process_outputs(
|
|
model_input, model_input.base_output_proc_callback)
|
|
if self.pythonization_cache:
|
|
self.pythonization_cache.reset()
|
|
return outputs
|
|
|
|
# should be [SamplerOutput]
|
|
return output
|
|
|
|
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
|
|
num_seqs: Optional[int], num_queries: int):
|
|
|
|
assert sampling_metadata.num_prompts == 0
|
|
assert len(sampling_metadata.seq_groups) == num_queries
|
|
assert sampling_metadata.selected_token_indices.shape == (
|
|
num_queries, )
|
|
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
|
|
|
|
# Verify that all sequences are decodes
|
|
for i in range(num_queries):
|
|
seq_group = sampling_metadata.seq_groups[i]
|
|
|
|
assert seq_group.is_prompt is False # No prompt
|
|
assert seq_group.prompt_logprob_indices == [] # No prompt
|
|
assert seq_group.sample_indices == [i] # Simple
|
|
assert seq_group.seq_len is None # Decode
|
|
assert seq_group.query_len is None # Decode
|
|
|
|
def _advance_step(self, model_input: StatefulModelInputForNPU,
|
|
out: SamplerOutput) -> StatefulModelInputForNPU:
|
|
|
|
model_input.maybe_advance_frozen_model_input(self.device,
|
|
self.pin_memory)
|
|
frozen_model_input = model_input.frozen_model_input
|
|
assert frozen_model_input is not None
|
|
assert frozen_model_input.input_tokens is not None
|
|
assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs
|
|
assert frozen_model_input.attn_metadata is not None
|
|
|
|
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
|
|
num_seqs = model_input.num_seqs
|
|
num_queries = model_input.num_queries
|
|
frozen_model_input = model_input.frozen_model_input
|
|
assert frozen_model_input is not None
|
|
attn_metadata = frozen_model_input.attn_metadata
|
|
assert attn_metadata is not None
|
|
|
|
turn_prefills_into_decodes: bool = model_input.current_step == 1 and \
|
|
model_input.num_single_step_prefills != 0
|
|
attn_metadata.advance_step(
|
|
frozen_model_input,
|
|
sampled_token_ids,
|
|
self.block_size,
|
|
num_seqs,
|
|
num_queries,
|
|
turn_prefills_into_decodes=turn_prefills_into_decodes)
|
|
|
|
return model_input
|
|
|
|
def load_model(self) -> None:
|
|
self._base_model_runner.load_model()
|
|
self.model_memory_usage = self._base_model_runner.model_memory_usage
|
|
|
|
def save_sharded_state(
|
|
self,
|
|
path: str,
|
|
pattern: Optional[str] = None,
|
|
max_size: Optional[int] = None,
|
|
) -> None:
|
|
return self._base_model_runner.save_sharded_state(
|
|
path, pattern, max_size)
|
|
|
|
def save_tensorized_model(self,
|
|
tensorizer_config: TensorizerConfig) -> None:
|
|
return self._base_model_runner.save_tensorized_model(tensorizer_config)
|
|
|
|
def profile_run(self) -> None:
|
|
return self._base_model_runner.profile_run()
|
|
|
|
def remove_all_loras(self):
|
|
return self._base_model_runner.remove_all_loras()
|
|
|
|
def capture_model(self, kv_caches: List[List]) -> None:
|
|
return self._base_model_runner.capture_model(kv_caches)
|
|
|
|
@property
|
|
def vocab_size(self) -> int:
|
|
return self._base_model_runner.vocab_size
|
|
|
|
|
|
DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
|
|
Optional[List[SampleLogprobs]]]
|
|
|
|
|
|
def deferred_pythonize_logprobs(
|
|
output: SamplerOutput,
|
|
sampling_metadata: SamplingMetadata,
|
|
logprobs_tensor: Optional[torch.Tensor],
|
|
) -> DeferredLogprobsReturnType:
|
|
"""Perform deferred logprob Pythonization.
|
|
|
|
1. Pythonize NPU-side sampler result tensors into CPU-side sampler result.
|
|
2. Pythonize NPU-side logprobs tensor into CPU-side logprobs lists,
|
|
utilizing the Pythonized sampler result computed in step 1.
|
|
|
|
These deferred computations are not required for single-step scheduling
|
|
or the `profile_run()` phase of multi-step scheduling.
|
|
|
|
Args:
|
|
output: sampler output (under deferred Pythonization)
|
|
sampling_metadata
|
|
|
|
Returns:
|
|
prompt_logprobs (CPU), sample_logprobs (CPU)
|
|
"""
|
|
|
|
# - Deferred pythonization of sample result
|
|
sampler_result = get_pythonized_sample_results(
|
|
output.deferred_sample_results_args)
|
|
|
|
# - Erase the NPU-side deferred sample_result
|
|
# computation args to ensure it is never
|
|
# pythonized or transferred to CPU
|
|
output.deferred_sample_results_args = None
|
|
|
|
# - Deferred pythonization of logprobs
|
|
(
|
|
prompt_logprobs,
|
|
sample_logprobs,
|
|
) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
|
|
assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
|
|
assert len(sample_logprobs) == len(sampling_metadata.seq_groups)
|
|
|
|
return prompt_logprobs, sample_logprobs
|
|
|
|
|
|
def _pythonize_sampler_output(
|
|
model_input: StatefulModelInputForNPU,
|
|
output: SamplerOutput,
|
|
pinned_sampled_token_buffer: torch.Tensor,
|
|
sampled_token_ids: torch.Tensor,
|
|
logprobs_tensor: Optional[torch.Tensor],
|
|
cache: Optional[PythonizationCache],
|
|
) -> None:
|
|
""" This function is only called when the output tensors are ready.
|
|
See :class:`ModelOutput`.
|
|
|
|
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
|
|
adding a Pythonized output data structure
|
|
(:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`.
|
|
|
|
Args:
|
|
model_input
|
|
output: sampler output
|
|
pinned_sampled_token_token_buffer: CPU-side pinned memory
|
|
(receives copy of
|
|
NPU-side token buffer.)
|
|
sampled_token_ids: NPU-side token buffer
|
|
logprobs_tensor: NPU-side tensor containing
|
|
logprobs computed during sampling
|
|
"""
|
|
|
|
assert model_input.frozen_model_input is not None
|
|
|
|
frozen_model_input = model_input.frozen_model_input
|
|
assert frozen_model_input.sampling_metadata is not None
|
|
sampling_metadata = frozen_model_input.sampling_metadata
|
|
# samples generation should have been skipped
|
|
assert not output.outputs
|
|
|
|
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
|
|
|
|
# We guarantee output tensors are ready, so it is safe to
|
|
# pythonize the sampler output & obtain CPU-side logprobs.
|
|
#
|
|
# However we should check whether logprobs pythonization may
|
|
# be skipped entirely, i.e. because no logprobs were requested
|
|
# or pythonization was not deferred. To that end,
|
|
#
|
|
# * `prompt_logprobs_are_requested_for_prefill` signals that
|
|
# there are *any* prefill-phase requests which specify that
|
|
# prompt logprobs should be returned.
|
|
#
|
|
# * `any_logprobs_are_requested` signals that there are any
|
|
# requests which (1) specify that sample logprobs should be
|
|
# returned, or (2) are in the prefill phase AND specify that
|
|
# prompt logprobs should be returned.
|
|
#
|
|
# Later on, these flags cause adjustments to the pythonization
|
|
# process to accommodate logprobs.
|
|
|
|
seq_groups = sampling_metadata.seq_groups
|
|
prompt_logprobs_are_requested_for_prefill = any([
|
|
sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
|
|
for sg in seq_groups
|
|
])
|
|
any_logprobs_are_requested = (
|
|
prompt_logprobs_are_requested_for_prefill
|
|
or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))
|
|
|
|
if prompt_logprobs_are_requested_for_prefill:
|
|
# CPU NPU sync, after gathering *only* sampled tokens (since
|
|
# requesting prompt logprobs leads `sampled_token_ids` to
|
|
# include prompt token ids in addition to sampled token ids.)
|
|
sample_idx_tensor = torch.tensor(
|
|
[sdx for sg in seq_groups for sdx in sg.sample_indices])
|
|
pinned_buffer = pinned_buffer.copy_(
|
|
sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
|
|
else:
|
|
# CPU NPU sync
|
|
pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
|
|
non_blocking=False)
|
|
|
|
# this will not block as the tensors are already on CPU
|
|
samples_list = pinned_buffer.tolist()
|
|
|
|
skip_sampler_cpu_output = (
|
|
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
|
|
|
|
# *Don't* skip logprobs pythonization *if*:
|
|
# * Any requests require logprobs to be returned in this
|
|
# iteration AND
|
|
# * These requests are being scheduled in a fashion which
|
|
# defers pythonization (i.e. multi-step scheduling.)
|
|
do_pythonize_logprobs = (skip_sampler_cpu_output
|
|
and any_logprobs_are_requested)
|
|
(
|
|
prompt_logprobs,
|
|
sample_logprobs,
|
|
) = (deferred_pythonize_logprobs(output, sampling_metadata,
|
|
logprobs_tensor)
|
|
if do_pythonize_logprobs else (None, None))
|
|
|
|
for sgdx, (seq_group,
|
|
sample_result) in enumerate(zip(seq_groups, samples_list)):
|
|
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
|
# If the feature combo become valid
|
|
# (Check for Guided Decoding)
|
|
if seq_group.sampling_params.logits_processors:
|
|
assert len(seq_group.sampling_params.logits_processors) == 0, (
|
|
"Logits Processors are not supported in multi-step decoding")
|
|
|
|
if do_pythonize_logprobs:
|
|
assert prompt_logprobs is not None
|
|
assert sample_logprobs is not None
|
|
|
|
(
|
|
group_prompt_logprobs,
|
|
group_sample_logprobs,
|
|
) = ( # Utilize deferred pythonization results
|
|
prompt_logprobs[sgdx],
|
|
sample_logprobs[sgdx],
|
|
)
|
|
elif any_logprobs_are_requested:
|
|
(
|
|
group_prompt_logprobs,
|
|
group_sample_logprobs,
|
|
) = (
|
|
# profile_run: use already-computed logprobs
|
|
output.outputs[sgdx].prompt_logprobs,
|
|
[sample.logprobs for sample in output.outputs[sgdx].samples])
|
|
|
|
seq_ids = seq_group.seq_ids
|
|
next_token_ids = sample_result
|
|
parent_ids = [0]
|
|
seq_outputs: List[SequenceOutput]
|
|
|
|
if cache is not None:
|
|
completion_seq_group_output: CompletionSequenceGroupOutput = \
|
|
cache.cached_completion_seq_group_output.get_object()
|
|
completion_seq_group_output.samples.clear()
|
|
seq_outputs = completion_seq_group_output.samples
|
|
else:
|
|
seq_outputs = []
|
|
|
|
for tdx, (parent_id,
|
|
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
|
|
if cache is not None:
|
|
seq_output: SequenceOutput = cache.cached_seq_output.get_object(
|
|
)
|
|
seq_output.parent_seq_id = seq_ids[parent_id]
|
|
seq_output.output_token = next_token_id
|
|
|
|
if any_logprobs_are_requested:
|
|
seq_output.logprobs = group_sample_logprobs[tdx]
|
|
else:
|
|
logprobs = next(iter(seq_output.logprobs.values()))
|
|
seq_output.logprobs.clear()
|
|
|
|
logprobs.logprob = float('inf')
|
|
logprobs.rank = None
|
|
logprobs.decoded_token = None
|
|
|
|
seq_output.logprobs[next_token_id] = logprobs
|
|
|
|
seq_outputs.append(seq_output)
|
|
|
|
else:
|
|
seq_outputs.append(
|
|
SequenceOutput(seq_ids[parent_id], next_token_id,
|
|
(group_sample_logprobs[tdx]
|
|
if any_logprobs_are_requested else {
|
|
next_token_id:
|
|
Logprob(logprob=float('inf'),
|
|
rank=None,
|
|
decoded_token=None)
|
|
})))
|
|
if cache is not None:
|
|
completion_seq_group_output.prompt_logprobs = \
|
|
group_prompt_logprobs if any_logprobs_are_requested else None
|
|
output.outputs.append(completion_seq_group_output)
|
|
else:
|
|
output.outputs.append(
|
|
CompletionSequenceGroupOutput(
|
|
seq_outputs, (group_prompt_logprobs
|
|
if any_logprobs_are_requested else None)))
|
|
|
|
assert len(output.outputs) > 0
|