[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
0
vllm/spec_decode/__init__.py
Normal file
0
vllm/spec_decode/__init__.py
Normal file
506
vllm/spec_decode/batch_expansion.py
Normal file
506
vllm/spec_decode/batch_expansion.py
Normal file
@@ -0,0 +1,506 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from array import array
|
||||
from itertools import chain, count
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
ExecuteModelRequest, SequenceData,
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
TokenId = int
|
||||
|
||||
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
|
||||
|
||||
|
||||
class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
"""Implements a speculative scorer that uses batch expansion to get
|
||||
probabilities of speculative tokens according to the scoring model.
|
||||
|
||||
Batch expansion converts a list of sequences and multiple query positions
|
||||
to a new batch of sequences, each with a single query position. This allows
|
||||
for MQA-like scoring in speculative decoding without requiring an MQA
|
||||
kernel.
|
||||
|
||||
It is strictly less efficient than MQA scoring.
|
||||
|
||||
It only supports scoring the top1 proposal tokens of the proposer, instead
|
||||
of topk/tree.
|
||||
"""
|
||||
|
||||
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
"""Score the proposed tokens via the scorer model.
|
||||
|
||||
This converts each input sequence to a set of k+1 target sequences. The
|
||||
target sequences have the unique continuations to be scored and a
|
||||
unique sequence ID that is different from all input sequence ids.
|
||||
|
||||
If a speculative sequence length would exceed the max model length, then
|
||||
no speculation is produced for that sequence.
|
||||
|
||||
Args:
|
||||
execute_model_req: The execution request.
|
||||
proposals: The speculative proposals to score.
|
||||
Returns:
|
||||
SpeculativeScores: The scores of each speculative token, along with
|
||||
which sequences were ignored during scoring.
|
||||
"""
|
||||
|
||||
# TODO(cade) perform this on GPU to remove blocking call.
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||
|
||||
# Filter the list to ignore invalid proposals.
|
||||
proposal_token_ids_list_without_skips = [
|
||||
proposals for proposals in proposal_token_ids_list
|
||||
if VLLM_INVALID_TOKEN_ID not in proposals
|
||||
]
|
||||
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens) = self._expand_batch(
|
||||
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list))
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
if not non_spec_indices:
|
||||
# All sequence groups in batch have spec decoding enabled
|
||||
return self._contract_batch_all_spec(
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
)
|
||||
else:
|
||||
# Batch has a mix of spec decode enabled and disabled seq groups
|
||||
return self._contract_batch(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
non_spec_indices=non_spec_indices,
|
||||
spec_indices=spec_indices,
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
)
|
||||
|
||||
def _expand_batch(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids_list: List[List[TokenId]],
|
||||
proposal_lens_list: List[int],
|
||||
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
|
||||
"""Given the input sequences and potentially multiple corresponding
|
||||
proposal tokens, create a new batch where each sequence has a single
|
||||
query token.
|
||||
"""
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
(spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
|
||||
split_batch_by_proposal_len(
|
||||
seq_group_metadata_list, proposal_lens_list)
|
||||
|
||||
spec_expanded_seqs = self._create_scoring_model_input(
|
||||
seq_group_metadata_list=spec_seqs,
|
||||
proposal_token_ids=proposal_token_ids_list,
|
||||
# NOTE: We determine the seq ids in the expanded batch using the
|
||||
# full seq_group_metadata_list, instead of only spec_seqs.
|
||||
target_seq_ids_iter=self._create_target_seq_id_iterator(
|
||||
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
|
||||
)
|
||||
|
||||
num_scoring_tokens = len(spec_expanded_seqs)
|
||||
# Batch speculative and non-speculative (e.g. chunked prefill) requests
|
||||
# but make sure order is prefill|decode due to backend requirement.
|
||||
target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs
|
||||
|
||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_non_speculative(
|
||||
self, scores: SpeculativeScores,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
|
||||
has_prompt_log: bool) -> SpeculativeScores:
|
||||
"""
|
||||
Augment input `scores` with non-speculative requests outputs.
|
||||
This includes decode requests with speculation turned off, as well
|
||||
as prefill requests when `enable_chunked_prefill` is set.
|
||||
For the latter, prefills are further separated into terminal and
|
||||
non-terminal chunks (from which no token is sampled).
|
||||
"""
|
||||
if not non_spec_indices:
|
||||
return scores
|
||||
|
||||
if has_prompt_log:
|
||||
# When prompt_logprobs is enabled, prefills yield output token
|
||||
# (and respective prob) in the last entry (prompt|out):
|
||||
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
|
||||
# With chunked prefill, non-terminal chunks have -1 on each
|
||||
# position: they're still picked, but they're discarded later.
|
||||
seq_meta = seq_group_metadata_list
|
||||
nospec_sizes = torch.tensor([
|
||||
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
|
||||
for i in non_spec_indices
|
||||
])
|
||||
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
|
||||
else:
|
||||
# In this case only sampled tokens are returned, select all.
|
||||
nospec_sampled_token_idxs = list(
|
||||
range(len(non_spec_outputs.token_ids)))
|
||||
|
||||
scores.token_ids[non_spec_indices, :1] = \
|
||||
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
scores.probs[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
scores.logprobs[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
if scores.hidden_states is not None:
|
||||
assert non_spec_outputs.hidden_states is not None
|
||||
scores.hidden_states[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
return scores
|
||||
|
||||
def _contract_batch(
|
||||
self,
|
||||
contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
target_sampler_output: SamplerOutput,
|
||||
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
||||
non_spec_indices: List[int], spec_indices: List[int],
|
||||
k: int) -> SpeculativeScores:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
|
||||
contracted_bs is the original batch size, and the batch size that the
|
||||
target_sampler_output will be contracted to.
|
||||
"""
|
||||
contracted_bs = len(contracted_seq_group_metadata_list)
|
||||
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
|
||||
non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs,
|
||||
non_spec_target_hidden_states) = self._split_scoring_output(
|
||||
target_sampler_output, num_scoring_tokens)
|
||||
|
||||
# Map distinct sequences used to score each token
|
||||
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||
expanded_batch_size, k = proposals.proposal_token_ids.shape
|
||||
|
||||
# The number of tokens in the expanded batch used for speculation is
|
||||
# equal to the total expanded batch size minus the number of samples for
|
||||
# non-speculative sequences, prefill chunks with no out tokens included
|
||||
non_spec_expanded_bs = len(non_spec_indices)
|
||||
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||
|
||||
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
|
||||
target_probs = target_probs.reshape(*target_token_ids.shape,
|
||||
self._vocab_size)
|
||||
target_logprobs = target_logprobs.reshape(target_probs.shape)
|
||||
|
||||
if target_hidden_states is not None:
|
||||
target_hidden_states = target_hidden_states.reshape(
|
||||
*target_token_ids.shape, target_hidden_states.shape[-1])
|
||||
|
||||
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1)
|
||||
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
|
||||
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
||||
fill_value=-float("inf"))
|
||||
|
||||
if target_sampler_output.hidden_states is not None:
|
||||
all_hidden_states = target_hidden_states.new_zeros(
|
||||
size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
|
||||
else:
|
||||
all_hidden_states = None
|
||||
|
||||
has_prompt_log = any((sg.sampling_params.prompt_logprobs
|
||||
and sg.sampling_params.prompt_logprobs > 0)
|
||||
for sg in contracted_seq_group_metadata_list)
|
||||
# When prompt logprobs is enabled, lens of returned tensors go from
|
||||
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
|
||||
# We adjust stride accordingly to get the generated tokens and
|
||||
# their probs, but pass on prompt_logprobs as is.
|
||||
prompt_logprobs = None
|
||||
if (not self._scorer_worker.model_runner.disable_logprobs\
|
||||
and has_prompt_log):
|
||||
prompt_logprobs = [
|
||||
o.prompt_logprobs for o in target_sampler_output.outputs
|
||||
]
|
||||
elif not has_prompt_log:
|
||||
# When prompt logprobs are not to be returned,
|
||||
# we can ignore non-terminal chunks (no out token).
|
||||
non_spec_indices = [
|
||||
idx for idx in non_spec_indices
|
||||
if contracted_seq_group_metadata_list[idx].do_sample
|
||||
]
|
||||
|
||||
# "Contract" speculative.
|
||||
if spec_indices:
|
||||
all_tokens[spec_indices] = target_token_ids
|
||||
all_probs[spec_indices] = target_probs
|
||||
all_logprobs[spec_indices] = target_logprobs
|
||||
if all_hidden_states is not None:
|
||||
all_hidden_states[spec_indices] = target_hidden_states
|
||||
|
||||
spec_scores = SpeculativeScores(probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=all_logprobs,
|
||||
hidden_states=all_hidden_states,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
|
||||
non_spec_outputs = SpeculativeScores(
|
||||
probs=non_spec_target_probs,
|
||||
token_ids=non_spec_target_token_ids,
|
||||
logprobs=non_spec_target_logprobs,
|
||||
hidden_states=non_spec_target_hidden_states)
|
||||
# Contract remaining nonspec entries based on non_spec_indices, if any.
|
||||
return self._contract_non_speculative(
|
||||
spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
|
||||
non_spec_outputs, has_prompt_log)
|
||||
|
||||
def _contract_batch_all_spec(
|
||||
self,
|
||||
target_sampler_output: SamplerOutput,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
|
||||
It assumes all sequences in the batch were previously expanded.
|
||||
"""
|
||||
|
||||
# Map distinct sequences used to score each token
|
||||
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||
contracted_bs, k = proposals.proposal_token_ids.shape
|
||||
|
||||
# Reshape tensors to original batch size
|
||||
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
|
||||
contracted_bs, k + 1)
|
||||
target_probs = target_sampler_output.sampled_token_probs.reshape(
|
||||
*target_token_ids.shape, self._vocab_size)
|
||||
target_logprobs = target_sampler_output.logprobs.reshape(
|
||||
target_probs.shape)
|
||||
target_hidden_states = target_sampler_output.hidden_states
|
||||
if target_hidden_states is not None:
|
||||
target_hidden_states = target_hidden_states.reshape(
|
||||
*target_token_ids.shape, target_hidden_states.shape[-1])
|
||||
|
||||
return SpeculativeScores(probs=target_probs,
|
||||
token_ids=target_token_ids,
|
||||
logprobs=target_logprobs,
|
||||
hidden_states=target_hidden_states,
|
||||
prompt_logprobs=None)
|
||||
|
||||
def _create_scoring_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Given the original input sequences and proposed tokens from the draft
|
||||
model, create a list of target sequences that can be used for scoring.
|
||||
|
||||
target_seq_ids_iter provides sequence ids for the expanded batch,
|
||||
fulfilling the requirement that no seq id in the expanded batch is equal
|
||||
to the seq id in the original batch.
|
||||
"""
|
||||
|
||||
if not seq_group_metadata_list:
|
||||
return []
|
||||
|
||||
target_seq_group_metadata = list(
|
||||
chain.from_iterable(
|
||||
self._create_target_seq_group_metadata(
|
||||
seq_group_metadata,
|
||||
proposal_token_ids,
|
||||
i,
|
||||
target_seq_ids_iter,
|
||||
) for i, seq_group_metadata in enumerate(
|
||||
seq_group_metadata_list)))
|
||||
|
||||
return target_seq_group_metadata
|
||||
|
||||
def _create_target_seq_group_metadata(
|
||||
self,
|
||||
input_seq_group_metadata: SequenceGroupMetadata,
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
batch_index: int,
|
||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Given an input sequence group metadata and a list of draft tokens,
|
||||
create a list of target SequenceGroupMetadata, one for each
|
||||
token id that needs to be scored.
|
||||
|
||||
Naive speculative decoding requires K target model scores, one for each
|
||||
draft model token. However one can add a bonus token such that if each
|
||||
token is accepted, then a final token may be sampled from the model.
|
||||
This function creates K+1 target SequenceGroupMetadata to take
|
||||
advantage of the bonus token.
|
||||
"""
|
||||
assert len(input_seq_group_metadata.seq_data) == 1, (
|
||||
"Beam search "
|
||||
"not supported in speculative decoding")
|
||||
input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
|
||||
|
||||
token_ids_to_score = self._get_token_ids_to_score(
|
||||
proposal_token_ids[batch_index])
|
||||
|
||||
sampling_params = input_seq_group_metadata.sampling_params
|
||||
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for i, token_ids in enumerate(token_ids_to_score):
|
||||
target_seq_group_metadata_list.append(
|
||||
self._create_single_target_seq_group_metadata(
|
||||
input_seq_group_metadata,
|
||||
input_seq_id,
|
||||
next(target_seq_ids_iter),
|
||||
token_ids,
|
||||
sampling_params=sampling_params,
|
||||
))
|
||||
|
||||
return target_seq_group_metadata_list
|
||||
|
||||
@staticmethod
|
||||
def _create_single_target_seq_group_metadata(
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_id: SeqId,
|
||||
target_seq_id: TargetSeqId,
|
||||
token_ids: List[TokenId],
|
||||
sampling_params: SamplingParams,
|
||||
) -> SequenceGroupMetadata:
|
||||
"""Create a single target SequenceGroupMetadata.
|
||||
|
||||
Args:
|
||||
seq_group_metadata: The metadata for the input sequence.
|
||||
seq_id: The input sequence ID.
|
||||
target_seq_id: The corresponding target sequence ID.
|
||||
token_ids: The list of token ids that are to be appended to the
|
||||
input sequence.
|
||||
"""
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_token_ids = seq_data.prompt_token_ids_array
|
||||
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
|
||||
mrope_position_delta = seq_data.mrope_position_delta
|
||||
|
||||
new_seq_data_dict = {
|
||||
target_seq_id:
|
||||
SequenceData(
|
||||
prompt_token_ids,
|
||||
_output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
new_output_token_ids),
|
||||
),
|
||||
}
|
||||
# This is a hack. Technically, spec decoding should compute
|
||||
# num_lookahead slots at one shot, but instead, it expands the batch
|
||||
# and evaluate one by one right now. context_len is seq_len - 1 because
|
||||
# the kv cache is filled by a previous batch in the batch expansion.
|
||||
for data in new_seq_data_dict.values():
|
||||
data.update_num_computed_tokens(data.get_len() - 1)
|
||||
data.mrope_position_delta = mrope_position_delta
|
||||
|
||||
return SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
seq_data=new_seq_data_dict,
|
||||
sampling_params=sampling_params,
|
||||
block_tables={
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
token_chunk_size=1,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _split_scoring_output(
|
||||
sampler_output: SamplerOutput, num_scoring_tokens: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Split the target model output into speculative and non-speculative
|
||||
output.
|
||||
"""
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
#
|
||||
# First samples are non-speculative, latter samples are from speculative
|
||||
# scoring (prefill|decode order).
|
||||
split_sizes = (sampler_output.sampled_token_ids.numel() -
|
||||
num_scoring_tokens, num_scoring_tokens)
|
||||
(non_spec_probs,
|
||||
spec_probs) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||
(non_spec_sampled_tokens, spec_sampled_tokens
|
||||
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
|
||||
(non_spec_logprobs,
|
||||
spec_logprobs) = sampler_output.logprobs.split(split_sizes)
|
||||
|
||||
if sampler_output.hidden_states is not None:
|
||||
(non_spec_hidden_states, spec_hidden_states
|
||||
) = sampler_output.hidden_states.split(split_sizes)
|
||||
else:
|
||||
non_spec_hidden_states, spec_hidden_states = None, None
|
||||
|
||||
return (spec_sampled_tokens, spec_probs, spec_logprobs,
|
||||
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
|
||||
non_spec_logprobs, non_spec_hidden_states)
|
||||
|
||||
@staticmethod
|
||||
def _create_target_seq_id_iterator(
|
||||
seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||
"""Create an iterator for creating target sequence ids.
|
||||
Target sequence ids are distinct from sequence ids because we create a
|
||||
distinct target sequence id for each proposal token to be scored.
|
||||
|
||||
This implementation increments a counter starting at 1 + max of all
|
||||
provided input sequence ids.
|
||||
"""
|
||||
return count(start=max(seq_ids) + 1)
|
||||
|
||||
@staticmethod
|
||||
def _get_token_ids_to_score(
|
||||
full_spec_token_ids: List[TokenId] # shape: [k]
|
||||
) -> List[List[TokenId]]:
|
||||
"""Given an int tensor of proposal token ids, return a list of
|
||||
token ids that should be scored.
|
||||
|
||||
Returns k+1 output lists. The additional one is used for generating the
|
||||
bonus token.
|
||||
|
||||
Example:
|
||||
Input: [0, 1, 2, 3] (k=4)
|
||||
Output: (k+1 lists)
|
||||
[]
|
||||
[0]
|
||||
[0, 1]
|
||||
[0, 1, 2]
|
||||
[0, 1, 2, 3]
|
||||
"""
|
||||
empty_token_ids: List[TokenId] = []
|
||||
|
||||
token_ids_to_score = [empty_token_ids]
|
||||
token_ids_to_score.extend(full_spec_token_ids[:i + 1]
|
||||
for i in range(len(full_spec_token_ids)))
|
||||
return token_ids_to_score
|
||||
349
vllm/spec_decode/draft_model_runner.py
Normal file
349
vllm/spec_decode/draft_model_runner.py
Normal file
@@ -0,0 +1,349 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
|
||||
try:
|
||||
try:
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# vllm_flash_attn is not installed, try the ROCm FA metadata
|
||||
from vllm.attention.backends.rocm_flash_attn import (
|
||||
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
||||
except (ModuleNotFoundError, ImportError) as err:
|
||||
raise RuntimeError(
|
||||
"Draft model speculative decoding currently only supports "
|
||||
"CUDA and ROCm flash attention backend.") from err
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
ModelRunnerInputBase,
|
||||
ModelRunnerWrapperBase)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# A flag to enable debug prints for the updated input tensors
|
||||
# before each step.
|
||||
debug_advance_input = False
|
||||
# A flag to allow GPU advance step for draft model runner.
|
||||
# Set to False for debugging.
|
||||
allow_gpu_advance_step = True
|
||||
|
||||
|
||||
class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
"""Specialized model runner for speculative decoding draft model.
|
||||
Since the draft model always execute k forward passes consecutively to
|
||||
generate k speculative tokens in a single speculative decoding step,
|
||||
we could get rid of most CPU-GPU synchronization and data transfer
|
||||
overheads by keeping model input and output tensors on GPU all the time.
|
||||
|
||||
TODOs:
|
||||
1. Currently supports only flash-attn, add support for other attn_backends.
|
||||
2. Support TP > 1 (this requires some designs because we do not expect
|
||||
any broadcasting inside execute_model).
|
||||
"""
|
||||
|
||||
def __init__(self, model_runner: ModelRunnerBase):
|
||||
super().__init__(model_runner)
|
||||
|
||||
self.indices_of_seq_with_bonus_tokens = None
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||
num_queries):
|
||||
|
||||
assert sampling_metadata.num_prompts == 0
|
||||
assert len(sampling_metadata.seq_groups) == num_queries
|
||||
assert sampling_metadata.selected_token_indices.shape == (
|
||||
num_queries, )
|
||||
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
|
||||
|
||||
# Verify that all sequences are decodes
|
||||
for i in range(num_queries):
|
||||
seq_group = sampling_metadata.seq_groups[i]
|
||||
|
||||
assert seq_group.is_prompt is False # No prompt
|
||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||
assert seq_group.sample_indices == [i] # Simple
|
||||
|
||||
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
|
||||
last_output: SamplerOutput) -> ModelRunnerInputBase:
|
||||
# Currently, we expect "decode mode" only
|
||||
assert not model_input.is_prompt
|
||||
|
||||
# Get num_seqs
|
||||
num_seqs = len(model_input.seq_lens)
|
||||
num_queries = len(model_input.query_lens)
|
||||
|
||||
# Get output tokens GPU tensor
|
||||
sampled_token_ids = last_output.sampled_token_ids
|
||||
assert sampled_token_ids is not None
|
||||
|
||||
# Update attn_metadata
|
||||
attn_metadata = model_input.attn_metadata
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
|
||||
attn_metadata.advance_step(model_input, sampled_token_ids,
|
||||
self.block_size, num_seqs, num_queries)
|
||||
|
||||
# Update sampling_metadata
|
||||
sampling_metadata = model_input.sampling_metadata
|
||||
self._update_sampling_metadata(sampling_metadata, num_seqs,
|
||||
num_queries)
|
||||
|
||||
# Create new input
|
||||
new_model_input = self._model_input_cls(
|
||||
input_tokens=model_input.input_tokens,
|
||||
input_positions=model_input.input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
query_lens=model_input.query_lens,
|
||||
lora_mapping=model_input.lora_mapping,
|
||||
lora_requests=model_input.lora_requests,
|
||||
multi_modal_kwargs=model_input.multi_modal_kwargs,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
is_prompt=False,
|
||||
)
|
||||
|
||||
# Ensure we skip CPU samples
|
||||
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
|
||||
# We can reuse sampling tensors since every decode iteration is the same
|
||||
new_model_input.sampling_metadata.reuse_sampling_tensors = True
|
||||
|
||||
if debug_advance_input:
|
||||
logger.debug("NEW INPUT: ")
|
||||
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
|
||||
logger.debug(" input_positions = %s",
|
||||
new_model_input.input_positions)
|
||||
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
|
||||
logger.debug(" query_lens = %d", new_model_input.query_lens)
|
||||
logger.debug(" attn_metadata:")
|
||||
logger.debug(" seq_lens_tensor: %s",
|
||||
attn_metadata.seq_lens_tensor)
|
||||
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
|
||||
logger.debug(" block_tables: %s", attn_metadata.block_tables)
|
||||
|
||||
return new_model_input
|
||||
|
||||
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
|
||||
"""Determines if draft_model_runner GPU multi-step can be used.
|
||||
Currently required conditions are:
|
||||
1. Only decodes
|
||||
2. Only flash-attn
|
||||
3. No LORA
|
||||
4. No prompt_adapter_config
|
||||
"""
|
||||
if not allow_gpu_advance_step:
|
||||
return False
|
||||
|
||||
# We allow multi-step GPU only in decode mode
|
||||
for seq_group in execute_model_req.seq_group_metadata_list:
|
||||
if seq_group.is_prompt:
|
||||
return False
|
||||
|
||||
# TODO: Add support for other attn backends
|
||||
if self.attn_backend.get_name() not in ("FLASH_ATTN", ):
|
||||
return False
|
||||
|
||||
# TODO: Add support for LORA
|
||||
if self.lora_config:
|
||||
return False
|
||||
|
||||
# TODO: Add soft-tuning prompt adapter support
|
||||
return not self.prompt_adapter_config
|
||||
|
||||
def set_indices_of_seq_with_bonus_tokens(self,
|
||||
indices_of_seq_with_bonus_tokens):
|
||||
self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelRunnerInputBase,
|
||||
kv_caches: List[torch.Tensor],
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
**kwargs,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes num_steps forward passes with advacement of input tensors
|
||||
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
|
||||
|
||||
Optimizations used:
|
||||
1. Input tensors are updated on the GPU directly
|
||||
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
|
||||
them since we do batch expansion later that uses GPU outputs)
|
||||
3. Reuses sampling tensors (since we run only decodes and they have
|
||||
a repeating sampling logic)
|
||||
"""
|
||||
|
||||
# When num_steps == 1, we execute the fallback here for the GPU
|
||||
# advance_step, which runs prepare_inputs on CPU and for each spec
|
||||
# iteration invokes this function only once
|
||||
# (Look at multi-step-worker code)
|
||||
is_fallback = num_steps == 1
|
||||
if not is_fallback:
|
||||
# Since we do not broadcast data inside execute_model anymore,
|
||||
# we need to figure out the best way to support TP > 1 in this
|
||||
# case, because we will at least need to broadcast the sampled
|
||||
# tokens to all workers.
|
||||
if not self.is_driver_worker:
|
||||
raise ValueError("TP1DraftModelRunner only supports TP=1.")
|
||||
|
||||
# Sanity
|
||||
if self.lora_config is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for LORA")
|
||||
if self.prompt_adapter_config is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for "
|
||||
"prompt_adapter_config")
|
||||
if model_input.inputs_embeds is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for "
|
||||
"inputs_embeds")
|
||||
if model_input.multi_modal_kwargs:
|
||||
raise ValueError(
|
||||
"TP1DraftModelRunner has no support for multi_modal_kwargs"
|
||||
)
|
||||
else:
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
# Detect exec mode
|
||||
assert model_input.attn_metadata is not None
|
||||
use_cuda_graph = False
|
||||
if model_input.attn_metadata.num_prefills > 0:
|
||||
# In this case, execute_model(..) was called directly
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"execute_model(..) of draft_model_runner can be called "
|
||||
"directly only with a single-step prefill")
|
||||
else:
|
||||
# We can skip CPU samples for spec token generation.
|
||||
# (We do allow CPU samples for num_steps == 1 to support the
|
||||
# fallback case, where supports_gpu_multi_step(..) does not pass)
|
||||
model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
not is_fallback)
|
||||
|
||||
# Attn attr defines if we use cuda graphs
|
||||
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
|
||||
|
||||
# Get model
|
||||
if use_cuda_graph:
|
||||
if model_input.inputs_embeds is None:
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, False)])
|
||||
else:
|
||||
graph_batch_size = model_input.inputs_embeds.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, True)])
|
||||
|
||||
if previous_hidden_states is not None:
|
||||
hidden_states = torch.cat([
|
||||
previous_hidden_states,
|
||||
torch.empty([
|
||||
graph_batch_size - previous_hidden_states.shape[0],
|
||||
*previous_hidden_states.shape[1:]
|
||||
],
|
||||
dtype=previous_hidden_states.dtype,
|
||||
device=previous_hidden_states.device)
|
||||
])
|
||||
else:
|
||||
hidden_states = None
|
||||
else:
|
||||
model_executable = self.model
|
||||
hidden_states = previous_hidden_states
|
||||
|
||||
outputs: List[SamplerOutput] = []
|
||||
for step in range(num_steps):
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
|
||||
model_execute_kwargs = {"previous_hidden_states": hidden_states} \
|
||||
if previous_hidden_states is not None else {}
|
||||
|
||||
compute_logits_kwargs = {}
|
||||
# Run model
|
||||
if hasattr(self.model.config, "num_nextn_predict_layers"):
|
||||
# for DeepSeek MTP only to use the corresponding layer for
|
||||
# each step
|
||||
spec_step_idx = kwargs.get("spec_step_idx", step)
|
||||
model_execute_kwargs["spec_step_idx"] = spec_step_idx
|
||||
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config):
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
inputs_embeds=None,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
device=self.device,
|
||||
),
|
||||
**model_execute_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata,
|
||||
**compute_logits_kwargs)
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
# Sample the next token.
|
||||
output = self.model_runner.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
outputs.append(output)
|
||||
|
||||
if self.return_hidden_states and is_fallback:
|
||||
if use_cuda_graph:
|
||||
indices = model_input.sampling_metadata\
|
||||
.selected_token_indices
|
||||
output.hidden_states = hidden_states[:len(indices)]
|
||||
else:
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
if model_input.attn_metadata.num_prefills == 0 \
|
||||
and self.indices_of_seq_with_bonus_tokens is not None:
|
||||
assert output.sampled_token_ids is not None
|
||||
# output.sampled_token_ids should be of shape (num_seqs, 1)
|
||||
nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
|
||||
assert num_tokens_per_seq == 1
|
||||
count = 0
|
||||
for i in range(nums_seqs):
|
||||
bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
|
||||
count]
|
||||
if i != bonus_seq_idx:
|
||||
# The following might cause a cpu->gpu sync
|
||||
# However, the performance impact is negligible as we
|
||||
# benchmarked on H100.
|
||||
output.sampled_token_ids[
|
||||
i, :] = model_input.input_tokens[bonus_seq_idx]
|
||||
else:
|
||||
count += 1
|
||||
|
||||
# Prepare inputs for the next step
|
||||
if step != num_steps - 1:
|
||||
model_input = self._gpu_advance_step(model_input, outputs[-1])
|
||||
|
||||
return outputs
|
||||
99
vllm/spec_decode/interfaces.py
Normal file
99
vllm/spec_decode/interfaces.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest, PromptLogprobs
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeculativeProposals:
|
||||
"""Datastructure used to represent proposal tokens from some proposer. It
|
||||
also tracks how many speculative tokens each sequence has.
|
||||
"""
|
||||
|
||||
# Speculative proposal tokens.
|
||||
proposal_token_ids: torch.Tensor
|
||||
|
||||
# Probabilities of the proposal tokens according to the proposer.
|
||||
proposal_probs: torch.Tensor
|
||||
|
||||
# The valid length of each proposal; can be zero.
|
||||
proposal_lens: torch.Tensor
|
||||
|
||||
# A flag to mark that there's no available proposals
|
||||
no_proposals: bool = False
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeProposals("
|
||||
f"proposal_token_ids={self.proposal_token_ids}, "
|
||||
f"proposal_probs={self.proposal_probs.shape}, "
|
||||
f"proposal_lens={self.proposal_lens})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeculativeScores:
|
||||
"""Datastructure used to represent the scores of speculative tokens
|
||||
according to the scoring model.
|
||||
"""
|
||||
|
||||
# Probabilities of the speculative tokens according to the scoring model.
|
||||
probs: torch.Tensor
|
||||
|
||||
# Log-probabilities of the speculative tokens according to the scoring
|
||||
# model. These values can be used to generate Logprob objects that are
|
||||
# returned to the user.
|
||||
logprobs: torch.Tensor
|
||||
|
||||
# Token ids sampled from the scoring model. Used for speculative bonus
|
||||
# tokens and also non-speculative normal decoding.
|
||||
token_ids: torch.Tensor
|
||||
|
||||
# Optional last hidden states from the scoring model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Scoring model may also return logprobs for prompt tokens
|
||||
# for each request, when chunked prefill is enabled.
|
||||
prompt_logprobs: Optional[List[PromptLogprobs]] = None
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeScores("
|
||||
f"probs={self.probs.shape}, "
|
||||
f"token_ids={self.token_ids.shape})")
|
||||
|
||||
|
||||
class SpeculativeProposer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
# If set, this contains all sequence IDs that were assigned
|
||||
# bonus tokens in their last forward pass.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SpeculativeScorer(ABC):
|
||||
|
||||
def __init__(self, scorer_worker: WorkerBase,
|
||||
device: Union[torch.device, str], vocab_size: int):
|
||||
self._scorer_worker = scorer_worker
|
||||
if isinstance(device, torch.device):
|
||||
device = device.type
|
||||
self._device = device
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
@abstractmethod
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
raise NotImplementedError
|
||||
138
vllm/spec_decode/medusa_worker.py
Normal file
138
vllm/spec_decode/medusa_worker.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import weakref
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import DelegateWorkerBase
|
||||
|
||||
|
||||
class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
|
||||
"""Worker for Medusa.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
DelegateWorkerBase.__init__(self, *args, **kwargs)
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
def init_device(self):
|
||||
self.worker.init_device()
|
||||
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
self.device,
|
||||
self.vocab_size,
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
pass
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self):
|
||||
pass
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass to generate sample_len future tokens.
|
||||
Returns the list of sampler output, one per layer, along with indicator
|
||||
of whether torch tensor in sampler output need to be transposed in
|
||||
latter sampler_output_to_torch logic.
|
||||
|
||||
For medusa worker, this indicator shall be False.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
seq_lens, query_lens = self._prepare_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
|
||||
generators = self.model_runner.get_generators(
|
||||
execute_model_req.finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.model_runner.pin_memory, generators)
|
||||
|
||||
model_outputs = self.model_runner.model.generate_proposals(
|
||||
previous_hidden_states=execute_model_req.previous_hidden_states.
|
||||
hidden_states,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
return model_outputs, False
|
||||
|
||||
def _prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if not seq_group_metadata_list:
|
||||
return [], []
|
||||
|
||||
seq_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seq_data_len = seq_data.get_len()
|
||||
if is_prompt:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = min(
|
||||
seq_data_len,
|
||||
context_len + seq_group_metadata.token_chunk_size)
|
||||
seq_lens.append(seq_len)
|
||||
query_lens.append(seq_len - context_len)
|
||||
else:
|
||||
seq_lens.append(seq_data_len)
|
||||
query_lens.append(1)
|
||||
|
||||
return seq_lens, query_lens
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""MedusaWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"MedusaWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MedusaWorker does not support beam search.")
|
||||
213
vllm/spec_decode/metrics.py
Normal file
213
vllm/spec_decode/metrics.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
class SpecDecodeWorkerMetrics(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""Dataclass holding metrics emitted from the spec decode worker.
|
||||
"""
|
||||
|
||||
# The empirical acceptance rate of the proposal method on a per-token basis.
|
||||
# This is useful for evaluating how well the proposal method aligns with the
|
||||
# scoring method.
|
||||
draft_acceptance_rate: float
|
||||
|
||||
# The empirical efficiency, measured as the number of tokens emitted by the
|
||||
# system divided by the number of tokens that could be emitted by the system
|
||||
# if the proposal method were perfect.
|
||||
system_efficiency: float
|
||||
|
||||
# The number of speculative tokens produced by the proposal method.
|
||||
draft_tokens: int
|
||||
|
||||
# The number of tokens emitted by the entire system.
|
||||
emitted_tokens: int
|
||||
|
||||
# The number of tokens accepted by the scoring model and verification
|
||||
# routine, e.g. Llama2-70B and lossless rejection sampling.
|
||||
#
|
||||
# NOTE: Any token accepted by the verification routine is considered
|
||||
# accepted (regardless of if the speculative prefix is also accepted). The
|
||||
# user will usually see less accepted tokens. This metric is helpful when
|
||||
# evaluating alignment of the proposal method with the scoring model.
|
||||
accepted_tokens: int
|
||||
|
||||
# The number of speculative tokens per sequence.
|
||||
num_spec_tokens: int
|
||||
|
||||
|
||||
Timer = Callable[[], float]
|
||||
|
||||
|
||||
class AsyncMetricsCollector:
|
||||
"""Class which copies rejection/typical-acceptance sampler metrics
|
||||
from the device to CPU on a non-default Torch stream.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||
timer: Optional[Timer] = None,
|
||||
collect_interval_s: float = 5.0):
|
||||
self.spec_decode_sampler = spec_decode_sampler
|
||||
self._timer = time.time if timer is None else timer
|
||||
|
||||
self._rank: Optional[int] = None
|
||||
|
||||
# We don't have a device set yet.
|
||||
self._copy_stream: Optional[torch.cuda.Stream] = None
|
||||
|
||||
self._in_flight_copy: Optional[torch.cuda.Event] = None
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
self._aggregate_num_accepted_tokens = torch.tensor(
|
||||
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
||||
self._aggregate_num_emitted_tokens = torch.tensor(
|
||||
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
||||
self._aggregate_num_draft_tokens = 0
|
||||
|
||||
self._rejsample_metrics_collect_interval_s = collect_interval_s
|
||||
self._last_metrics_collect_time = self._timer()
|
||||
|
||||
def init_gpu_tensors(self, rank: int) -> None:
|
||||
self._rank = rank
|
||||
self._copy_stream = torch.cuda.Stream()
|
||||
|
||||
def init_tensors(self,
|
||||
rank: int,
|
||||
device_type: Union[torch.device, str] = 'cuda') -> None:
|
||||
self._rank = rank
|
||||
if isinstance(device_type, torch.device):
|
||||
device_type = device_type.type
|
||||
stream = current_platform.Stream
|
||||
if stream is not None:
|
||||
self._copy_stream = stream()
|
||||
|
||||
def maybe_collect_rejsample_metrics(
|
||||
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
||||
# Skip for any platform that doesn't have device Event
|
||||
if current_platform.Event is None:
|
||||
return None
|
||||
|
||||
# If a copy was initiated in the previous call, collect and return.
|
||||
if self._in_flight_copy is not None:
|
||||
ready_event = self._in_flight_copy
|
||||
self._in_flight_copy = None
|
||||
return self._collect_rejsample_metrics(k, ready_event)
|
||||
|
||||
# Otherwise, check if we should start a new copy.
|
||||
if self._should_collect_rejsample_metrics(self._timer()):
|
||||
assert self._in_flight_copy is None
|
||||
self._in_flight_copy = self._copy_rejsample_metrics_async()
|
||||
|
||||
return None
|
||||
|
||||
def _should_collect_rejsample_metrics(self, now: float) -> bool:
|
||||
"""Return whether or not this iteration should print sampling
|
||||
metrics.
|
||||
"""
|
||||
if self._rank != 0:
|
||||
return False
|
||||
|
||||
return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501
|
||||
|
||||
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
|
||||
"""Copy rejection/typical-acceptance sampling metrics
|
||||
(number of accepted tokens, etc) to CPU asynchronously.
|
||||
|
||||
Returns a device event recording when the copy is complete.
|
||||
"""
|
||||
assert self._copy_stream is not None
|
||||
self._copy_stream.wait_stream(current_platform.current_stream())
|
||||
|
||||
with current_platform.stream(self._copy_stream):
|
||||
self._aggregate_num_accepted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_accepted_tokens,
|
||||
non_blocking=True)
|
||||
self._aggregate_num_emitted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
|
||||
# Number of draft tokens is calculated on CPU, so no copy is
|
||||
# required.
|
||||
self._aggregate_num_draft_tokens = (
|
||||
self.spec_decode_sampler.num_draft_tokens)
|
||||
|
||||
aggregate_metrics_ready = current_platform.Event()
|
||||
aggregate_metrics_ready.record(self._copy_stream)
|
||||
|
||||
return aggregate_metrics_ready
|
||||
|
||||
def _collect_rejsample_metrics(
|
||||
self, k: int,
|
||||
ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
|
||||
"""Create metrics object from statistics copied asynchronously.
|
||||
|
||||
Args:
|
||||
k: int. The number of speculative tokens; used to determine system
|
||||
efficiency.
|
||||
ready_event: torch.cuda.Event. The CUDA event recording when the
|
||||
async GPU->CPU copy is complete.
|
||||
"""
|
||||
|
||||
ready_event.synchronize()
|
||||
|
||||
# update time of last collection
|
||||
self._last_metrics_collect_time = self._timer()
|
||||
|
||||
accepted_tokens = self._aggregate_num_accepted_tokens.item()
|
||||
emitted_tokens = self._aggregate_num_emitted_tokens.item()
|
||||
draft_tokens = self._aggregate_num_draft_tokens
|
||||
|
||||
max_num_emitted_tokens = self.get_max_num_emitted_tokens(
|
||||
draft_tokens, k)
|
||||
|
||||
if draft_tokens > 0:
|
||||
draft_acceptance_rate = accepted_tokens / draft_tokens
|
||||
else:
|
||||
draft_acceptance_rate = float("nan")
|
||||
|
||||
if max_num_emitted_tokens > 0:
|
||||
system_efficiency = emitted_tokens / max_num_emitted_tokens
|
||||
else:
|
||||
system_efficiency = float("nan")
|
||||
|
||||
return SpecDecodeWorkerMetrics(
|
||||
num_spec_tokens=k,
|
||||
draft_acceptance_rate=draft_acceptance_rate,
|
||||
system_efficiency=system_efficiency,
|
||||
accepted_tokens=accepted_tokens,
|
||||
draft_tokens=draft_tokens,
|
||||
emitted_tokens=emitted_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
|
||||
"""Calculate the number of emitted tokens, assuming all tokens are
|
||||
accepted.
|
||||
|
||||
This is equal to the number of sequences that have been speculated on,
|
||||
times (speculation len + 1). The +1 comes from the bonus token.
|
||||
"""
|
||||
# Determine the number of sequences that have been speculated on. Since
|
||||
# the batch size can be variable, we divide by k.
|
||||
assert draft_tokens % k == 0
|
||||
total_num_spec_seqs = draft_tokens // k
|
||||
|
||||
# A single sequence may emit k accepted tokens and one bonus token in
|
||||
# the best case.
|
||||
num_emitted_per_seq_if_all_accepted = k + 1
|
||||
|
||||
# The max num of emitted tokens is the number of speculated sequences
|
||||
# times the max emitted per seq.
|
||||
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted
|
||||
94
vllm/spec_decode/mlp_speculator_worker.py
Normal file
94
vllm/spec_decode/mlp_speculator_worker.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
|
||||
|
||||
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||
"""Worker for MLPSpeculator models.
|
||||
|
||||
Not currently compatible with LoRA or chunked prefill.
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass to generate sample_len future tokens.
|
||||
Returns the list of sampler output, one per layer, along with indicator
|
||||
of whether torch tensor in sampler output need to be transposed in
|
||||
latter sampler_output_to_torch logic.
|
||||
|
||||
For mlp spec worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
(input_tokens, seq_lens,
|
||||
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
generators = self.model_runner.get_generators(
|
||||
execute_model_req.finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.model_runner.pin_memory, generators)
|
||||
|
||||
model_outputs = self.model_runner.model.generate_proposals(
|
||||
input_ids=input_tokens,
|
||||
previous_hidden_states=execute_model_req.previous_hidden_states.
|
||||
hidden_states,
|
||||
num_predict_tokens=sample_len,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
assert len(model_outputs) == sample_len
|
||||
|
||||
return model_outputs, True
|
||||
|
||||
def _prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[torch.Tensor, List[int], List[int]]:
|
||||
if not seq_group_metadata_list:
|
||||
return torch.empty(0, device=self.device), [], []
|
||||
|
||||
input_tokens: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seq_data_len = seq_data.get_len()
|
||||
if is_prompt:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = min(
|
||||
seq_data_len,
|
||||
context_len + seq_group_metadata.token_chunk_size)
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
seq_lens.append(seq_len)
|
||||
input_tokens.extend(tokens)
|
||||
query_lens.append(seq_len - context_len)
|
||||
else:
|
||||
seq_lens.append(seq_data_len)
|
||||
input_tokens.append(seq_data.get_last_token_id())
|
||||
query_lens.append(1)
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
return input_tokens_tensor, seq_lens, query_lens
|
||||
160
vllm/spec_decode/mqa_scorer.py
Normal file
160
vllm/spec_decode/mqa_scorer.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SequenceData,
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
|
||||
|
||||
class MQAScorer(SpeculativeScorer):
|
||||
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
target_seq_group_metadata_list = []
|
||||
target_seq_id_start = max(
|
||||
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
|
||||
all_proposal_tokens = proposals.proposal_token_ids.tolist()
|
||||
all_proposal_lengths = proposals.proposal_lens.tolist()
|
||||
for i, seq_group_metadata in enumerate(
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
if all_proposal_lengths[i] == 0:
|
||||
# Keep prompt seqs untouched (keep computed_tokens for chunks).
|
||||
target_seq_group_metadata_list.append(seq_group_metadata)
|
||||
continue
|
||||
|
||||
seq_data_dict = seq_group_metadata.seq_data
|
||||
assert len(seq_data_dict) == 1
|
||||
seq_id = next(iter(seq_data_dict.keys()))
|
||||
|
||||
seq_data: SequenceData = seq_data_dict[seq_id]
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
output_token_ids = seq_data.get_output_token_ids()
|
||||
proposal_token_ids = all_proposal_tokens[
|
||||
i][:all_proposal_lengths[i]]
|
||||
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
|
||||
|
||||
target_seq_id = target_seq_id_start + i
|
||||
new_seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
output_token_ids=new_output_token_ids,
|
||||
)
|
||||
new_seq_data.update_num_computed_tokens(
|
||||
len(prompt_token_ids) + len(output_token_ids) - 1)
|
||||
|
||||
# Ensure that the new decode sequence has at least one token.
|
||||
assert len(output_token_ids) >= 1
|
||||
new_seq_data_dict = {target_seq_id: new_seq_data}
|
||||
|
||||
new_seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
seq_data=new_seq_data_dict,
|
||||
sampling_params=seq_group_metadata.sampling_params,
|
||||
block_tables={
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
)
|
||||
target_seq_group_metadata_list.append(new_seq_group_metadata)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list))
|
||||
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
k = execute_model_req.num_lookahead_slots
|
||||
bs = len(execute_model_req.seq_group_metadata_list)
|
||||
target_token_ids = target_sampler_output.sampled_token_ids
|
||||
target_probs = target_sampler_output.sampled_token_probs
|
||||
target_logprobs = target_sampler_output.logprobs
|
||||
prompt_logprobs = None
|
||||
|
||||
# If all requests have the same number of query tokens, we can avoid
|
||||
# the for loop to build output for better performance.
|
||||
if min(all_proposal_lengths) == k:
|
||||
# Regular decodes only.
|
||||
assert all(not sg.is_prompt
|
||||
for sg in target_seq_group_metadata_list
|
||||
if sg.is_prompt)
|
||||
bs, _ = proposals.proposal_token_ids.shape
|
||||
all_tokens = target_token_ids.reshape(bs, k + 1)
|
||||
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
|
||||
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
|
||||
else:
|
||||
# We either have decodes with different lens or prefill+decodes.
|
||||
all_tokens = target_token_ids.new_full(size=(bs, k + 1),
|
||||
fill_value=-1)
|
||||
all_probs = target_probs.new_zeros(*all_tokens.shape,
|
||||
self._vocab_size)
|
||||
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
||||
fill_value=-float("inf"))
|
||||
target_token_ids = target_token_ids.flatten()
|
||||
|
||||
# When prompt logprobs is enabled, lens of returned tensors go from
|
||||
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
|
||||
# We adjust stride accordingly to get the generated tokens and
|
||||
# their probs, but pass on prompt_logprobs as is, since it may be
|
||||
# that n_prompts >> K.
|
||||
has_prompt_log = any((sg.sampling_params.prompt_logprobs
|
||||
and sg.sampling_params.prompt_logprobs > 0)
|
||||
for sg in target_seq_group_metadata_list)
|
||||
# TODO (NickLucche) we should surface `disable_logprobs` as to not
|
||||
# break abstraction to get its value.
|
||||
if (not self._scorer_worker.model_runner.disable_logprobs\
|
||||
and has_prompt_log):
|
||||
prompt_logprobs = [
|
||||
o.prompt_logprobs for o in target_sampler_output.outputs
|
||||
]
|
||||
|
||||
# Split loop into prefill|decode for readability.
|
||||
start_loc, i = 0, 0
|
||||
while i < len(target_seq_group_metadata_list
|
||||
) and target_seq_group_metadata_list[i].is_prompt:
|
||||
seq_meta = target_seq_group_metadata_list[i]
|
||||
end_loc = start_loc
|
||||
if has_prompt_log:
|
||||
end_loc += seq_meta.token_chunk_size
|
||||
elif seq_meta.do_sample:
|
||||
end_loc += 1
|
||||
|
||||
# Skip chunks with no output tokens.
|
||||
if seq_meta.do_sample:
|
||||
# Get sampled token (last position in chunk) and its prob.
|
||||
all_tokens[i, 0] = target_token_ids[end_loc - 1]
|
||||
all_probs[i, 0] = target_probs[end_loc - 1]
|
||||
all_logprobs[i, 0] = target_logprobs[end_loc - 1]
|
||||
|
||||
i += 1
|
||||
start_loc = end_loc
|
||||
# Decodes.
|
||||
while i < len(target_seq_group_metadata_list):
|
||||
proposed_len, seq_meta = all_proposal_lengths[
|
||||
i], target_seq_group_metadata_list[i]
|
||||
output_len = proposed_len + 1
|
||||
end_loc = start_loc + output_len
|
||||
all_tokens[
|
||||
i, :output_len] = target_token_ids[start_loc:end_loc]
|
||||
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
|
||||
all_logprobs[
|
||||
i, :output_len] = target_logprobs[start_loc:end_loc]
|
||||
start_loc = end_loc
|
||||
i += 1
|
||||
|
||||
hidden_states = None
|
||||
if target_sampler_output.hidden_states is not None:
|
||||
hidden_states = target_sampler_output.hidden_states.reshape(
|
||||
bs, (k + 1), -1)
|
||||
|
||||
return SpeculativeScores(probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=all_logprobs,
|
||||
hidden_states=hidden_states,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
423
vllm/spec_decode/multi_step_worker.py
Normal file
423
vllm/spec_decode/multi_step_worker.py
Normal file
@@ -0,0 +1,423 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import weakref
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import DelegateWorkerBase
|
||||
|
||||
|
||||
class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||
multiple forward passes in a single call, assuming the scheduler has
|
||||
allocated enough space to store the additional KV. This reduces overhead
|
||||
by invoking the scheduler less.
|
||||
|
||||
The MultiStepWorker does not support cache swap operations, or beam search.
|
||||
Cache swap operations do not require large modifications. On the other hand,
|
||||
beam search requires memory allocations during sequence forks and thus
|
||||
requires more thought for MultiStepWorker support.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
DelegateWorkerBase.__init__(self, *args, **kwargs)
|
||||
# Lazy initialization list.
|
||||
self._proposer: SpeculativeProposer
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.worker.init_device()
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
self.device,
|
||||
self.vocab_size,
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self) -> None:
|
||||
# Need include_gpu_probs_tensor for MultiStepWorker
|
||||
self.model_runner.sampler.include_gpu_probs_tensor = True
|
||||
if hasattr(self.model_runner.model, "sampler"):
|
||||
(self.model_runner.model.sampler.include_gpu_probs_tensor) = True
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
|
||||
if hasattr(self.model_runner.model, "sampler"):
|
||||
(self.model_runner.model.sampler.should_modify_greedy_probs_inplace
|
||||
) = True
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
whether torch tensor in sampler output need to be transposed in latter
|
||||
sampler_output_to_torch logic.
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
# Expand the batch for sequences with a bonus token.
|
||||
# Perform a forward pass on the expanded batch and filter the
|
||||
# response to retain only the original sequences' responses.
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
if current_platform.is_cuda_alike() and isinstance(
|
||||
self.model_runner, TP1DraftModelRunner
|
||||
) and self.model_runner.supports_gpu_multi_step(expanded_request):
|
||||
# Here we run the draft_model_runner with multi-step prepare
|
||||
# on the GPU directly
|
||||
expanded_request.num_steps = sample_len
|
||||
self.model_runner.set_indices_of_seq_with_bonus_tokens(
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs = self.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
else:
|
||||
# Here we run multi-step directly, with every step prepared
|
||||
# on the CPU.
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
if expanded_request.previous_hidden_states is not None:
|
||||
self.worker.model_runner.return_hidden_states = True
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
self._maybe_update_previous_hidden_states(
|
||||
model_output, expanded_request)
|
||||
|
||||
self._append_new_tokens(
|
||||
model_output, expanded_request.seq_group_metadata_list,
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
# move indices to device to avoid stream sync
|
||||
indices_of_seq_with_bonus_tokens = torch.tensor(
|
||||
indices_of_seq_with_bonus_tokens, device=self.device)
|
||||
filtered_model_outputs = self._filter_model_output(
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
return filtered_model_outputs, True
|
||||
|
||||
@staticmethod
|
||||
def _maybe_update_previous_hidden_states(
|
||||
model_output: SamplerOutput,
|
||||
expanded_request: ExecuteModelRequest) -> None:
|
||||
"""
|
||||
Updates the previous hidden states in an expanded request
|
||||
in-place with the hidden states from the model output.
|
||||
"""
|
||||
if expanded_request.previous_hidden_states is not None:
|
||||
expanded_request.previous_hidden_states = HiddenStates(
|
||||
model_output.hidden_states,
|
||||
expanded_request.seq_group_metadata_list)
|
||||
|
||||
@staticmethod
|
||||
def _expand_execute_model_request(
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_with_bonus_token_in_last_step: set,
|
||||
) -> Tuple[ExecuteModelRequest, List[int]]:
|
||||
"""
|
||||
Expands the execute model request based on sequences with bonus
|
||||
tokens.
|
||||
|
||||
For each sequence with a bonus token, this method creates a new
|
||||
sequence without the bonus token and adds it to the execute model
|
||||
request. The original sequence groups are also retained. The indices
|
||||
of the original sequence groups are returned for further processing.
|
||||
|
||||
Args:
|
||||
execute_model_req (ExecuteModelRequest): The original execute
|
||||
model request.
|
||||
seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
|
||||
contain bonus tokens.
|
||||
|
||||
Returns:
|
||||
Tuple[ExecuteModelRequest, List[int]]: The updated execute model
|
||||
request with expanded sequences and a list of indices corresponding
|
||||
to the original sequence groups.
|
||||
"""
|
||||
updated_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
updated_execute_model_req = execute_model_req.clone(
|
||||
updated_seq_group_metadata_list)
|
||||
indices_of_original_sequence_groups = []
|
||||
for seq_group in execute_model_req.seq_group_metadata_list:
|
||||
seq_group_has_bonus_tokens = False
|
||||
for seq_id, _ in seq_group.seq_data.items():
|
||||
# Identify sequences with bonus tokens in the sequence group.
|
||||
if seq_id in seq_with_bonus_token_in_last_step:
|
||||
seq_group_has_bonus_tokens = True
|
||||
break
|
||||
if seq_group_has_bonus_tokens:
|
||||
#Create new sequences without the last bonus token. These new
|
||||
# sequence have the same sequence id as the original sequence.
|
||||
# We create a new sequence group and add them there.
|
||||
updated_seq_group_without_bonus_token = \
|
||||
MultiStepWorker._copy_seq_metadata_excluding_last_token(
|
||||
seq_group, seq_with_bonus_token_in_last_step)
|
||||
updated_seq_group_metadata_list.append(
|
||||
updated_seq_group_without_bonus_token)
|
||||
# Add the original sequence group.
|
||||
updated_seq_group_metadata_list.append(
|
||||
MultiStepWorker._shallow_copy_seq_group_metadata(seq_group))
|
||||
# Record the index of the original sequence group.
|
||||
indices_of_original_sequence_groups.append(
|
||||
len(updated_seq_group_metadata_list) - 1)
|
||||
|
||||
updated_execute_model_req.seq_group_metadata_list =\
|
||||
updated_seq_group_metadata_list
|
||||
|
||||
if isinstance(updated_execute_model_req.previous_hidden_states,
|
||||
HiddenStates):
|
||||
updated_execute_model_req.previous_hidden_states\
|
||||
.expand_with_bonus_tokens(seq_with_bonus_token_in_last_step)
|
||||
|
||||
return updated_execute_model_req, indices_of_original_sequence_groups
|
||||
|
||||
@staticmethod
|
||||
def _filter_model_output(
|
||||
expanded_batch_outputs: List[SamplerOutput],
|
||||
output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
|
||||
"""
|
||||
Filters the model output to include only the specified sequence
|
||||
outputs. This method contracts the expanded batch output from the
|
||||
model to retain the outputs of only those sequences indicated by the
|
||||
provided indices.
|
||||
|
||||
Args:
|
||||
expanded_batch_output (List[SamplerOutput]): The expanded output
|
||||
batch from the model.
|
||||
output_indices_to_retain (torch.Tensor): Indices of the model
|
||||
outputs to retain.
|
||||
|
||||
Returns:
|
||||
List[SamplerOutput]: A list containing the filtered model
|
||||
outputs for the specified indices.
|
||||
"""
|
||||
return [
|
||||
SamplerOutput(
|
||||
outputs=[
|
||||
expanded_batch_output.outputs[i]
|
||||
for i in output_indices_to_retain
|
||||
] if len(expanded_batch_output.outputs) > 0 else [],
|
||||
sampled_token_probs=(
|
||||
expanded_batch_output.
|
||||
sampled_token_probs[output_indices_to_retain]
|
||||
if expanded_batch_output.sampled_token_probs is not None
|
||||
else None),
|
||||
logprobs=(
|
||||
expanded_batch_output.logprobs[output_indices_to_retain]
|
||||
if expanded_batch_output.logprobs is not None else None),
|
||||
sampled_token_ids=(expanded_batch_output.
|
||||
sampled_token_ids[output_indices_to_retain]
|
||||
if expanded_batch_output.sampled_token_ids
|
||||
is not None else None))
|
||||
for expanded_batch_output in expanded_batch_outputs
|
||||
]
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: set,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
@staticmethod
|
||||
def _append_new_tokens(
|
||||
model_output: List[SamplerOutput],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
indices_of_seq_with_bonus_tokens: List[int]) -> None:
|
||||
"""Given model output from a single run, append the tokens to the
|
||||
sequences. This is normally done outside of the worker, but it is
|
||||
required if the worker is to perform multiple forward passes.
|
||||
"""
|
||||
count = 0
|
||||
for index, (seq_group_metadata, sequence_group_outputs) in enumerate(
|
||||
zip(seq_group_metadata_list, model_output)):
|
||||
seq_group_metadata.is_prompt = False
|
||||
|
||||
for seq_output in sequence_group_outputs.samples:
|
||||
# NOTE: Beam search is not supported, so we can assume that
|
||||
# parent_seq_id == seq_id.
|
||||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||
|
||||
token_id = seq_output.output_token
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
# Determine the actual token ID to be generated,
|
||||
# considering bonus tokens
|
||||
if index != indices_of_seq_with_bonus_tokens[count]:
|
||||
bonus_seq_metadata = seq_group_metadata_list[
|
||||
indices_of_seq_with_bonus_tokens[count]]
|
||||
_, bonus_token_seq_data = next(
|
||||
iter(bonus_seq_metadata.seq_data.items()))
|
||||
token_id = bonus_token_seq_data.output_token_ids[-1]
|
||||
else:
|
||||
count += 1
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob,
|
||||
seq_output.output_embed)
|
||||
seq.update_num_computed_tokens(1)
|
||||
|
||||
@staticmethod
|
||||
def _shallow_copy_seq_group_metadata(
|
||||
seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata:
|
||||
"""Copy input data structures to remove side-effects when input data
|
||||
structures are shared with other modules.
|
||||
|
||||
Helpful when the vLLM scheduler runs in the same process as the worker.
|
||||
The alternative is deep-copying (or other form of deep copy); this has
|
||||
performance downsides.
|
||||
"""
|
||||
# Shallow-copy the SequenceGroupMetadata. This allows us to
|
||||
# append tokens and change is_prompt without external side-effects.
|
||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||
new_seq_group_metadata = copy.copy(seq_group_metadata)
|
||||
|
||||
# We must shallow-copy seq_data as we will append token ids
|
||||
new_seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
new_seq_data[seq_id].output_token_ids =\
|
||||
old_seq_data.output_token_ids[:]
|
||||
|
||||
new_seq_group_metadata.seq_data = new_seq_data
|
||||
return new_seq_group_metadata
|
||||
|
||||
@staticmethod
|
||||
def _copy_seq_metadata_excluding_last_token(
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_ids_to_copy: Set[int],
|
||||
) -> SequenceGroupMetadata:
|
||||
"""
|
||||
Creates a shallow copy of the given SequenceGroupMetadata, retaining
|
||||
only the sequence IDs specified in seq_ids_to_copy. For each of these
|
||||
sequence IDs, all output_token_ids except the last one are copied.
|
||||
Sequence IDs not in seq_ids_to_copy are excluded from the copy.
|
||||
|
||||
Parameters:
|
||||
seq_group_metadata (SequenceGroupMetadata): The original sequence
|
||||
group metadata.
|
||||
seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
|
||||
copy.
|
||||
|
||||
Returns:
|
||||
SequenceGroupMetadata: A shallow copy of the sequence group metadata
|
||||
with the specified modifications.
|
||||
"""
|
||||
# Shallow-copy the SequenceGroupMetadata.
|
||||
new_seq_group_metadata = copy.copy(seq_group_metadata)
|
||||
# Shallow-copy seq_data and modify the output_token_ids.
|
||||
new_seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
if (seq_id in seq_ids_to_copy):
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
# Copy all the output token ids except the last.
|
||||
# Also reduce num_computed_tokens by 1 since we are not
|
||||
# including the last output token.
|
||||
# NOTE: num_computed_tokens is not directly used by the
|
||||
# speculative decoding workers, as it is only relevant for
|
||||
# chunked prefill, which is disabled for speculative decoding.
|
||||
# However, to maintain consistency in num_computed_tokens,
|
||||
# we update it here.
|
||||
new_seq_data[seq_id].output_token_ids =\
|
||||
old_seq_data.output_token_ids[:-1]
|
||||
new_seq_data[seq_id].update_num_computed_tokens(-1)
|
||||
new_seq_group_metadata.seq_data = new_seq_data
|
||||
return new_seq_group_metadata
|
||||
|
||||
def _assert_enough_kv_space(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
num_steps: int) -> None:
|
||||
"""Assert there are enough physical blocks per sequence to store the
|
||||
current KV plus additional KV from num_steps tokens.
|
||||
"""
|
||||
assert self.model_runner.block_size is not None
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Only one seq_id is guaranteed because there is no beam search.
|
||||
seq_id = list(seq_group_metadata.seq_data.keys())[0]
|
||||
seq = seq_group_metadata.seq_data[seq_id]
|
||||
|
||||
# After num_steps, the seq len will be the current seq len
|
||||
# plus one token per step.
|
||||
final_seq_len = seq.get_len() + num_steps
|
||||
|
||||
# We will have final_seq_len - 1 KV because vLLM saves KV for a
|
||||
# token in the iteration after the token was generated.
|
||||
required_num_kv_slots = final_seq_len - 1
|
||||
|
||||
# The allocated number of kv slots is the number of allocated blocks
|
||||
# times the number of slots of block.
|
||||
number_physical_blocks = len(
|
||||
seq_group_metadata.block_tables[seq_id])
|
||||
allocated_kv_slots = (number_physical_blocks *
|
||||
self.model_runner.block_size)
|
||||
|
||||
if required_num_kv_slots > allocated_kv_slots:
|
||||
request_id = seq_group_metadata.request_id
|
||||
raise ValueError(
|
||||
"The worker attempted to run "
|
||||
f"{num_steps} times but found insufficient KV space for "
|
||||
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
|
||||
f"{required_num_kv_slots=}).")
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""MultiStepWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
||||
|
||||
def maybe_load_lm_head_weight(
|
||||
self,
|
||||
lm_head_weight: torch.Tensor,
|
||||
) -> None:
|
||||
weight_loader = getattr(
|
||||
self.worker.model_runner.model_runner.model.lm_head.weight,
|
||||
"weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
self.worker.model_runner.model_runner.model.lm_head.weight,
|
||||
lm_head_weight)
|
||||
196
vllm/spec_decode/ngram_worker.py
Normal file
196
vllm/spec_decode/ngram_worker.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import weakref
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
|
||||
class _DummyModel(nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class NGramWorker(NonLLMProposerWorkerBase):
|
||||
"""NGramWorker provides a light drafter without need for model.
|
||||
|
||||
Current NGramWorker only implements prompt lookup decoding,
|
||||
and in future we may also do RAG type drafter and other scenarios
|
||||
which don't rely on LLM model to give proposals.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
device_type: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(vllm_config)
|
||||
|
||||
# Get local_rank/vocab_size from kwargs attribute
|
||||
self.local_rank = local_rank
|
||||
self.device_type = device_type
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
|
||||
ngram_prompt_lookup_max: int):
|
||||
# Search valid candidate window between
|
||||
# ngram_prompt_lookup_min/ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
||||
|
||||
def init_device(self):
|
||||
self.device = torch.device(f"{self.device_type}:{self.local_rank}")
|
||||
|
||||
# Current NGramWorker only supports Top1Proposer
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
device=self.device,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
pass # Dummy
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return _DummyModel()
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter. NGramWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
|
||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||
sampler output, one per SequenceGroupMetadata.
|
||||
|
||||
For ngram worker, we already done needed transposed internal, so the
|
||||
indicator pass to sampler_output_to_torch shall be False.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
has_spec_out = False
|
||||
token_id_list: List[Optional[torch.Tensor]] = []
|
||||
token_prob_list: List[Optional[torch.Tensor]] = []
|
||||
for idx, seq_group_metadata in enumerate(
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
# When seq_len is less than 3072 (3K), we use CPU to perform
|
||||
# the ngram match. Otherwise, we use the device specified in
|
||||
# the model config (normally GPU). 3072 is a rough threshold
|
||||
# based on profiling on H100, and it can be adjusted based
|
||||
# on the actual performance on different hardware.
|
||||
cur_device = "cpu" if seq_len < 3072 else self.device
|
||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||
dtype=torch.long,
|
||||
device=cur_device)
|
||||
input_length = seq_data.get_len()
|
||||
|
||||
for ngram_size in range(
|
||||
min(self.ngram_prompt_lookup_max, input_length - 1),
|
||||
self.ngram_prompt_lookup_min - 1,
|
||||
-1,
|
||||
):
|
||||
ngram_tensor = input_ids[-ngram_size:]
|
||||
if ngram_size == 1:
|
||||
# Do not match itself and do not use unfold and all
|
||||
matches = (input_ids[:-1] == ngram_tensor)
|
||||
else:
|
||||
windows = input_ids.unfold(dimension=0,
|
||||
size=ngram_size,
|
||||
step=1)
|
||||
# Do not match itself
|
||||
matches = (windows[:-1] == ngram_tensor).all(dim=-1)
|
||||
|
||||
# first_match includes "values" (bool), indicating whether
|
||||
# the match is found, and "indices", indicating the index
|
||||
# of the first match.
|
||||
first_match = matches.max(dim=-1)
|
||||
if first_match.values.item():
|
||||
proposal_start_idx = first_match.indices.add_(ngram_size)
|
||||
spec_indices = (
|
||||
proposal_start_idx).repeat(sample_len) + torch.arange(
|
||||
sample_len, device=cur_device)
|
||||
spec_indices.clamp_(max=input_ids.shape[-1] - 1)
|
||||
res = input_ids.gather(dim=-1,
|
||||
index=spec_indices).to(self.device)
|
||||
token_id_list.append(res)
|
||||
token_prob_list.append(
|
||||
torch.nn.functional.one_hot(
|
||||
res,
|
||||
num_classes=self.vocab_size).to(torch.float32))
|
||||
has_spec_out = True
|
||||
break
|
||||
else:
|
||||
token_id_list.append(None)
|
||||
token_prob_list.append(None)
|
||||
|
||||
if not has_spec_out:
|
||||
return None, False
|
||||
|
||||
outputs: List[Optional[SamplerOutput]] = []
|
||||
for idx in range(len(execute_model_req.seq_group_metadata_list)):
|
||||
if token_id_list[idx] is None:
|
||||
outputs.append(None)
|
||||
else:
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
sampled_token_probs=token_prob_list[idx],
|
||||
logprobs=torch.zeros((sample_len, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device),
|
||||
sampled_token_ids=token_id_list[idx],
|
||||
))
|
||||
|
||||
return outputs, False
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
# Unused parameter. NGramWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""NGramWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support beam search.")
|
||||
59
vllm/spec_decode/proposer_worker_base.py
Normal file
59
vllm/spec_decode/proposer_worker_base.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposer
|
||||
from vllm.worker.worker_base import LoRANotSupportedWorkerBase
|
||||
|
||||
|
||||
class ProposerWorkerBase(LoRANotSupportedWorkerBase, SpeculativeProposer):
|
||||
"""Interface for proposer workers"""
|
||||
|
||||
@abstractmethod
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# A set containing all sequence IDs that were assigned bonus tokens
|
||||
# in their last forward pass. This set is used to backfill the KV cache
|
||||
# with the key-value pairs of the penultimate token in the sequences.
|
||||
# This parameter is only used by the MultiStepWorker, which relies on
|
||||
# the KV cache for token generation. It is not used by workers that
|
||||
# do not utilize the KV cache.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int]
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_include_gpu_probs_tensor(self) -> None:
|
||||
"""Implementation optional"""
|
||||
pass
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
"""Implementation optional"""
|
||||
pass
|
||||
|
||||
|
||||
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
|
||||
"""Proposer worker which does not use a model with kvcache"""
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
"""get_spec_proposals is used to get the proposals"""
|
||||
return []
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""This is never called on the proposer, only the target model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
pass
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
return 0
|
||||
196
vllm/spec_decode/smaller_tp_proposer_worker.py
Normal file
196
vllm/spec_decode/smaller_tp_proposer_worker.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.distributed.parallel_state import (get_tp_group,
|
||||
init_model_parallel_group,
|
||||
patch_tensor_parallel_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _DummyModel(nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||
"""Class which allows a speculative draft model to run with smaller tensor
|
||||
parallel degree than target model.
|
||||
This reduces the communication overhead of small draft models.
|
||||
|
||||
To implement this feature, this class differs behavior based on is_dummy
|
||||
flag, where dummy means worker that does not participate draft generation.
|
||||
Participating workers use a smaller tp group by patching vLLM's tensor
|
||||
parallel group temporarily during forward passes of draft models.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int,
|
||||
target_tensor_parallel_size: int):
|
||||
"""Wrap the worker in a SmallerTpProposerWorker if necessary.
|
||||
"""
|
||||
if draft_tensor_parallel_size == target_tensor_parallel_size:
|
||||
return worker
|
||||
|
||||
# gpu ranks that will generate draft tokens together
|
||||
draft_ranks = list(range(draft_tensor_parallel_size))
|
||||
|
||||
logger.info("Wrapping {%s} in {%s}", type(worker), cls)
|
||||
return cls(worker, draft_ranks)
|
||||
|
||||
def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]):
|
||||
"""Create a SmallerTpProposerWorker.
|
||||
|
||||
Args:
|
||||
worker (~vllm.spec_decode.multi_step_worker.MultiStepWorker): an
|
||||
actual worker wrapped with this class
|
||||
draft_ranks (List[int]): if this value is given, only the GPU ranks
|
||||
written in this value participate in draft generation
|
||||
"""
|
||||
self._worker = worker
|
||||
self._draft_ranks = draft_ranks
|
||||
|
||||
# init during init_device
|
||||
self._is_dummy = False
|
||||
self._tp_group = None
|
||||
|
||||
def _patch_tensor_parallel_group(self):
|
||||
"""Temporarily patch the global tp group state with its own tp group
|
||||
state.
|
||||
"""
|
||||
return patch_tensor_parallel_group(self._tp_group)
|
||||
|
||||
def init_device(self) -> None:
|
||||
self._is_dummy = get_tp_group().rank not in self._draft_ranks
|
||||
|
||||
# dummy workers do nothing
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
# creates tp process group containing only a subset of gpu ranks
|
||||
local_rank = get_tp_group().local_rank
|
||||
tp_backend = torch.distributed.get_backend(get_tp_group().device_group)
|
||||
self._tp_group = init_model_parallel_group([self._draft_ranks],
|
||||
local_rank, tp_backend)
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
self._worker.init_device()
|
||||
|
||||
def set_include_gpu_probs_tensor(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
# Need include_gpu_probs_tensor for multi_step_worker
|
||||
self._worker.set_include_gpu_probs_tensor()
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
self._worker.set_should_modify_greedy_probs_inplace()
|
||||
|
||||
def load_model(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
self._worker.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
if self._is_dummy:
|
||||
# this case is not used now
|
||||
return -1, -1
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
# Do not check _is_dummy, as it's always called by get_spec_proposals
|
||||
return self._worker.sampler_output(
|
||||
execute_model_req, sample_len,
|
||||
seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
if self._is_dummy:
|
||||
return SpeculativeProposals(None, None, None)
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
if self._is_dummy:
|
||||
return _DummyModel()
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.get_model()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
if self._is_dummy:
|
||||
return []
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.execute_model(execute_model_req)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
if self._is_dummy:
|
||||
# by returning zero, target worker can use the entire kv cache space
|
||||
return 0
|
||||
|
||||
return self._worker.get_cache_block_size_bytes()
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._worker.vocab_size
|
||||
|
||||
def maybe_load_lm_head_weight(
|
||||
self,
|
||||
lm_head_weight: torch.Tensor,
|
||||
) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
weight_loader = getattr(
|
||||
self._worker.worker.model_runner.model_runner.model.\
|
||||
lm_head.weight,
|
||||
"weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(
|
||||
self._worker.worker.model_runner.model_runner.model.\
|
||||
lm_head.weight,
|
||||
lm_head_weight)
|
||||
1326
vllm/spec_decode/spec_decode_worker.py
Normal file
1326
vllm/spec_decode/spec_decode_worker.py
Normal file
File diff suppressed because it is too large
Load Diff
45
vllm/spec_decode/target_model_runner.py
Normal file
45
vllm/spec_decode/target_model_runner.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
ModelRunnerInputBase,
|
||||
ModelRunnerWrapperBase)
|
||||
|
||||
|
||||
class TargetModelRunner(ModelRunnerWrapperBase):
|
||||
"""Specialized model runner for speculative decoding target model.
|
||||
In speculative decoding, the log probabilities selected finally may not
|
||||
be the same ones as selected by the target model sampling. This means
|
||||
that the time spent in the log probability calculation of the target model
|
||||
is time wasted, since we calculate log probabilities after deciding which
|
||||
tokens are accepted. For this reason disabling log probabilities in the
|
||||
target model will make decode faster. The model runner sets the
|
||||
SamplingMetadata parameters according to whether log probabilities are
|
||||
requested or not.
|
||||
"""
|
||||
|
||||
def __init__(self, model_runner: ModelRunnerBase):
|
||||
# An internal boolean member variable to indicate if token log
|
||||
# probabilities are needed or not.
|
||||
super().__init__(model_runner)
|
||||
self.disable_logprobs = True
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None,
|
||||
) -> ModelRunnerInputBase:
|
||||
model_input: ModelRunnerInputBase =\
|
||||
self.model_runner.prepare_model_input(
|
||||
seq_group_metadata_list, virtual_engine, finished_requests_ids)
|
||||
# If token log probabilities is disabled then skip generating sampler
|
||||
# CPU output. We directly serialize the GPU sampled_token_id tensors
|
||||
# as needed. If log probabilities is enabled then synchronize all the
|
||||
# sampling related tensors which includes the logprobs tensors.
|
||||
model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
self.disable_logprobs)
|
||||
return model_input
|
||||
275
vllm/spec_decode/top1_proposer.py
Normal file
275
vllm/spec_decode/top1_proposer.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
|
||||
|
||||
class Top1Proposer(SpeculativeProposer):
|
||||
"""Helper class which separates out sequences which would exceed the max
|
||||
model length when speculated upon.
|
||||
|
||||
This allows combinations of models such as JackFram/llama-68m draft with
|
||||
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
|
||||
2048 while Llama2-13b has max_position_embeddings of 4096.
|
||||
|
||||
We treat the sequences which exceed the proposal draft model length as
|
||||
"non-spec sequences". Essentially they skip the draft model and go through
|
||||
normal decoding in the target model.
|
||||
|
||||
Currently, only proposal_lens of 0 and k are supported, where k is a global
|
||||
batch proposal length. In the future vLLM should support per-sequence
|
||||
proposal lengths.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker: ProposerWorkerBase,
|
||||
device: str,
|
||||
vocab_size: int,
|
||||
max_proposal_len: Optional[int] = None,
|
||||
):
|
||||
self._worker = worker
|
||||
self._device = device
|
||||
self.max_proposal_len = max_proposal_len
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
proposal_len = execute_model_req.num_lookahead_slots
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
(
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
# sequences.
|
||||
# If sampler_transposed is true, then maybe_sampler_output's
|
||||
# token_ids is like [batch] format in proposal_len size list,
|
||||
# while if it is false, the format would be [proposal_len]
|
||||
# in batch size list
|
||||
hidden_states = execute_model_req.previous_hidden_states
|
||||
if hidden_states is not None:
|
||||
hidden_states.prune(nonzero_proposal_len_seqs)
|
||||
nonzero_execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
num_lookahead_slots=proposal_len,
|
||||
previous_hidden_states=hidden_states,
|
||||
)
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
seq_ids_with_bonus_token_in_last_step=\
|
||||
seq_ids_with_bonus_token_in_last_step,
|
||||
)
|
||||
(
|
||||
proposal_lens,
|
||||
maybe_sampler_output,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._remove_no_proposal_seqs(proposal_lens,
|
||||
maybe_sampler_output,
|
||||
nonzero_proposal_len_indices,
|
||||
transposed)
|
||||
else:
|
||||
# If no sequences can be speculated, set sampler output to None.
|
||||
maybe_sampler_output = None
|
||||
transposed = False
|
||||
|
||||
# Combine speculative- and non-speculative sequences into the same
|
||||
# representation.
|
||||
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||||
batch_size=len(seq_group_metadata_list),
|
||||
proposal_len=proposal_len,
|
||||
maybe_sampler_output=maybe_sampler_output,
|
||||
proposal_lens=proposal_lens,
|
||||
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||||
sampler_transposed=transposed,
|
||||
)
|
||||
|
||||
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
no_proposals=maybe_sampler_output
|
||||
is None)
|
||||
return proposals
|
||||
|
||||
def _split_by_proposal_len(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_len: int,
|
||||
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
||||
"""Split sequences by two groups:
|
||||
1. Sequences with non-zero proposal length.
|
||||
2. Sequences with zero proposal length (due to disabled speculation
|
||||
or exceed the maximum model length).
|
||||
"""
|
||||
|
||||
proposal_lens: List[int] = []
|
||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||
nonzero_proposal_len_indices: List[int] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
# The speculative decoding for this request has either been disabled
|
||||
# (e.g. due to high traffic) or this is a prompt request.
|
||||
if (seq_group_metadata.is_prompt
|
||||
or seq_group_metadata.num_speculative_tokens == 0):
|
||||
proposal_lens.append(0)
|
||||
continue
|
||||
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
seq_len = seq_data.get_len()
|
||||
|
||||
# Currently only proposal lens of 0 or the global batch proposal len
|
||||
# are supported.
|
||||
# If max_proposal_len is defined, then we shall not exceed this
|
||||
# quota for nonzero_proposal
|
||||
new_k = 0
|
||||
if (self.max_proposal_len is None
|
||||
or seq_len + proposal_len < self.max_proposal_len):
|
||||
new_k = proposal_len
|
||||
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
||||
nonzero_proposal_len_indices.append(i)
|
||||
proposal_lens.append(new_k)
|
||||
seq_group_metadata.num_speculative_tokens = new_k
|
||||
|
||||
return (
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
|
||||
nonzero_proposal_len_indices, transposed):
|
||||
"""Remove sequences from nonzero_proposal_len_indices and reset
|
||||
their proposal_len to 0 the draft worker does not provide a proposal
|
||||
(maybe_sampler_output=None). This can avoid scoring overheads.
|
||||
"""
|
||||
|
||||
# If maybe_sampler_output is None, then the draft worker did not
|
||||
# provide a proposal for any sequence and thus no action needed.
|
||||
# Also we do not support transposed maybe_sampler_output for now
|
||||
# because it seems not straightforward for draft workers outputting
|
||||
# transposed sampler outputs to handle the case of no proposal.
|
||||
if maybe_sampler_output is None or transposed:
|
||||
return (proposal_lens, maybe_sampler_output,
|
||||
nonzero_proposal_len_indices)
|
||||
|
||||
new_proposal_lens: List[int] = []
|
||||
new_nonzero_proposal_len_indices: List[int] = []
|
||||
new_maybe_sampler_output: List[SamplerOutput] = []
|
||||
nonzero_proposal_len_idx_ptr = 0
|
||||
seq_idx = 0
|
||||
while seq_idx < len(
|
||||
proposal_lens) and nonzero_proposal_len_idx_ptr < len(
|
||||
nonzero_proposal_len_indices):
|
||||
if seq_idx < nonzero_proposal_len_indices[
|
||||
nonzero_proposal_len_idx_ptr]:
|
||||
# Sequence is not in the original nonzero_proposal_len_indices,
|
||||
# meaning that it has a proposal length of 0 before sending to
|
||||
# the draft worker.
|
||||
assert proposal_lens[seq_idx] == 0
|
||||
new_proposal_lens.append(0)
|
||||
else:
|
||||
# Sequence is in the original nonzero_proposal_len_indices
|
||||
if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
|
||||
# but does not have a proposal from the draft worker.
|
||||
new_proposal_lens.append(0)
|
||||
else:
|
||||
# and has a proposal from the draft worker. Add it to the
|
||||
# new nonzero proposal list and keep the sampler output.
|
||||
new_proposal_lens.append(proposal_lens[seq_idx])
|
||||
new_nonzero_proposal_len_indices.append(seq_idx)
|
||||
new_maybe_sampler_output.append(
|
||||
maybe_sampler_output[nonzero_proposal_len_idx_ptr])
|
||||
nonzero_proposal_len_idx_ptr += 1
|
||||
seq_idx += 1
|
||||
|
||||
# The remaining sequences should have proposal length of 0.
|
||||
new_proposal_lens.extend(proposal_lens[seq_idx:])
|
||||
|
||||
# We assume sampler_output will not be a list of all Nones.
|
||||
# In this case this function should not be called.
|
||||
assert new_maybe_sampler_output
|
||||
return (new_proposal_lens, new_maybe_sampler_output,
|
||||
new_nonzero_proposal_len_indices)
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
proposal_len: int,
|
||||
maybe_sampler_output: Optional[List[SamplerOutput]],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
sampler_transposed: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.tensor(-1,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len)
|
||||
proposal_probs = torch.tensor(0,
|
||||
dtype=torch.float32,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len,
|
||||
self._vocab_size)
|
||||
proposal_lens_tensor = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
len(proposal_lens))
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
|
||||
sampler_output, sampler_transposed)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = proposal_tokens.new_full(
|
||||
size=(batch_size, *proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = proposal_probs.new_zeros(
|
||||
batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = (
|
||||
entire_proposal_tokens,
|
||||
entire_proposal_probs,
|
||||
)
|
||||
|
||||
proposal_lens_tensor = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
277
vllm/spec_decode/util.py
Normal file
277
vllm/spec_decode/util.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
PromptLogprobs, SequenceGroupMetadata,
|
||||
SequenceOutput)
|
||||
|
||||
SeqId = int
|
||||
|
||||
|
||||
def get_all_num_logprobs(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
|
||||
|
||||
If the sampling params do not call for any logprobs, return 0 for that
|
||||
sequence.
|
||||
"""
|
||||
|
||||
all_num_logprobs: List[int] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
num_logprobs = seq_group_metadata.sampling_params.logprobs
|
||||
if num_logprobs is None:
|
||||
num_logprobs = 0
|
||||
all_num_logprobs.append(num_logprobs)
|
||||
|
||||
return all_num_logprobs
|
||||
|
||||
|
||||
def get_sampled_token_logprobs(
|
||||
# shape [num_steps, batch_size, vocab_size]
|
||||
logprob_tensor: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
|
||||
"""
|
||||
num_steps, batch_size, vocab_size = logprob_tensor.shape
|
||||
|
||||
selected_logprobs = logprob_tensor[
|
||||
torch.arange(num_steps).unsqueeze(1),
|
||||
torch.arange(batch_size),
|
||||
sampled_token_ids,
|
||||
]
|
||||
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
|
||||
-1, -1, vocab_size)
|
||||
sampled_token_ids_ranks = (logprob_tensor
|
||||
> expanded_selected_logprobs).sum(-1).add_(1)
|
||||
|
||||
return sampled_token_ids_ranks, selected_logprobs
|
||||
|
||||
|
||||
def create_logprobs_output(
|
||||
token_id: int,
|
||||
token_id_logprob_rank: int,
|
||||
token_id_logprob: float,
|
||||
topk_token_ids: List[Optional[int]],
|
||||
topk_logprobs: List[Optional[float]],
|
||||
) -> Dict[int, Logprob]:
|
||||
"""Create a Logprob Dict for a token given the sampling results.
|
||||
|
||||
Args:
|
||||
token_id (int): The sampled token for the sequence.
|
||||
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||
token_id_logprob (float): The logprob value of the sampled token.
|
||||
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
||||
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
||||
"""
|
||||
# vLLM logprobs always include the sampled token. In addition, the user may
|
||||
# request topk-logprobs (where top-k varies per user up to max_logprobs).
|
||||
logprobs: Dict[int, Logprob] = {
|
||||
token_id: Logprob(
|
||||
logprob=token_id_logprob,
|
||||
rank=token_id_logprob_rank,
|
||||
),
|
||||
}
|
||||
logprobs.update({
|
||||
topk_token_id: Logprob(
|
||||
logprob=topk_logprob if topk_logprob is not None else 0.0,
|
||||
rank=topk_index + 1,
|
||||
)
|
||||
for topk_index, (topk_token_id, topk_logprob) \
|
||||
in enumerate(zip(topk_token_ids, topk_logprobs)) \
|
||||
if topk_token_id is not None
|
||||
})
|
||||
|
||||
return logprobs
|
||||
|
||||
|
||||
def create_sequence_group_output(
|
||||
token_id: int,
|
||||
token_id_logprob_rank: int,
|
||||
token_id_logprob: float,
|
||||
seq_id: SeqId,
|
||||
topk_token_ids: List[Optional[int]],
|
||||
topk_logprobs: List[Optional[float]],
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None,
|
||||
step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput:
|
||||
"""Create a SequenceGroupOutput given the sampling results.
|
||||
|
||||
Args:
|
||||
token_id (int): The sampled token for the sequence.
|
||||
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||
token_id_logprob (float): The logprob value of the sampled token.
|
||||
seq_id (int): The sequence id.
|
||||
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
|
||||
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
|
||||
step_index: (Optional[int]): The index of the speculative token.
|
||||
"""
|
||||
|
||||
logprobs = create_logprobs_output(
|
||||
token_id,
|
||||
token_id_logprob_rank,
|
||||
token_id_logprob,
|
||||
topk_token_ids,
|
||||
topk_logprobs,
|
||||
)
|
||||
|
||||
return CompletionSequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs=logprobs)
|
||||
],
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
step_index=step_index)
|
||||
|
||||
|
||||
def split_batch_by_proposal_len(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_lens: List[int],
|
||||
) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[
|
||||
List[SequenceGroupMetadata], List[int]]]:
|
||||
"""Utility function that splits a batch based on whether the proposal len is
|
||||
zero or not. We should remove this once vLLM supports per-sequence proposal
|
||||
lens in a batch.
|
||||
"""
|
||||
|
||||
nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
|
||||
zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
|
||||
for i, (seq_group, proposal_len) in enumerate(
|
||||
zip(seq_group_metadata_list, proposal_lens)):
|
||||
seq_groups, indices = nonzero_lists if proposal_len else zero_lists
|
||||
seq_groups.append(seq_group)
|
||||
indices.append(i)
|
||||
return nonzero_lists, zero_lists
|
||||
|
||||
|
||||
def sampler_output_to_torch(
|
||||
sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||
|
||||
sampler_transposed here is used as the indicator for whether
|
||||
we need do additional tensor transpose logic here.
|
||||
|
||||
Returns:
|
||||
sampled_token_ids: torch.Tensor
|
||||
shape: [batch_size, len(sampler_output_list)]
|
||||
|
||||
sampled_token_probs: torch.Tensor
|
||||
shape: [batch_size, len(sampler_output_list), vocab_size]
|
||||
"""
|
||||
|
||||
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||
sampled_token_probs = torch.stack(
|
||||
[
|
||||
sampler_output.sampled_token_probs
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||
sampled_token_logprobs = torch.stack(
|
||||
[sampler_output.logprobs for sampler_output in sampler_output_list],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# shape: [batch_size, num_sampler_output]
|
||||
sampled_token_ids = torch.stack(
|
||||
[
|
||||
sampler_output.sampled_token_ids.flatten()
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
||||
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
|
||||
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
||||
|
||||
if sampler_output_list[0].hidden_states is not None:
|
||||
# shape: [batch_size, num_sampler_output, hidden_dim]
|
||||
sampled_hidden_states = torch.stack(
|
||||
[
|
||||
sampler_output.hidden_states
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_hidden_states = sampled_hidden_states.transpose(0, 1)
|
||||
else:
|
||||
sampled_hidden_states = None
|
||||
|
||||
return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs,
|
||||
sampled_hidden_states)
|
||||
|
||||
|
||||
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
|
||||
vocab_size: int, device: str) -> None:
|
||||
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
|
||||
values. This will be removed in PR 7/9.
|
||||
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
||||
"""
|
||||
values = [
|
||||
sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
|
||||
]
|
||||
assert all(v is None for v in values) or not any(v is None for v in values)
|
||||
if not any(v is None for v in values):
|
||||
# Do nothing if the tensors are already created (usually in unit tests).
|
||||
return
|
||||
|
||||
# Softmax to ensure valid probs.
|
||||
sampler_output.sampled_token_probs = torch.nn.functional.softmax(
|
||||
torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
|
||||
dim=-1)
|
||||
|
||||
sampler_output.sampled_token_ids = torch.randint(low=10,
|
||||
high=100,
|
||||
size=(batch_size, ),
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def nvtx_range(msg, *args, **kwargs):
|
||||
"""
|
||||
Context manager / decorator that pushes an NVTX range at the beginning
|
||||
of its scope, and pops it at the end. If extra arguments are given,
|
||||
they are passed as arguments to msg.format().
|
||||
|
||||
If running with cuda graphs, you must enable nsys cuda graph profiling.
|
||||
|
||||
Arguments:
|
||||
msg (string): message to associate with the range
|
||||
"""
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Basic timer context manager for measuring CPU time.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.end_time = time.time()
|
||||
self.elapsed_time_s = self.end_time - self.start_time
|
||||
self.elapsed_time_ms = self.elapsed_time_s * 1000
|
||||
Reference in New Issue
Block a user