import os from typing import List, Optional, Set, Tuple import torch from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.util import sampler_output_to_torch from vllm.utils import async_tensor_h2d from vllm.zero_overhead.utils import record_proposal_lens_list class ZeroOverheadTop1Proposer(Top1Proposer): def _merge_outputs( self, batch_size: int, proposal_len: int, maybe_sampler_output: Optional[List[SamplerOutput]], proposal_lens: List[int], nonzero_proposal_len_indices: List[int], sampler_transposed: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """After speculations are produced, merge the speculation results with the skipped sequences. """ if maybe_sampler_output is None: # If no speculative tokens, the sampler output will be None. # In this case we return empty proposals. proposal_tokens = torch.tensor(-1, dtype=torch.long, device=self._device).expand( batch_size, proposal_len) proposal_probs = torch.tensor(0, dtype=torch.float32, device=self._device).expand( batch_size, proposal_len, self._vocab_size) proposal_lens_tensor = torch.tensor(0, dtype=torch.long, device=self._device).expand( len(proposal_lens)) return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( sampler_output, sampler_transposed) proposal_lens_list = [0 for i in range(batch_size)] for indices in nonzero_proposal_len_indices: proposal_lens_list[indices] = proposal_len record_proposal_lens_list(proposal_lens_list) nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32, self._device, True) # Now, reformat the output GPU tensors such that each sequence has # a proposal. the proposal can be empty, e.g. [-1, -1, -1] entire_proposal_tokens = proposal_tokens.new_full( size=(batch_size, *proposal_tokens.shape[1:]), fill_value=-1, ) entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens entire_proposal_probs = proposal_probs.new_zeros( batch_size, *proposal_probs.shape[1:], ) entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs proposal_tokens, proposal_probs = ( entire_proposal_tokens, entire_proposal_probs, ) proposal_lens_tensor = async_tensor_h2d(proposal_lens_list, torch.long, self._device, True) return proposal_tokens, proposal_probs, proposal_lens_tensor