init src 0.9.2
This commit is contained in:
84
vllm/zero_overhead/spec_decode/top1_proproser.py
Normal file
84
vllm/zero_overhead/spec_decode/top1_proproser.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.utils import async_tensor_h2d
|
||||
from vllm.zero_overhead.utils import record_proposal_lens_list
|
||||
|
||||
class ZeroOverheadTop1Proposer(Top1Proposer):
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
proposal_len: int,
|
||||
maybe_sampler_output: Optional[List[SamplerOutput]],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
sampler_transposed: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.tensor(-1,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len)
|
||||
proposal_probs = torch.tensor(0,
|
||||
dtype=torch.float32,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len,
|
||||
self._vocab_size)
|
||||
proposal_lens_tensor = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
len(proposal_lens))
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
|
||||
sampler_output, sampler_transposed)
|
||||
|
||||
proposal_lens_list = [0 for i in range(batch_size)]
|
||||
for indices in nonzero_proposal_len_indices:
|
||||
proposal_lens_list[indices] = proposal_len
|
||||
record_proposal_lens_list(proposal_lens_list)
|
||||
|
||||
nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32,
|
||||
self._device,
|
||||
True)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = proposal_tokens.new_full(
|
||||
size=(batch_size, *proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = proposal_probs.new_zeros(
|
||||
batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = (
|
||||
entire_proposal_tokens,
|
||||
entire_proposal_probs,
|
||||
)
|
||||
|
||||
proposal_lens_tensor = async_tensor_h2d(proposal_lens_list, torch.long,
|
||||
self._device,
|
||||
True)
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
Reference in New Issue
Block a user