Files
2026-01-09 15:09:53 +08:00

137 lines
6.1 KiB
Python

import copy
import weakref
from typing import Dict, List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.spec_decode.top1_proproser import ZeroOverheadTop1Proposer
from vllm.zero_overhead.utils import SpecStepKind, set_spec_step
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.worker.worker_base import DelegateWorkerBase
class ZeroOverheadMultiStepWorker(MultiStepWorker):
def init_device(self) -> None:
self.worker.init_device()
self._proposer = ZeroOverheadTop1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
set_spec_step(SpecStepKind.FIRST_PROPOSAL)
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
set_spec_step(SpecStepKind.OTHER_PROPOSAL)
self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
set_spec_step(SpecStepKind.SCORE_DECODE)
filtered_model_outputs = self._filter_model_output_zero_overhead(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
def _filter_model_output_zero_overhead(self,
expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
indices_of_seq_with_bonus_tokens = async_tensor_h2d(output_indices_to_retain, torch.int32,
self.device,
True)
return [
SamplerOutput(
outputs=[
expanded_batch_output.outputs[i]
for i in output_indices_to_retain
] if len(expanded_batch_output.outputs) > 0 else [],
sampled_token_probs=(
expanded_batch_output.
sampled_token_probs[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.sampled_token_probs is not None
else None),
logprobs=(
expanded_batch_output.logprobs[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.logprobs is not None else None),
sampled_token_ids=(expanded_batch_output.
sampled_token_ids[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.sampled_token_ids
is not None else None))
for expanded_batch_output in expanded_batch_outputs
]