import copy import weakref from typing import Dict, List, Set, Tuple import torch from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, SequenceGroupMetadata) from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.utils import async_tensor_h2d from vllm.zero_overhead.spec_decode.top1_proproser import ZeroOverheadTop1Proposer from vllm.zero_overhead.utils import SpecStepKind, set_spec_step if current_platform.is_cuda_alike(): from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.worker.worker_base import DelegateWorkerBase class ZeroOverheadMultiStepWorker(MultiStepWorker): def init_device(self) -> None: self.worker.init_device() self._proposer = ZeroOverheadTop1Proposer( weakref.proxy(self), # type: ignore[arg-type] self.device, self.vocab_size, max_proposal_len=self.max_model_len, ) @torch.inference_mode() def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass sample_len times. Returns the list of sampler output, one per model forward pass, along with indicator of whether torch tensor in sampler output need to be transposed in latter sampler_output_to_torch logic. For multi step worker, this indicator shall be True. """ self._raise_if_unsupported(execute_model_req) # Expand the batch for sequences with a bonus token. # Perform a forward pass on the expanded batch and filter the # response to retain only the original sequences' responses. expanded_request, indices_of_seq_with_bonus_tokens =\ self._expand_execute_model_request( execute_model_req, seq_ids_with_bonus_token_in_last_step) # Run model sample_len times. model_outputs: List[SamplerOutput] = [] if current_platform.is_cuda_alike() and isinstance( self.model_runner, TP1DraftModelRunner ) and self.model_runner.supports_gpu_multi_step(expanded_request): # Here we run the draft_model_runner with multi-step prepare # on the GPU directly expanded_request.num_steps = sample_len self.model_runner.set_indices_of_seq_with_bonus_tokens( indices_of_seq_with_bonus_tokens) model_outputs = self.execute_model( execute_model_req=expanded_request) else: # Here we run multi-step directly, with every step prepared # on the CPU. # TODO: Remove this branch once DraftModelRunner supports TP>1 # and other restrictions that are part of DraftModelRunner's # supports_gpu_multi_step(..) set_spec_step(SpecStepKind.FIRST_PROPOSAL) for _ in range(sample_len): model_output: List[SamplerOutput] = self.worker.execute_model( execute_model_req=expanded_request) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] set_spec_step(SpecStepKind.OTHER_PROPOSAL) self._append_new_tokens( model_output, expanded_request.seq_group_metadata_list, indices_of_seq_with_bonus_tokens) model_outputs.append(model_output) set_spec_step(SpecStepKind.SCORE_DECODE) filtered_model_outputs = self._filter_model_output_zero_overhead( model_outputs, indices_of_seq_with_bonus_tokens) return filtered_model_outputs, True def _filter_model_output_zero_overhead(self, expanded_batch_outputs: List[SamplerOutput], output_indices_to_retain: List[int]) -> List[SamplerOutput]: """ Filters the model output to include only the specified sequence outputs. This method contracts the expanded batch output from the model to retain the outputs of only those sequences indicated by the provided indices. Args: expanded_batch_output (List[SamplerOutput]): The expanded output batch from the model. output_indices_to_retain (torch.Tensor): Indices of the model outputs to retain. Returns: List[SamplerOutput]: A list containing the filtered model outputs for the specified indices. """ indices_of_seq_with_bonus_tokens = async_tensor_h2d(output_indices_to_retain, torch.int32, self.device, True) return [ SamplerOutput( outputs=[ expanded_batch_output.outputs[i] for i in output_indices_to_retain ] if len(expanded_batch_output.outputs) > 0 else [], sampled_token_probs=( expanded_batch_output. sampled_token_probs[indices_of_seq_with_bonus_tokens] if expanded_batch_output.sampled_token_probs is not None else None), logprobs=( expanded_batch_output.logprobs[indices_of_seq_with_bonus_tokens] if expanded_batch_output.logprobs is not None else None), sampled_token_ids=(expanded_batch_output. sampled_token_ids[indices_of_seq_with_bonus_tokens] if expanded_batch_output.sampled_token_ids is not None else None)) for expanded_batch_output in expanded_batch_outputs ]