init
This commit is contained in:
0
vllm/spec_decode/__init__.py
Normal file
0
vllm/spec_decode/__init__.py
Normal file
397
vllm/spec_decode/batch_expansion.py
Normal file
397
vllm/spec_decode/batch_expansion.py
Normal file
@@ -0,0 +1,397 @@
|
||||
from itertools import chain, count
|
||||
from typing import Iterator, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
sampler_output_to_torch,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
TokenId = int
|
||||
|
||||
|
||||
class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
"""Implements a speculative scorer that uses batch expansion to get
|
||||
probabilities of speculative tokens according to the scoring model.
|
||||
|
||||
Batch expansion converts a list of sequences and multiple query positions
|
||||
to a new batch of sequences, each with a single query position. This allows
|
||||
for MQA-like scoring in speculative decoding without requiring an MQA
|
||||
kernel.
|
||||
|
||||
It is strictly less efficient than MQA scoring.
|
||||
|
||||
It only supports scoring the top1 proposal tokens of the proposer, instead
|
||||
of topk/tree.
|
||||
"""
|
||||
|
||||
def __init__(self, scorer_worker: WorkerBase, device: str,
|
||||
vocab_size: int):
|
||||
self._scorer_worker = scorer_worker
|
||||
self._device = device
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
"""Score the proposed tokens via the scorer model.
|
||||
|
||||
This converts each input sequence to a set of k+1 target sequences. The
|
||||
target sequences have the unique continuations to be scored and a
|
||||
unique sequence ID that is different from all input sequence ids.
|
||||
|
||||
If a speculative sequence length would exceed the max model length, then
|
||||
no speculation is produced for that sequence.
|
||||
|
||||
Args:
|
||||
execute_model_req: The execution request.
|
||||
proposals: The speculative proposals to score.
|
||||
Returns:
|
||||
SpeculativeScores: The scores of each speculative token, along with
|
||||
which sequences were ignored during scoring.
|
||||
"""
|
||||
|
||||
# TODO(cade) perform this on GPU to remove blocking call.
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||
|
||||
# Filter the list to ignore -1 proposals.
|
||||
proposal_token_ids_list_without_skips = [
|
||||
proposals for proposals in proposal_token_ids_list
|
||||
if -1 not in proposals
|
||||
]
|
||||
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens) = self._expand_batch(
|
||||
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list, ))
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
all_tokens, all_probs, spec_logprobs = self._contract_batch(
|
||||
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
non_spec_indices=non_spec_indices,
|
||||
spec_indices=spec_indices,
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
)
|
||||
|
||||
return SpeculativeScores(
|
||||
probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=spec_logprobs,
|
||||
)
|
||||
|
||||
def _expand_batch(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids_list: List[List[TokenId]],
|
||||
proposal_lens_list: List[int],
|
||||
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
|
||||
"""Given the input sequences and potentially multiple corresponding
|
||||
proposal tokens, create a new batch where each sequence has a single
|
||||
query token.
|
||||
"""
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
spec_seqs, spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=False)
|
||||
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=True)
|
||||
|
||||
target_seq_group_metadata_list = self._create_scoring_model_input(
|
||||
seq_group_metadata_list=spec_seqs,
|
||||
proposal_token_ids=proposal_token_ids_list,
|
||||
# NOTE: We determine the seq ids in the expanded batch using the
|
||||
# full seq_group_metadata_list, instead of only spec_seqs.
|
||||
target_seq_ids_iter=self._create_target_seq_id_iterator(
|
||||
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
|
||||
)
|
||||
|
||||
num_scoring_tokens = len(target_seq_group_metadata_list)
|
||||
target_seq_group_metadata_list.extend(non_spec_seqs)
|
||||
|
||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_batch(
|
||||
self, contracted_bs: int,
|
||||
target_sampler_output: List[SamplerOutput],
|
||||
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
||||
non_spec_indices: List[int], spec_indices: List[int],
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
|
||||
contracted_bs is the original batch size, and the batch size that the
|
||||
target_sampler_output will be contracted to.
|
||||
"""
|
||||
(target_token_ids, target_probs, target_logprobs,
|
||||
non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs) = self._split_scoring_output(
|
||||
target_sampler_output, num_scoring_tokens)
|
||||
|
||||
# Map distinct sequences used to score each token
|
||||
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||
expanded_batch_size, k = proposals.proposal_token_ids.shape
|
||||
|
||||
# The number of tokens in the expanded batch used for speculation is
|
||||
# equal to the total expanded batch size minus the number of samples for
|
||||
# non-speculative sequences.
|
||||
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
|
||||
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||
|
||||
target_token_ids = target_token_ids.squeeze().reshape(
|
||||
spec_expanded_bs, k + 1)
|
||||
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
|
||||
self._vocab_size)
|
||||
target_logprobs = target_logprobs.squeeze().reshape(
|
||||
spec_expanded_bs, k + 1, self._vocab_size)
|
||||
|
||||
all_tokens = torch.full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1,
|
||||
device=self._device,
|
||||
dtype=torch.long)
|
||||
all_probs = torch.zeros(contracted_bs,
|
||||
k + 1,
|
||||
self._vocab_size,
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
all_logprobs = torch.full(size=(
|
||||
contracted_bs,
|
||||
k + 1,
|
||||
self._vocab_size,
|
||||
),
|
||||
fill_value=-float("inf"),
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
|
||||
if non_spec_indices:
|
||||
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
||||
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
||||
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
|
||||
|
||||
if spec_indices:
|
||||
all_tokens[spec_indices] = target_token_ids
|
||||
all_probs[spec_indices] = target_probs
|
||||
all_logprobs[spec_indices] = target_logprobs
|
||||
|
||||
return all_tokens, all_probs, all_logprobs
|
||||
|
||||
def _create_scoring_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Given the original input sequences and proposed tokens from the draft
|
||||
model, create a list of target sequences that can be used for scoring.
|
||||
|
||||
target_seq_ids_iter provides sequence ids for the expanded batch,
|
||||
fulfilling the requirement that no seq id in the expanded batch is equal
|
||||
to the seq id in the original batch.
|
||||
"""
|
||||
|
||||
if not seq_group_metadata_list:
|
||||
return []
|
||||
|
||||
target_seq_group_metadata = list(
|
||||
chain.from_iterable(
|
||||
self._create_target_seq_group_metadata(
|
||||
seq_group_metadata,
|
||||
proposal_token_ids,
|
||||
i,
|
||||
target_seq_ids_iter,
|
||||
) for i, seq_group_metadata in enumerate(
|
||||
seq_group_metadata_list)))
|
||||
|
||||
return target_seq_group_metadata
|
||||
|
||||
def _create_target_seq_group_metadata(
|
||||
self,
|
||||
input_seq_group_metadata: SequenceGroupMetadata,
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
batch_index: int,
|
||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Given an input sequence group metadata and a list of draft tokens,
|
||||
create a list of target SequenceGroupMetadata, one for each
|
||||
token id that needs to be scored.
|
||||
|
||||
Naive speculative decoding requires K target model scores, one for each
|
||||
draft model token. However one can add a bonus token such that if each
|
||||
token is accepted, then a final token may be sampled from the model.
|
||||
This function creates K+1 target SequenceGroupMetadata to take
|
||||
advantage of the bonus token.
|
||||
"""
|
||||
assert not input_seq_group_metadata.is_prompt, (
|
||||
"Speculating on "
|
||||
"prompts not yet supported")
|
||||
assert len(input_seq_group_metadata.seq_data) == 1, (
|
||||
"Beam search "
|
||||
"not supported in speculative decoding")
|
||||
input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
|
||||
|
||||
token_ids_to_score = self._get_token_ids_to_score(
|
||||
proposal_token_ids[batch_index])
|
||||
|
||||
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for token_ids in token_ids_to_score:
|
||||
target_seq_group_metadata_list.append(
|
||||
self._create_single_target_seq_group_metadata(
|
||||
input_seq_group_metadata,
|
||||
input_seq_id,
|
||||
next(target_seq_ids_iter),
|
||||
token_ids,
|
||||
))
|
||||
|
||||
return target_seq_group_metadata_list
|
||||
|
||||
def _create_single_target_seq_group_metadata(
|
||||
self,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_id: SeqId,
|
||||
target_seq_id: TargetSeqId,
|
||||
token_ids: List[TokenId],
|
||||
) -> SequenceGroupMetadata:
|
||||
"""Create a single target SequenceGroupMetadata.
|
||||
|
||||
Args:
|
||||
seq_group_metadata: The metadata for the input sequence.
|
||||
seq_id: The input sequence ID.
|
||||
target_seq_id: The corresponding target sequence ID.
|
||||
token_ids: The list of token ids that are to be appended to the
|
||||
input sequence.
|
||||
"""
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
|
||||
|
||||
return SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
seq_data={
|
||||
target_seq_id:
|
||||
SequenceData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
output_token_ids=new_output_token_ids,
|
||||
),
|
||||
},
|
||||
sampling_params=seq_group_metadata.sampling_params,
|
||||
block_tables={
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
)
|
||||
|
||||
def _split_scoring_output(
|
||||
self, sampler_output: SamplerOutput, num_scoring_tokens: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
torch.Tensor, torch.Tensor]:
|
||||
"""Split the target model output into speculative and non-speculative
|
||||
output.
|
||||
"""
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
#
|
||||
# First samples are from speculative scoring, latter samples are non-
|
||||
# speculative samples.
|
||||
split_sizes = [
|
||||
num_scoring_tokens,
|
||||
sampler_output.sampled_token_ids.numel() - num_scoring_tokens
|
||||
]
|
||||
(spec_probs, non_spec_probs
|
||||
) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||
(spec_sampled_tokens, non_spec_sampled_tokens
|
||||
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
|
||||
(
|
||||
spec_logprobs,
|
||||
non_spec_logprobs,
|
||||
) = sampler_output.logprobs.split(split_sizes)
|
||||
|
||||
# Convert scores to tensors.
|
||||
sampler_output.sampled_token_probs = spec_probs
|
||||
sampler_output.sampled_token_ids = spec_sampled_tokens
|
||||
sampler_output.logprobs = spec_logprobs
|
||||
(target_token_ids, target_probs,
|
||||
target_logprobs) = sampler_output_to_torch([sampler_output], True)
|
||||
|
||||
# Convert non-speculative output tokens to tensors.
|
||||
sampler_output.sampled_token_probs = non_spec_probs
|
||||
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
||||
sampler_output.logprobs = non_spec_logprobs
|
||||
(non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
|
||||
True)
|
||||
|
||||
return (target_token_ids, target_probs, target_logprobs,
|
||||
non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs)
|
||||
|
||||
def _create_target_seq_id_iterator(
|
||||
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||
"""Create an iterator for creating target sequence ids.
|
||||
Target sequence ids are distinct from sequence ids because we create a
|
||||
distinct target sequence id for each proposal token to be scored.
|
||||
|
||||
This implementation increments a counter starting at 1 + max of all
|
||||
provided input sequence ids.
|
||||
"""
|
||||
return count(start=max(seq_ids) + 1)
|
||||
|
||||
def _get_token_ids_to_score(
|
||||
self,
|
||||
full_spec_token_ids: List[TokenId] # shape: [k]
|
||||
) -> List[List[TokenId]]:
|
||||
"""Given an int tensor of proposal token ids, return a list of
|
||||
token ids that should be scored.
|
||||
|
||||
Returns k+1 output lists. The additional one is used for generating the
|
||||
bonus token.
|
||||
|
||||
Example:
|
||||
Input: [0, 1, 2, 3] (k=4)
|
||||
Output: (k+1 lists)
|
||||
[]
|
||||
[0]
|
||||
[0, 1]
|
||||
[0, 1, 2]
|
||||
[0, 1, 2, 3]
|
||||
"""
|
||||
empty_token_ids: List[TokenId] = []
|
||||
|
||||
token_ids_to_score = [empty_token_ids]
|
||||
token_ids_to_score.extend([
|
||||
full_spec_token_ids[:i + 1]
|
||||
for i in range(len(full_spec_token_ids))
|
||||
])
|
||||
return token_ids_to_score
|
||||
73
vllm/spec_decode/interfaces.py
Normal file
73
vllm/spec_decode/interfaces.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeculativeProposals:
|
||||
"""Datastructure used to represent proposal tokens from some proposer. It
|
||||
also tracks how many speculative tokens each sequence has.
|
||||
"""
|
||||
|
||||
# Speculative proposal tokens.
|
||||
proposal_token_ids: torch.Tensor
|
||||
|
||||
# Probabilities of the proposal tokens according to the proposer.
|
||||
proposal_probs: torch.Tensor
|
||||
|
||||
# The valid length of each proposal; can be zero.
|
||||
proposal_lens: torch.Tensor
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeProposals("
|
||||
f"proposal_token_ids={self.proposal_token_ids}, "
|
||||
f"proposal_probs={self.proposal_probs.shape}, "
|
||||
f"proposal_lens={self.proposal_lens})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeculativeScores:
|
||||
"""Datastructure used to represent the scores of speculative tokens
|
||||
according to the scoring model.
|
||||
"""
|
||||
|
||||
# Probabilities of the speculative tokens according to the scoring model.
|
||||
probs: torch.Tensor
|
||||
|
||||
# Log-probabilities of the speculative tokens according to the scoring
|
||||
# model. These values can be used to generate Logprob objects that are
|
||||
# returned to the user.
|
||||
logprobs: torch.Tensor
|
||||
|
||||
# Token ids sampled from the scoring model. Used for speculative bonus
|
||||
# tokens and also non-speculative normal decoding.
|
||||
token_ids: torch.Tensor
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeScores("
|
||||
f"probs={self.probs.shape}, "
|
||||
f"token_ids={self.token_ids.shape})")
|
||||
|
||||
|
||||
class SpeculativeProposer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SpeculativeScorer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
raise NotImplementedError
|
||||
191
vllm/spec_decode/metrics.py
Normal file
191
vllm/spec_decode/metrics.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecDecodeWorkerMetrics:
|
||||
"""Dataclass holding metrics emitted from the spec decode worker.
|
||||
"""
|
||||
|
||||
# The empirical acceptance rate of the proposal method on a per-token basis.
|
||||
# This is useful for evaluating how well the proposal method aligns with the
|
||||
# scoring method.
|
||||
draft_acceptance_rate: float
|
||||
|
||||
# The empirical efficiency, measured as the number of tokens emitted by the
|
||||
# system divided by the number of tokens that could be emitted by the system
|
||||
# if the proposal method were perfect.
|
||||
system_efficiency: float
|
||||
|
||||
# The number of speculative tokens produced by the proposal method.
|
||||
draft_tokens: int
|
||||
|
||||
# The number of tokens emitted by the entire system.
|
||||
emitted_tokens: int
|
||||
|
||||
# The number of tokens accepted by the scoring model and verification
|
||||
# routine, e.g. Llama2-70B and lossless rejection sampling.
|
||||
#
|
||||
# NOTE: Any token accepted by the verification routine is considered
|
||||
# accepted (regardless of if the speculative prefix is also accepted). The
|
||||
# user will usually see less accepted tokens. This metric is helpful when
|
||||
# evaluating alignment of the proposal method with the scoring model.
|
||||
accepted_tokens: int
|
||||
|
||||
# The number of speculative tokens per sequence.
|
||||
num_spec_tokens: int
|
||||
|
||||
|
||||
Timer = Callable[[], float]
|
||||
|
||||
|
||||
class AsyncMetricsCollector:
|
||||
"""Class which copies rejection sampler metrics from the device to CPU on a
|
||||
non-default Torch stream.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rejection_sampler: RejectionSampler,
|
||||
timer: Optional[Timer] = None,
|
||||
collect_interval_s: float = 5.0):
|
||||
self._rejection_sampler = rejection_sampler
|
||||
self._timer = time.time if timer is None else timer
|
||||
|
||||
self._rank: Optional[int] = None
|
||||
|
||||
# We don't have a device set yet.
|
||||
self._copy_stream: Optional[torch.cuda.Stream] = None
|
||||
|
||||
self._in_flight_copy: Optional[torch.cuda.Event] = None
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
self._aggregate_num_accepted_tokens = torch.tensor(
|
||||
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
||||
self._aggregate_num_emitted_tokens = torch.tensor(
|
||||
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
||||
self._aggregate_num_draft_tokens = 0
|
||||
|
||||
self._rejsample_metrics_collect_interval_s = collect_interval_s
|
||||
self._last_metrics_collect_time = self._timer()
|
||||
|
||||
def init_gpu_tensors(self, rank: int) -> None:
|
||||
self._rank = rank
|
||||
self._copy_stream = torch.musa.Stream()
|
||||
|
||||
def maybe_collect_rejsample_metrics(
|
||||
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
||||
|
||||
# If a copy was initiated in the previous call, collect and return.
|
||||
if self._in_flight_copy is not None:
|
||||
ready_event = self._in_flight_copy
|
||||
self._in_flight_copy = None
|
||||
return self._collect_rejsample_metrics(k, ready_event)
|
||||
|
||||
# Otherwise, check if we should start a new copy.
|
||||
if self._should_collect_rejsample_metrics(self._timer()):
|
||||
assert self._in_flight_copy is None
|
||||
self._in_flight_copy = self._copy_rejsample_metrics_async()
|
||||
|
||||
return None
|
||||
|
||||
def _should_collect_rejsample_metrics(self, now: float) -> bool:
|
||||
"""Return whether or not this iteration should print rejection sampling
|
||||
metrics.
|
||||
"""
|
||||
if self._rank != 0:
|
||||
return False
|
||||
|
||||
if (now - self._last_metrics_collect_time <
|
||||
self._rejsample_metrics_collect_interval_s):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
|
||||
"""Copy rejection sampling metrics (number of accepted tokens, etc) to
|
||||
CPU asynchronously.
|
||||
|
||||
Returns a CUDA event recording when the copy is complete.
|
||||
"""
|
||||
assert self._copy_stream is not None
|
||||
self._copy_stream.wait_stream(torch.musa.current_stream())
|
||||
|
||||
with torch.musa.stream(self._copy_stream):
|
||||
self._aggregate_num_accepted_tokens.copy_(
|
||||
self._rejection_sampler.num_accepted_tokens, non_blocking=True)
|
||||
self._aggregate_num_emitted_tokens.copy_(
|
||||
self._rejection_sampler.num_emitted_tokens, non_blocking=True)
|
||||
# Number of draft tokens is calculated on CPU, so no copy is
|
||||
# required.
|
||||
self._aggregate_num_draft_tokens = (
|
||||
self._rejection_sampler.num_draft_tokens)
|
||||
|
||||
aggregate_metrics_ready = torch.musa.Event()
|
||||
aggregate_metrics_ready.record(self._copy_stream)
|
||||
|
||||
return aggregate_metrics_ready
|
||||
|
||||
def _collect_rejsample_metrics(
|
||||
self, k: int,
|
||||
ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
|
||||
"""Create metrics object from statistics copied asynchronously.
|
||||
|
||||
Args:
|
||||
k: int. The number of speculative tokens; used to determine system
|
||||
efficiency.
|
||||
ready_event: torch.cuda.Event. The CUDA event recording when the
|
||||
async GPU->CPU copy is complete.
|
||||
"""
|
||||
|
||||
ready_event.synchronize()
|
||||
accepted_tokens = self._aggregate_num_accepted_tokens.item()
|
||||
emitted_tokens = self._aggregate_num_emitted_tokens.item()
|
||||
draft_tokens = self._aggregate_num_draft_tokens
|
||||
|
||||
max_num_emitted_tokens = self.get_max_num_emitted_tokens(
|
||||
draft_tokens, k)
|
||||
|
||||
if draft_tokens > 0:
|
||||
draft_acceptance_rate = accepted_tokens / draft_tokens
|
||||
else:
|
||||
draft_acceptance_rate = float("nan")
|
||||
|
||||
if max_num_emitted_tokens > 0:
|
||||
system_efficiency = emitted_tokens / max_num_emitted_tokens
|
||||
else:
|
||||
system_efficiency = float("nan")
|
||||
|
||||
return SpecDecodeWorkerMetrics(
|
||||
num_spec_tokens=k,
|
||||
draft_acceptance_rate=draft_acceptance_rate,
|
||||
system_efficiency=system_efficiency,
|
||||
accepted_tokens=accepted_tokens,
|
||||
draft_tokens=draft_tokens,
|
||||
emitted_tokens=emitted_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
|
||||
"""Calculate the number of emitted tokens, assuming all tokens are
|
||||
accepted.
|
||||
|
||||
This is equal to the number of sequences that have been speculated on,
|
||||
times (speculation len + 1). The +1 comes from the bonus token.
|
||||
"""
|
||||
# Determine the number of sequences that have been speculated on. Since
|
||||
# the batch size can be variable, we divide by k.
|
||||
assert draft_tokens % k == 0
|
||||
total_num_spec_seqs = draft_tokens // k
|
||||
|
||||
# A single sequence may emit k accepted tokens and one bonus token in
|
||||
# the best case.
|
||||
num_emitted_per_seq_if_all_accepted = k + 1
|
||||
|
||||
# The max num of emitted tokens is the number of speculated sequences
|
||||
# times the max emitted per seq.
|
||||
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted
|
||||
203
vllm/spec_decode/multi_step_worker.py
Normal file
203
vllm/spec_decode/multi_step_worker.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import copy
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
class MultiStepWorker(Worker):
|
||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||
multiple forward passes in a single call, assuming the scheduler has
|
||||
allocated enough space to store the additional KV. This reduces overhead
|
||||
by invoking the scheduler less.
|
||||
|
||||
The MultiStepWorker does not support cache swap operations, or beam search.
|
||||
Cache swap operations do not require large modifications. On the other hand,
|
||||
beam search requires memory allocations during sequence forks and thus
|
||||
requires more thought for MultiStepWorker support.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
def init_device(self):
|
||||
super().init_device()
|
||||
|
||||
self._proposer = Top1Proposer(
|
||||
self,
|
||||
self.device,
|
||||
self.vocab_size,
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
# Need include_gpu_probs_tensor for multi_step_worker
|
||||
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
whether torch tensor in sampler output need to be transposed in latter
|
||||
sampler_output_to_torch logic.
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
# Shallow copy input data so modifications (such as appending tokens)
|
||||
# do not cause side-effects.
|
||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||
execute_model_req.seq_group_metadata_list)
|
||||
copied_execute_model_req = execute_model_req.clone(
|
||||
copied_seq_group_metadata_list)
|
||||
|
||||
# Assert enough KV space for sample_len tokens per sequence.
|
||||
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
|
||||
sample_len)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs = []
|
||||
for _ in range(sample_len):
|
||||
model_output = super().execute_model(
|
||||
execute_model_req=copied_execute_model_req)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
copied_seq_group_metadata_list)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
return model_outputs, True
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_proposals(execute_model_req)
|
||||
|
||||
def _append_new_tokens(
|
||||
self, model_output: SamplerOutput,
|
||||
seq_group_metadata_list: SequenceGroupMetadata) -> None:
|
||||
"""Given model output from a single run, append the tokens to the
|
||||
sequences. This is normally done outside of the worker, but it is
|
||||
required if the worker is to perform multiple forward passes.
|
||||
"""
|
||||
for seq_group_metadata, sequence_group_outputs in zip(
|
||||
seq_group_metadata_list, model_output):
|
||||
seq_group_metadata.is_prompt = False
|
||||
|
||||
for seq_output in sequence_group_outputs.samples:
|
||||
# NOTE: Beam search is not supported, so we can assume that
|
||||
# parent_seq_id == seq_id.
|
||||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||
|
||||
token_id = seq_output.output_token
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob)
|
||||
|
||||
def _shallow_copy_inputs(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Copy input data structures to remove side-effects when input data
|
||||
structures are shared with other modules.
|
||||
|
||||
Helpful when the vLLM scheduler runs in the same process as the worker.
|
||||
The alternative is deep-copying (or other form of deep copy); this has
|
||||
performance downsides.
|
||||
"""
|
||||
|
||||
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
|
||||
# append tokens and change is_prompt without external side-effects.
|
||||
new_seq_group_metadata_list = []
|
||||
|
||||
for old_seq_group_metadata in seq_group_metadata_list:
|
||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||
seq_group_metadata = copy.copy(old_seq_group_metadata)
|
||||
new_seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# We must shallow-copy seq_data as we will append token ids
|
||||
new_seq_data = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
new_seq_data[
|
||||
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
|
||||
|
||||
seq_group_metadata.seq_data = new_seq_data
|
||||
|
||||
return new_seq_group_metadata_list
|
||||
|
||||
def _assert_enough_kv_space(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
num_steps: int) -> None:
|
||||
"""Assert there are enough physical blocks per sequence to store the
|
||||
current KV plus additional KV from num_steps tokens.
|
||||
"""
|
||||
assert self.model_runner.block_size is not None
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Only one seq_id is guaranteed because there is no beam search.
|
||||
seq_id = list(seq_group_metadata.seq_data.keys())[0]
|
||||
seq = seq_group_metadata.seq_data[seq_id]
|
||||
|
||||
# After num_steps, the seq len will be the current seq len
|
||||
# plus one token per step.
|
||||
final_seq_len = seq.get_len() + num_steps
|
||||
|
||||
# We will have final_seq_len - 1 KV because vLLM saves KV for a
|
||||
# token in the iteration after the token was generated.
|
||||
required_num_kv_slots = final_seq_len - 1
|
||||
|
||||
# The allocated number of kv slots is the number of allocated blocks
|
||||
# times the number of slots of block.
|
||||
number_physical_blocks = len(
|
||||
seq_group_metadata.block_tables[seq_id])
|
||||
allocated_kv_slots = (number_physical_blocks *
|
||||
self.model_runner.block_size)
|
||||
|
||||
if required_num_kv_slots > allocated_kv_slots:
|
||||
request_id = seq_group_metadata.request_id
|
||||
raise ValueError(
|
||||
"The worker attempted to run "
|
||||
f"{num_steps} times but found insufficient KV space for "
|
||||
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
|
||||
f"{required_num_kv_slots=}).")
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""MultiStepWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
||||
176
vllm/spec_decode/ngram_worker.py
Normal file
176
vllm/spec_decode/ngram_worker.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
||||
|
||||
class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
"""NGramWorker provides a light drafter without need for model.
|
||||
|
||||
Current NGramWorker only implement prompt lookup decoding,
|
||||
and in future we may also do RAG type drafter and other scenerios
|
||||
which don't rely on LLM model to give proposals.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Get local_rank/vocab_size from kwargs attribute
|
||||
self.local_rank = kwargs["local_rank"]
|
||||
self.vocab_size = kwargs["model_config"].get_vocab_size()
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
|
||||
ngram_prompt_lookup_max: int):
|
||||
# Search valid candidate window between
|
||||
# ngram_prompt_lookup_min/ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
||||
|
||||
def init_device(self):
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
self.load_model = lambda *args, **kwargs: None
|
||||
|
||||
# Current only support Top1Proposer
|
||||
self._proposer = Top1Proposer(
|
||||
self,
|
||||
device=self.device,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
# NGram don't need gpu sampler
|
||||
pass
|
||||
|
||||
def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
|
||||
"""NGram doesn't depend on model execution, just pass this function"""
|
||||
pass
|
||||
|
||||
def determine_num_available_blocks(self) -> None:
|
||||
"""NGram doesn't depend on model execution, no need to check blocks"""
|
||||
pass
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""As there is no cache need to handle, just pass this function"""
|
||||
pass
|
||||
|
||||
def get_cache_block_size_bytes(self):
|
||||
"""Return the size of a cache block in bytes."""
|
||||
return 0
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||
sampler output, one per SequenceGroupMetadata.
|
||||
|
||||
For ngram worker, we already done needed transposed internal, so the
|
||||
indicator pass to sampler_output_to_torch shall be False.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
arr = []
|
||||
has_spec_out = False
|
||||
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
|
||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_length = seq_data.get_len()
|
||||
|
||||
for ngram_size in range(
|
||||
min(self.ngram_prompt_lookup_max, input_length - 1),
|
||||
self.ngram_prompt_lookup_min,
|
||||
-1,
|
||||
):
|
||||
ngram_tensor = input_ids[-1 * ngram_size:]
|
||||
windows = input_ids.unfold(dimension=0,
|
||||
size=ngram_size,
|
||||
step=1)
|
||||
matches = (windows == ngram_tensor).all(dim=1)
|
||||
match_indices = matches.nonzero(as_tuple=True)[0]
|
||||
if match_indices.size()[0] > 1:
|
||||
has_spec_out = True
|
||||
res = seq_data.get_token_ids()
|
||||
res = res[match_indices[0] + ngram_size:match_indices[0] +
|
||||
ngram_size + sample_len]
|
||||
res_len = len(res)
|
||||
# pad 0 towards output as sample_len tokens required
|
||||
res += [0] * (sample_len - res_len)
|
||||
|
||||
break
|
||||
else:
|
||||
# if no candidate found, fill with 0
|
||||
res = [0] * sample_len
|
||||
|
||||
arr.append(res)
|
||||
|
||||
if not has_spec_out:
|
||||
return None, False
|
||||
|
||||
outputs = []
|
||||
token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device)
|
||||
indices = token_ids.unsqueeze(2)
|
||||
|
||||
token_probs = torch.zeros(
|
||||
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
||||
self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
token_probs.scatter_(2, indices, 1)
|
||||
token_logprobs = torch.zeros(
|
||||
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
||||
self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(len(execute_model_req.seq_group_metadata_list)):
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
sampled_token_probs=token_probs[i],
|
||||
logprobs=token_logprobs,
|
||||
sampled_token_ids=token_ids[i],
|
||||
))
|
||||
return outputs, False
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_proposals(execute_model_req)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""NGramWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support beam search.")
|
||||
472
vllm/spec_decode/spec_decode_worker.py
Normal file
472
vllm/spec_decode/spec_decode_worker.py
Normal file
@@ -0,0 +1,472 @@
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||
get_all_num_logprobs, get_all_seq_ids,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"""Worker which implements speculative decoding.
|
||||
|
||||
Speculative decoding reduces decoding per-token latency by using a proposal
|
||||
method, such as a small draft model, to speculate ahead of a larger LLM. The
|
||||
probabilities of the speculative tokens are then determined by the larger
|
||||
LLM, after which some verification routine determines which (if any) of the
|
||||
speculative tokens are accepted by the larger LLM.
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/2188 and
|
||||
https://github.com/vllm-project/vllm/pull/3103 for more info.
|
||||
|
||||
The current implementation has the following limitations:
|
||||
* Only draft-model proposal is implemented (contributions for more forms are
|
||||
welcome!).
|
||||
* Only top-1 proposal and scoring are implemented. Tree-attention is left as
|
||||
future work.
|
||||
* Only lossless rejection sampling is supported. Contributions adding lossy
|
||||
verification routines are welcome (e.g. Medusa's typical acceptance).
|
||||
* All sequences in a batch must have the same proposal length, or zero. This
|
||||
can be improved by having per-sequence speculation in the future.
|
||||
* The scoring forward pass is done without an MQA kernel, which is
|
||||
suboptimal especially as the batch size, proposal length, and sequence
|
||||
lengths grow. Contributions to add a MQA scoring are welcome once
|
||||
correctness tests pass.
|
||||
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_worker(
|
||||
cls,
|
||||
scorer_worker: WorkerBase,
|
||||
draft_worker_kwargs,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
if "ngram_prompt_lookup_max" in draft_worker_kwargs:
|
||||
ngram_prompt_lookup_max = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
else:
|
||||
ngram_prompt_lookup_max = 0
|
||||
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
else:
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
# TODO(cade) disable strict mode for speedup.
|
||||
rejection_sampler=RejectionSampler(strict_mode=True),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proposer_worker: WorkerBase,
|
||||
scorer_worker: WorkerBase,
|
||||
rejection_sampler: RejectionSampler,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
):
|
||||
"""
|
||||
Create a SpecDecodeWorker.
|
||||
|
||||
Args:
|
||||
proposer_worker: A worker that can produce speculative tokens for
|
||||
sequences.
|
||||
scorer_worker: A worker that produces probabilities of speculative
|
||||
tokens according to some base model. Typically a vanilla vLLM
|
||||
Worker.
|
||||
rejection_sampler: A Torch module used to perform modified rejection
|
||||
sampling for speculative decoding.
|
||||
metrics_collector: Helper class for collecting metrics; can be set
|
||||
for testing purposes.
|
||||
"""
|
||||
self.proposer_worker = proposer_worker
|
||||
self.scorer_worker = scorer_worker
|
||||
self.rejection_sampler = rejection_sampler
|
||||
|
||||
self._metrics = AsyncMetricsCollector(
|
||||
rejection_sampler
|
||||
) if metrics_collector is None else metrics_collector
|
||||
|
||||
self.probs_dtype = self.rejection_sampler.probs_dtype
|
||||
self.token_id_dtype = self.rejection_sampler.token_id_dtype
|
||||
|
||||
# Lazy initiazliation.
|
||||
self.scorer: SpeculativeScorer
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize both scorer and proposer models.
|
||||
"""
|
||||
# The scorer worker model is initialized first in case the proposer
|
||||
# model has a smaller TP degree than the target worker.
|
||||
self.scorer_worker.init_device()
|
||||
self.proposer_worker.init_device()
|
||||
|
||||
# NOTE(cade): load_model is not part of the WorkerBase interface.
|
||||
self.scorer_worker.load_model()
|
||||
self.proposer_worker.load_model()
|
||||
|
||||
self._metrics.init_gpu_tensors(self.rank)
|
||||
self.rejection_sampler.init_gpu_tensors(self.rank)
|
||||
self.scorer = BatchExpansionTop1Scorer(
|
||||
scorer_worker=self.scorer_worker,
|
||||
device=self.device,
|
||||
vocab_size=self._vocab_size)
|
||||
|
||||
self._configure_model_sampler_for_spec_decode()
|
||||
|
||||
def _configure_model_sampler_for_spec_decode(self):
|
||||
"""Configure model sampler to emit GPU tensors. This allows spec decode
|
||||
to keep data on device without transferring to CPU and serializing,
|
||||
which significantly reduces overhead of rejection sampling.
|
||||
|
||||
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
|
||||
design is to have the "move to CPU and serialize" sampling decision be
|
||||
done outside of the model/sampler; this way the "last-mile" worker
|
||||
object which interfaces with the scheduler can serialize and incur the
|
||||
performance hit as necessary. This allows us to run the worker several
|
||||
iterations in a row without incurring the "move to CPU and serialize"
|
||||
performance penalty.
|
||||
|
||||
Since this requires a large change to vLLM, we defer it to later and
|
||||
temporarily accept this broken abstraction boundary.
|
||||
|
||||
NOTE(cade): This will require a special check if the proposer worker
|
||||
does not have a sampler (e.g. ngram speculation).
|
||||
"""
|
||||
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
|
||||
) = True
|
||||
self.proposer_worker.set_include_gpu_probs_tensor()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of cache blocks to use.
|
||||
|
||||
This is done by profiling the scorer model (which is typically the
|
||||
larger of the two). Then the total memory which would be used by the
|
||||
scorer cache is divided evenly between the proposer and scorer model KV,
|
||||
such that the number of blocks is equal in both KV caches.
|
||||
"""
|
||||
num_gpu_blocks, num_cpu_blocks = (
|
||||
self.scorer_worker.determine_num_available_blocks())
|
||||
|
||||
scorer_cache_block_size_bytes = (
|
||||
self.scorer_worker.get_cache_block_size_bytes())
|
||||
proposer_cache_block_size_bytes = (
|
||||
self.proposer_worker.get_cache_block_size_bytes())
|
||||
|
||||
new_num_gpu_blocks = split_num_cache_blocks_evenly(
|
||||
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
|
||||
num_gpu_blocks)
|
||||
return new_num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the cache engine of the scorer and proposer workers.
|
||||
"""
|
||||
self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Perform speculative decoding on the input batch.
|
||||
"""
|
||||
|
||||
assert execute_model_req.seq_group_metadata_list is not None, (
|
||||
"speculative decoding "
|
||||
"requires non-None seq_group_metadata_list")
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# Used for prefill.
|
||||
if execute_model_req.num_lookahead_slots == 0 or len(
|
||||
execute_model_req.seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(execute_model_req)
|
||||
|
||||
return self._run_speculative_decoding_step(execute_model_req)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Run a prefill step, without any speculation. The input is sent to the
|
||||
proposer and scorer model so that the KV cache is consistent between the
|
||||
two.
|
||||
"""
|
||||
#logger.info("run proposer worker no spec")
|
||||
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
#logger.info("run target worker no spec")
|
||||
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||
assert len(sampler_output) == 1
|
||||
sampler_output = sampler_output[0]
|
||||
|
||||
# Clear device tensors from sampler output. This reduces communication
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.probs = None
|
||||
sampler_output.sampled_tokens = None
|
||||
sampler_output.logprobs = None
|
||||
return [sampler_output]
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||
def _run_speculative_decoding_step(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
"""Execute a single step of speculative decoding.
|
||||
|
||||
This invokes the proposer worker to get k speculative tokens for each
|
||||
sequence, then scores each speculative token using the scoring worker.
|
||||
|
||||
Returns a list of SamplerOutput, each containing a single token per
|
||||
sequence.
|
||||
"""
|
||||
|
||||
#logger.info("get spec proposals")
|
||||
# Generate proposals using draft worker.
|
||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||
|
||||
#logger.info("score proposals")
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
execute_model_req,
|
||||
proposals,
|
||||
)
|
||||
|
||||
#logger.info("verify proposals")
|
||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||
proposals, execute_model_req.num_lookahead_slots)
|
||||
|
||||
#logger.info("create output list")
|
||||
return self._create_output_sampler_list(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
accepted_token_ids,
|
||||
target_logprobs=target_logprobs,
|
||||
k=execute_model_req.num_lookahead_slots)
|
||||
|
||||
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||
def _verify_tokens(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_scores: SpeculativeScores,
|
||||
proposals: SpeculativeProposals,
|
||||
max_proposal_len: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Determine which speculative tokens are accepted using the
|
||||
probabilities of each token according to the proposer and scorer models.
|
||||
|
||||
Returns a tuple of Tensors, one for the accepted token ids and one for
|
||||
the logprobs according to the scoring model.
|
||||
"""
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
_, spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=False)
|
||||
_, non_spec_indices = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list,
|
||||
proposal_lens_list,
|
||||
select_proposal_len_zero=True)
|
||||
original_indices = spec_indices + non_spec_indices
|
||||
|
||||
# Get probabilities of target model, excluding bonus token.
|
||||
proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
|
||||
|
||||
# Get non-speculative sampled tokens from target model.
|
||||
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
||||
|
||||
# Get bonus tokens from target model.
|
||||
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
||||
|
||||
# Get probabilities according to proposal method.
|
||||
proposal_probs = proposals.proposal_probs[spec_indices]
|
||||
|
||||
# Get proposed tokens.
|
||||
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||
|
||||
accepted_token_ids = self.rejection_sampler(
|
||||
target_probs=proposal_verifier_probs,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
draft_token_ids=proposal_token_ids,
|
||||
)
|
||||
|
||||
# Append output tokens from non-speculative sequences to
|
||||
# the accepted token ids tensor.
|
||||
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
|
||||
1).clone()
|
||||
non_spec_token_ids[:, 1:] = -1
|
||||
accepted_token_ids = torch.cat(
|
||||
[accepted_token_ids, non_spec_token_ids])
|
||||
logprobs = proposal_scores.logprobs
|
||||
|
||||
# Rearrange so that results are in the order of the original seq group
|
||||
# metadata.
|
||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
||||
return accepted_token_ids, logprobs
|
||||
|
||||
def _create_output_sampler_list(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
||||
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
||||
k: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Given the accepted token ids, create a list of SamplerOutput.
|
||||
|
||||
The output is padded with -1 tokens such that each sequence has
|
||||
the same number of outputs.
|
||||
"""
|
||||
batch_size, num_steps = accepted_token_ids.shape
|
||||
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
target_logprobs_by_step = target_logprobs.transpose(0, 1)
|
||||
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
|
||||
|
||||
# Get the logprobs/rank of the accepted tokens.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
|
||||
logprob_tensor=target_logprobs_by_step,
|
||||
sampled_token_ids=accepted_token_ids_by_step,
|
||||
)
|
||||
|
||||
# Get the top-k logprobs (which may or may not include the logprob of
|
||||
# the accepted token).
|
||||
(topk_logprobs_by_step,
|
||||
topk_indices_by_step) = target_logprobs_by_step.topk(
|
||||
k=self.scorer_worker.model_config.max_logprobs,
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
||||
# batch.
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
||||
|
||||
# Serialize all tensors to CPU Python lists.
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
accepted_token_id_ranks_by_step = (
|
||||
accepted_token_id_ranks_by_step.tolist())
|
||||
accepted_token_id_logprobs_by_step = (
|
||||
accepted_token_id_logprobs_by_step.tolist())
|
||||
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
|
||||
topk_indices_by_step = topk_indices_by_step.tolist()
|
||||
|
||||
# Construct the output on a per-step, per-sequence basis.
|
||||
sampler_output_list = []
|
||||
for step_index in range(num_steps):
|
||||
if all(token_id == -1
|
||||
for token_id in accepted_token_ids_by_step[step_index]):
|
||||
break
|
||||
|
||||
step_output_token_ids = []
|
||||
for sequence_index in range(batch_size):
|
||||
# Each sequence may have a different num_logprobs; retrieve it.
|
||||
num_logprobs = num_logprobs_per_seq[sequence_index]
|
||||
|
||||
step_output_token_ids.append(
|
||||
create_sequence_group_output(
|
||||
token_id=accepted_token_ids_by_step[step_index]
|
||||
[sequence_index],
|
||||
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
||||
step_index][sequence_index],
|
||||
token_id_logprob=accepted_token_id_logprobs_by_step[
|
||||
step_index][sequence_index],
|
||||
seq_id=seq_ids[sequence_index],
|
||||
topk_token_ids=topk_indices_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
topk_logprobs=topk_logprobs_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
))
|
||||
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
|
||||
maybe_rejsample_metrics = (
|
||||
self._metrics.maybe_collect_rejsample_metrics(k))
|
||||
if maybe_rejsample_metrics is not None:
|
||||
sampler_output_list[
|
||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||
|
||||
return sampler_output_list
|
||||
|
||||
@cached_property
|
||||
def _vocab_size(self) -> int:
|
||||
"""Get the vocab size of the model and make sure it's consistent between
|
||||
draft and target workers.
|
||||
"""
|
||||
vocab_sizes = [
|
||||
worker.vocab_size
|
||||
for worker in [self.proposer_worker, self.scorer_worker]
|
||||
]
|
||||
assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
|
||||
return vocab_sizes[0]
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self.scorer_worker.rank
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.scorer_worker.device
|
||||
|
||||
def get_cache_block_size_bytes(self):
|
||||
"""Return the size of a cache block in bytes.
|
||||
|
||||
This function is only used to compose workers within a SpecDecodeWorker.
|
||||
We leave composing a SpecDecodeWorker within a SpecDecodeWorker
|
||||
undefined for now, although it could be implemented in the future.
|
||||
See https://arxiv.org/abs/2308.04623.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
|
||||
proposer_cache_block_size_bytes: int,
|
||||
total_num_gpu_blocks: int) -> int:
|
||||
"""Given total_num_gpu_blocks, the number of GPU blocks that could be
|
||||
allocate to the target model, this function calculates how many blocks
|
||||
should be given to the draft and target model.
|
||||
|
||||
Note that usually the block size, in bytes, of each model is different,
|
||||
as it's a function of number of KV/layer, number of heads, and hidden
|
||||
dimension size.
|
||||
|
||||
Since the target and draft models allocate the same number of blocks, we
|
||||
simply calculate the number of blocks where if allocated by both models,
|
||||
the total memory usage from KV cache is no larger than the number of
|
||||
blocks allocatable by the target model alone.
|
||||
"""
|
||||
new_num_gpu_blocks = int(
|
||||
total_num_gpu_blocks * scorer_cache_block_size_bytes /
|
||||
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
|
||||
|
||||
return new_num_gpu_blocks
|
||||
200
vllm/spec_decode/top1_proposer.py
Normal file
200
vllm/spec_decode/top1_proposer.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
|
||||
class Top1Proposer(SpeculativeProposer):
|
||||
"""Helper class which separates out sequences which would exceed the max
|
||||
model length when speculated upon.
|
||||
|
||||
This allows combinations of models such as JackFram/llama-68m draft with
|
||||
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
|
||||
2048 while Llama2-13b has max_position_embeddings of 4096.
|
||||
|
||||
We treat the sequences which exceed the proposal draft model length as
|
||||
"non-spec sequences". Essentially they skip the draft model and go through
|
||||
normal decoding in the target model.
|
||||
|
||||
Currently, only proposal_lens of 0 and k are supported, where k is a global
|
||||
batch proposal length. In the future vLLM should support per-sequence
|
||||
proposal lengths.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker: WorkerBase,
|
||||
device: str,
|
||||
vocab_size: int,
|
||||
max_proposal_len: Optional[int] = None,
|
||||
):
|
||||
self._worker = worker
|
||||
self._device = device
|
||||
self.max_proposal_len = max_proposal_len
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
def get_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
proposal_len = execute_model_req.num_lookahead_slots
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
(
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
# sequences.
|
||||
# If sampler_transposed is true, then maybe_sampler_output's
|
||||
# token_ids is like [batch] format in proposal_len size list,
|
||||
# while if it is false, the format would be [proposal_len]
|
||||
# in batch size list
|
||||
nonzero_execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
num_lookahead_slots=proposal_len,
|
||||
)
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
)
|
||||
else:
|
||||
# If no sequences can be speculated, set sampler output to None.
|
||||
maybe_sampler_output = None
|
||||
transposed = False
|
||||
|
||||
# Combine speculative- and non-speculative sequences into the same
|
||||
# representation.
|
||||
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||||
batch_size=len(seq_group_metadata_list),
|
||||
proposal_len=proposal_len,
|
||||
maybe_sampler_output=maybe_sampler_output,
|
||||
proposal_lens=proposal_lens,
|
||||
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||||
sampler_transposed=transposed,
|
||||
)
|
||||
|
||||
proposals = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
)
|
||||
|
||||
return proposals
|
||||
|
||||
def _split_by_max_model_len(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_len: int,
|
||||
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
||||
"""Determine which sequences would exceed the max model length."""
|
||||
|
||||
proposal_lens: List[int] = []
|
||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||
nonzero_proposal_len_indices: List[int] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
seq_len = seq_data.get_len()
|
||||
|
||||
# Currently only proposal lens of 0 or the global batch proposal len
|
||||
# are supported.
|
||||
# If max_proposal_len is defined, then we shall no exccess this
|
||||
# quota for nonzero_proposal
|
||||
if (self.max_proposal_len is None
|
||||
or seq_len + proposal_len < self.max_proposal_len):
|
||||
proposal_lens.append(proposal_len)
|
||||
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
||||
nonzero_proposal_len_indices.append(i)
|
||||
else:
|
||||
proposal_lens.append(0)
|
||||
|
||||
return (
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
)
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
proposal_len: int,
|
||||
maybe_sampler_output: Optional[SamplerOutput],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
sampler_transposed: bool,
|
||||
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.full(
|
||||
size=(
|
||||
batch_size,
|
||||
proposal_len,
|
||||
),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device,
|
||||
)
|
||||
proposal_probs = torch.zeros(
|
||||
batch_size,
|
||||
proposal_len,
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=self._device,
|
||||
)
|
||||
proposal_lens_tensor = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
|
||||
sampler_output, sampler_transposed)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = torch.full(
|
||||
size=(batch_size, *proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device,
|
||||
)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = torch.zeros(
|
||||
batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
dtype=torch.float32,
|
||||
device=self._device,
|
||||
)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = (
|
||||
entire_proposal_tokens,
|
||||
entire_proposal_probs,
|
||||
)
|
||||
|
||||
proposal_lens_tensor = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
228
vllm/spec_decode/util.py
Normal file
228
vllm/spec_decode/util.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceOutput)
|
||||
|
||||
SeqId = int
|
||||
|
||||
|
||||
def get_all_seq_ids(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||
sequence ids.
|
||||
"""
|
||||
return list(
|
||||
chain.from_iterable([
|
||||
seq_group_metadata.seq_data.keys()
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]))
|
||||
|
||||
|
||||
def get_all_num_logprobs(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
|
||||
|
||||
If the sampling params do not call for any logprobs, return 0 for that
|
||||
sequence.
|
||||
"""
|
||||
|
||||
all_num_logprobs = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
num_logprobs = seq_group_metadata.sampling_params.logprobs
|
||||
if seq_group_metadata.sampling_params.logprobs is None:
|
||||
num_logprobs = 0
|
||||
all_num_logprobs.append(num_logprobs)
|
||||
|
||||
return all_num_logprobs
|
||||
|
||||
|
||||
def get_sampled_token_logprobs(
|
||||
# shape [num_steps, batch_size, vocab_size]
|
||||
logprob_tensor: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
|
||||
"""
|
||||
num_steps, batch_size, vocab_size = logprob_tensor.shape
|
||||
|
||||
selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1),
|
||||
torch.arange(batch_size),
|
||||
sampled_token_ids, ]
|
||||
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
|
||||
-1, -1, vocab_size)
|
||||
sampled_token_ids_ranks = (logprob_tensor >=
|
||||
expanded_selected_logprobs).sum(-1)
|
||||
|
||||
return sampled_token_ids_ranks, selected_logprobs
|
||||
|
||||
|
||||
def create_sequence_group_output(
|
||||
token_id: int,
|
||||
token_id_logprob_rank: int,
|
||||
token_id_logprob: float,
|
||||
seq_id: SeqId,
|
||||
topk_token_ids: List[int],
|
||||
topk_logprobs: List[float],
|
||||
) -> SequenceGroupOutput:
|
||||
"""Create a SequenceGroupOutput given the sampling results.
|
||||
|
||||
Args:
|
||||
token_id (int): The sampled token for the sequence.
|
||||
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||
token_id_logprob (float): The logprob value of the sampled token.
|
||||
seq_id (int): The sequence id.
|
||||
topk_token_ids (List[int]): The list of top-k token ids.
|
||||
topk_logprobs (List[float]): The list of top-k logprobs.
|
||||
"""
|
||||
# vLLM logprobs always include the sampled token. In addition, the user may
|
||||
# request topk-logprobs (where top-k varies per user up to max_logprobs).
|
||||
logprobs: Dict[int, Logprob] = {
|
||||
token_id: Logprob(
|
||||
logprob=token_id_logprob,
|
||||
rank=token_id_logprob_rank,
|
||||
),
|
||||
}
|
||||
logprobs.update({
|
||||
topk_token_ids[topk_logprob_index]: Logprob(
|
||||
logprob=topk_logprobs[topk_logprob_index],
|
||||
rank=topk_logprob_index + 1,
|
||||
)
|
||||
for topk_logprob_index, _ in enumerate(topk_token_ids)
|
||||
})
|
||||
|
||||
return SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs=logprobs)
|
||||
],
|
||||
# TODO add prompt logprobs support.
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
|
||||
|
||||
def split_batch_by_proposal_len(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_lens: List[int], select_proposal_len_zero: bool
|
||||
) -> Tuple[List[SequenceGroupMetadata], List[int]]:
|
||||
"""Utility function that splits a batch based on whether the proposal len is
|
||||
zero or not. We should remove this once vLLM supports per-sequence proposal
|
||||
lens in a batch.
|
||||
"""
|
||||
|
||||
if select_proposal_len_zero:
|
||||
predicate = lambda proposal_len: proposal_len == 0
|
||||
else:
|
||||
predicate = lambda proposal_len: proposal_len != 0
|
||||
|
||||
indices = [
|
||||
i for i, (_, proposal_len
|
||||
) in enumerate(zip(seq_group_metadata_list, proposal_lens))
|
||||
if predicate(proposal_len)
|
||||
]
|
||||
seq_groups = [
|
||||
seq_group for seq_group, proposal_len in zip(
|
||||
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
|
||||
]
|
||||
|
||||
return seq_groups, indices
|
||||
|
||||
|
||||
def sampler_output_to_torch(
|
||||
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||
|
||||
sampler_transposed here is used as the indicator for whether
|
||||
we need do additional tensor transpose logic here.
|
||||
|
||||
Returns:
|
||||
sampled_token_ids: torch.Tensor
|
||||
shape: [batch_size, len(sampler_output_list)]
|
||||
|
||||
sampled_token_probs: torch.Tensor
|
||||
shape: [batch_size, len(sampler_output_list), vocab_size]
|
||||
"""
|
||||
|
||||
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||
sampled_token_probs = torch.stack(
|
||||
[
|
||||
sampler_output.sampled_token_probs
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
||||
|
||||
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||
sampled_token_logprobs = torch.stack(
|
||||
[sampler_output.logprobs for sampler_output in sampler_output_list],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
|
||||
|
||||
# shape: [batch_size, num_sampler_output]
|
||||
sampled_token_ids = torch.stack(
|
||||
[
|
||||
sampler_output.sampled_token_ids.flatten()
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
if sampler_transposed:
|
||||
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
||||
|
||||
return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
|
||||
|
||||
|
||||
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
|
||||
vocab_size: int, device: str) -> None:
|
||||
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
|
||||
values. This will be removed in PR 7/9.
|
||||
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
||||
"""
|
||||
values = [
|
||||
sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
|
||||
]
|
||||
assert all(v is None for v in values) or not any(v is None for v in values)
|
||||
if not any(v is None for v in values):
|
||||
# Do nothing if the tensors are already created (usually in unit tests).
|
||||
return
|
||||
|
||||
# Softmax to ensure valid probs.
|
||||
sampler_output.sampled_token_probs = torch.nn.functional.softmax(
|
||||
torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
|
||||
dim=-1)
|
||||
|
||||
sampler_output.sampled_token_ids = torch.randint(low=10,
|
||||
high=100,
|
||||
size=(batch_size, ),
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def nvtx_range(msg, *args, **kwargs):
|
||||
"""
|
||||
Context manager / decorator that pushes an NVTX range at the beginning
|
||||
of its scope, and pops it at the end. If extra arguments are given,
|
||||
they are passed as arguments to msg.format().
|
||||
|
||||
If running with cuda graphs, you must enable nsys cuda graph profiling.
|
||||
|
||||
Arguments:
|
||||
msg (string): message to associate with the range
|
||||
"""
|
||||
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
Reference in New Issue
Block a user