84 lines
3.7 KiB
Python
84 lines
3.7 KiB
Python
import os
|
|
from typing import List, Optional, Set, Tuple
|
|
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
|
SpeculativeProposer)
|
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
|
from vllm.spec_decode.util import sampler_output_to_torch
|
|
from vllm.utils import async_tensor_h2d
|
|
from vllm.zero_overhead.utils import record_proposal_lens_list
|
|
|
|
class ZeroOverheadTop1Proposer(Top1Proposer):
|
|
|
|
def _merge_outputs(
|
|
self,
|
|
batch_size: int,
|
|
proposal_len: int,
|
|
maybe_sampler_output: Optional[List[SamplerOutput]],
|
|
proposal_lens: List[int],
|
|
nonzero_proposal_len_indices: List[int],
|
|
sampler_transposed: bool,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""After speculations are produced, merge the speculation results with
|
|
the skipped sequences.
|
|
"""
|
|
if maybe_sampler_output is None:
|
|
# If no speculative tokens, the sampler output will be None.
|
|
# In this case we return empty proposals.
|
|
proposal_tokens = torch.tensor(-1,
|
|
dtype=torch.long,
|
|
device=self._device).expand(
|
|
batch_size, proposal_len)
|
|
proposal_probs = torch.tensor(0,
|
|
dtype=torch.float32,
|
|
device=self._device).expand(
|
|
batch_size, proposal_len,
|
|
self._vocab_size)
|
|
proposal_lens_tensor = torch.tensor(0,
|
|
dtype=torch.long,
|
|
device=self._device).expand(
|
|
len(proposal_lens))
|
|
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
|
|
|
sampler_output = maybe_sampler_output
|
|
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
|
|
sampler_output, sampler_transposed)
|
|
|
|
proposal_lens_list = [0 for i in range(batch_size)]
|
|
for indices in nonzero_proposal_len_indices:
|
|
proposal_lens_list[indices] = proposal_len
|
|
record_proposal_lens_list(proposal_lens_list)
|
|
|
|
nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32,
|
|
self._device,
|
|
True)
|
|
|
|
# Now, reformat the output GPU tensors such that each sequence has
|
|
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
|
|
|
entire_proposal_tokens = proposal_tokens.new_full(
|
|
size=(batch_size, *proposal_tokens.shape[1:]),
|
|
fill_value=-1,
|
|
)
|
|
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
|
entire_proposal_probs = proposal_probs.new_zeros(
|
|
batch_size,
|
|
*proposal_probs.shape[1:],
|
|
)
|
|
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
|
|
|
proposal_tokens, proposal_probs = (
|
|
entire_proposal_tokens,
|
|
entire_proposal_probs,
|
|
)
|
|
|
|
proposal_lens_tensor = async_tensor_h2d(proposal_lens_list, torch.long,
|
|
self._device,
|
|
True)
|
|
|
|
return proposal_tokens, proposal_probs, proposal_lens_tensor |