141 lines
6.4 KiB
Python
141 lines
6.4 KiB
Python
from array import array
|
|
import numpy as np
|
|
from itertools import chain, count
|
|
from typing import Iterator, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from vllm import SamplingParams
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
|
|
ExecuteModelRequest, SequenceData,
|
|
SequenceGroupMetadata, get_all_seq_ids)
|
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
|
SpeculativeScorer, SpeculativeScores)
|
|
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
|
from vllm.utils import async_tensor_h2d
|
|
from vllm.zero_overhead.utils import get_proposal_lens_list, record_proposal_token_ids
|
|
|
|
SeqId = int
|
|
TargetSeqId = int
|
|
TokenId = int
|
|
|
|
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
|
|
|
|
|
|
class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
|
|
|
|
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
|
def score_proposals(
|
|
self,
|
|
execute_model_req: ExecuteModelRequest,
|
|
proposals: SpeculativeProposals,
|
|
) -> SpeculativeScores:
|
|
"""Score the proposed tokens via the scorer model.
|
|
|
|
This converts each input sequence to a set of k+1 target sequences. The
|
|
target sequences have the unique continuations to be scored and a
|
|
unique sequence ID that is different from all input sequence ids.
|
|
|
|
If a speculative sequence length would exceed the max model length, then
|
|
no speculation is produced for that sequence.
|
|
|
|
Args:
|
|
execute_model_req: The execution request.
|
|
proposals: The speculative proposals to score.
|
|
Returns:
|
|
SpeculativeScores: The scores of each speculative token, along with
|
|
which sequences were ignored during scoring.
|
|
"""
|
|
|
|
proposal_lens_list = get_proposal_lens_list()
|
|
record_proposal_token_ids(proposals.proposal_token_ids)
|
|
proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist() # place holder tokens
|
|
|
|
# Filter the list to ignore invalid proposals.
|
|
proposal_token_ids_list_without_skips = [
|
|
proposals for proposals in proposal_token_ids_list
|
|
if VLLM_INVALID_TOKEN_ID not in proposals
|
|
]
|
|
|
|
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
|
num_scoring_tokens) = self._expand_batch(
|
|
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
|
|
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
|
proposal_lens_list=proposal_lens_list,
|
|
)
|
|
|
|
target_sampler_output = self._scorer_worker.execute_model(
|
|
execute_model_req=execute_model_req.clone(
|
|
seq_group_metadata_list=target_seq_group_metadata_list))
|
|
assert len(target_sampler_output) == 1, "expected single-step output"
|
|
target_sampler_output = target_sampler_output[0]
|
|
if not non_spec_indices:
|
|
# All sequence groups in batch have spec decoding enabled
|
|
return self._contract_batch_all_spec(
|
|
target_sampler_output=target_sampler_output,
|
|
proposals=proposals,
|
|
)
|
|
else:
|
|
# Batch has a mix of spec decode enabled and disabled seq groups
|
|
return self._contract_batch(
|
|
execute_model_req.seq_group_metadata_list,
|
|
target_sampler_output=target_sampler_output,
|
|
proposals=proposals,
|
|
num_scoring_tokens=num_scoring_tokens,
|
|
non_spec_indices=non_spec_indices,
|
|
spec_indices=spec_indices,
|
|
k=execute_model_req.num_lookahead_slots,
|
|
)
|
|
|
|
def _contract_non_speculative(
|
|
self, scores: SpeculativeScores,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
|
|
has_prompt_log: bool) -> SpeculativeScores:
|
|
"""
|
|
Augment input `scores` with non-speculative requests outputs.
|
|
This includes decode requests with speculation turned off, as well
|
|
as prefill requests when `enable_chunked_prefill` is set.
|
|
For the latter, prefills are further separated into terminal and
|
|
non-terminal chunks (from which no token is sampled).
|
|
"""
|
|
if not non_spec_indices:
|
|
return scores
|
|
|
|
if has_prompt_log:
|
|
# When prompt_logprobs is enabled, prefills yield output token
|
|
# (and respective prob) in the last entry (prompt|out):
|
|
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
|
|
# With chunked prefill, non-terminal chunks have -1 on each
|
|
# position: they're still picked, but they're discarded later.
|
|
seq_meta = seq_group_metadata_list
|
|
nospec_sizes = torch.tensor([
|
|
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
|
|
for i in non_spec_indices
|
|
])
|
|
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
|
|
else:
|
|
# In this case only sampled tokens are returned, select all.
|
|
nospec_sampled_token_idxs = list(
|
|
range(len(non_spec_outputs.token_ids)))
|
|
|
|
nospec_sampled_token_idxs = async_tensor_h2d(nospec_sampled_token_idxs, torch.int32,
|
|
self._device,
|
|
True)
|
|
non_spec_indices = async_tensor_h2d(non_spec_indices, torch.int32,
|
|
self._device,
|
|
True)
|
|
|
|
scores.token_ids[non_spec_indices, :1] = \
|
|
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
|
|
scores.probs[non_spec_indices, :1, :] = \
|
|
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
|
|
scores.logprobs[non_spec_indices, :1, :] = \
|
|
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
|
|
if scores.hidden_states is not None:
|
|
assert non_spec_outputs.hidden_states is not None
|
|
scores.hidden_states[non_spec_indices, :1, :] = \
|
|
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
|
|
return scores |