init src 0.9.2
This commit is contained in:
141
vllm/zero_overhead/spec_decode/batch_expansion.py
Normal file
141
vllm/zero_overhead/spec_decode/batch_expansion.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from array import array
|
||||
import numpy as np
|
||||
from itertools import chain, count
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
ExecuteModelRequest, SequenceData,
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
||||
from vllm.utils import async_tensor_h2d
|
||||
from vllm.zero_overhead.utils import get_proposal_lens_list, record_proposal_token_ids
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
TokenId = int
|
||||
|
||||
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
|
||||
|
||||
|
||||
class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
|
||||
|
||||
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
"""Score the proposed tokens via the scorer model.
|
||||
|
||||
This converts each input sequence to a set of k+1 target sequences. The
|
||||
target sequences have the unique continuations to be scored and a
|
||||
unique sequence ID that is different from all input sequence ids.
|
||||
|
||||
If a speculative sequence length would exceed the max model length, then
|
||||
no speculation is produced for that sequence.
|
||||
|
||||
Args:
|
||||
execute_model_req: The execution request.
|
||||
proposals: The speculative proposals to score.
|
||||
Returns:
|
||||
SpeculativeScores: The scores of each speculative token, along with
|
||||
which sequences were ignored during scoring.
|
||||
"""
|
||||
|
||||
proposal_lens_list = get_proposal_lens_list()
|
||||
record_proposal_token_ids(proposals.proposal_token_ids)
|
||||
proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist() # place holder tokens
|
||||
|
||||
# Filter the list to ignore invalid proposals.
|
||||
proposal_token_ids_list_without_skips = [
|
||||
proposals for proposals in proposal_token_ids_list
|
||||
if VLLM_INVALID_TOKEN_ID not in proposals
|
||||
]
|
||||
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens) = self._expand_batch(
|
||||
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list))
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
if not non_spec_indices:
|
||||
# All sequence groups in batch have spec decoding enabled
|
||||
return self._contract_batch_all_spec(
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
)
|
||||
else:
|
||||
# Batch has a mix of spec decode enabled and disabled seq groups
|
||||
return self._contract_batch(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
non_spec_indices=non_spec_indices,
|
||||
spec_indices=spec_indices,
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
)
|
||||
|
||||
def _contract_non_speculative(
|
||||
self, scores: SpeculativeScores,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
|
||||
has_prompt_log: bool) -> SpeculativeScores:
|
||||
"""
|
||||
Augment input `scores` with non-speculative requests outputs.
|
||||
This includes decode requests with speculation turned off, as well
|
||||
as prefill requests when `enable_chunked_prefill` is set.
|
||||
For the latter, prefills are further separated into terminal and
|
||||
non-terminal chunks (from which no token is sampled).
|
||||
"""
|
||||
if not non_spec_indices:
|
||||
return scores
|
||||
|
||||
if has_prompt_log:
|
||||
# When prompt_logprobs is enabled, prefills yield output token
|
||||
# (and respective prob) in the last entry (prompt|out):
|
||||
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
|
||||
# With chunked prefill, non-terminal chunks have -1 on each
|
||||
# position: they're still picked, but they're discarded later.
|
||||
seq_meta = seq_group_metadata_list
|
||||
nospec_sizes = torch.tensor([
|
||||
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
|
||||
for i in non_spec_indices
|
||||
])
|
||||
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
|
||||
else:
|
||||
# In this case only sampled tokens are returned, select all.
|
||||
nospec_sampled_token_idxs = list(
|
||||
range(len(non_spec_outputs.token_ids)))
|
||||
|
||||
nospec_sampled_token_idxs = async_tensor_h2d(nospec_sampled_token_idxs, torch.int32,
|
||||
self._device,
|
||||
True)
|
||||
non_spec_indices = async_tensor_h2d(non_spec_indices, torch.int32,
|
||||
self._device,
|
||||
True)
|
||||
|
||||
scores.token_ids[non_spec_indices, :1] = \
|
||||
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
scores.probs[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
scores.logprobs[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
if scores.hidden_states is not None:
|
||||
assert non_spec_outputs.hidden_states is not None
|
||||
scores.hidden_states[non_spec_indices, :1, :] = \
|
||||
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
|
||||
return scores
|
||||
137
vllm/zero_overhead/spec_decode/muti_step_worker.py
Normal file
137
vllm/zero_overhead/spec_decode/muti_step_worker.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import copy
|
||||
import weakref
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.utils import async_tensor_h2d
|
||||
from vllm.zero_overhead.spec_decode.top1_proproser import ZeroOverheadTop1Proposer
|
||||
from vllm.zero_overhead.utils import SpecStepKind, set_spec_step
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.worker.worker_base import DelegateWorkerBase
|
||||
|
||||
class ZeroOverheadMultiStepWorker(MultiStepWorker):
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.worker.init_device()
|
||||
self._proposer = ZeroOverheadTop1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
self.device,
|
||||
self.vocab_size,
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
whether torch tensor in sampler output need to be transposed in latter
|
||||
sampler_output_to_torch logic.
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
# Expand the batch for sequences with a bonus token.
|
||||
# Perform a forward pass on the expanded batch and filter the
|
||||
# response to retain only the original sequences' responses.
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
if current_platform.is_cuda_alike() and isinstance(
|
||||
self.model_runner, TP1DraftModelRunner
|
||||
) and self.model_runner.supports_gpu_multi_step(expanded_request):
|
||||
# Here we run the draft_model_runner with multi-step prepare
|
||||
# on the GPU directly
|
||||
expanded_request.num_steps = sample_len
|
||||
self.model_runner.set_indices_of_seq_with_bonus_tokens(
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs = self.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
else:
|
||||
# Here we run multi-step directly, with every step prepared
|
||||
# on the CPU.
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
|
||||
set_spec_step(SpecStepKind.FIRST_PROPOSAL)
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
set_spec_step(SpecStepKind.OTHER_PROPOSAL)
|
||||
self._append_new_tokens(
|
||||
model_output, expanded_request.seq_group_metadata_list,
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs.append(model_output)
|
||||
set_spec_step(SpecStepKind.SCORE_DECODE)
|
||||
|
||||
filtered_model_outputs = self._filter_model_output_zero_overhead(
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
|
||||
return filtered_model_outputs, True
|
||||
|
||||
def _filter_model_output_zero_overhead(self,
|
||||
expanded_batch_outputs: List[SamplerOutput],
|
||||
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
|
||||
"""
|
||||
Filters the model output to include only the specified sequence
|
||||
outputs. This method contracts the expanded batch output from the
|
||||
model to retain the outputs of only those sequences indicated by the
|
||||
provided indices.
|
||||
|
||||
Args:
|
||||
expanded_batch_output (List[SamplerOutput]): The expanded output
|
||||
batch from the model.
|
||||
output_indices_to_retain (torch.Tensor): Indices of the model
|
||||
outputs to retain.
|
||||
|
||||
Returns:
|
||||
List[SamplerOutput]: A list containing the filtered model
|
||||
outputs for the specified indices.
|
||||
"""
|
||||
|
||||
indices_of_seq_with_bonus_tokens = async_tensor_h2d(output_indices_to_retain, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
|
||||
return [
|
||||
SamplerOutput(
|
||||
outputs=[
|
||||
expanded_batch_output.outputs[i]
|
||||
for i in output_indices_to_retain
|
||||
] if len(expanded_batch_output.outputs) > 0 else [],
|
||||
sampled_token_probs=(
|
||||
expanded_batch_output.
|
||||
sampled_token_probs[indices_of_seq_with_bonus_tokens]
|
||||
if expanded_batch_output.sampled_token_probs is not None
|
||||
else None),
|
||||
logprobs=(
|
||||
expanded_batch_output.logprobs[indices_of_seq_with_bonus_tokens]
|
||||
if expanded_batch_output.logprobs is not None else None),
|
||||
sampled_token_ids=(expanded_batch_output.
|
||||
sampled_token_ids[indices_of_seq_with_bonus_tokens]
|
||||
if expanded_batch_output.sampled_token_ids
|
||||
is not None else None))
|
||||
for expanded_batch_output in expanded_batch_outputs
|
||||
]
|
||||
565
vllm/zero_overhead/spec_decode/spec_decode_worker.py
Normal file
565
vllm/zero_overhead/spec_decode/spec_decode_worker.py
Normal file
@@ -0,0 +1,565 @@
|
||||
|
||||
import os
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.distributed.communication_op import (broadcast_tensor_dict,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.distributed.parallel_state import model_parallel_is_initialized
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||
CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SequenceGroupMetadata,
|
||||
get_all_seq_ids_and_request_ids, Logits)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, prepare_prefill_hidden_states
|
||||
from vllm.zero_overhead.spec_decode.batch_expansion import ZeroOverheadBatchExpansionTop1Scorer
|
||||
from vllm.zero_overhead.utils import SpecStepKind, record_accepted_token_ids, set_spec_step
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.medusa_worker import MedusaWorker
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
||||
from vllm.spec_decode.mqa_scorer import MQAScorer
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||
from vllm.spec_decode.target_model_runner import TargetModelRunner
|
||||
from vllm.spec_decode.util import (Timer, create_logprobs_output,
|
||||
create_sequence_group_output,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.utils import async_tensor_h2d, resolve_obj_by_qualname
|
||||
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
|
||||
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize both scorer and proposer models.
|
||||
"""
|
||||
# The scorer worker model is initialized first in case the proposer
|
||||
# model has a smaller TP degree than the target worker.
|
||||
self.scorer_worker.init_device()
|
||||
self.proposer_worker.init_device()
|
||||
|
||||
# NOTE(cade): load_model is not part of the WorkerBase interface.
|
||||
self.scorer_worker.load_model()
|
||||
self.proposer_worker.load_model()
|
||||
|
||||
if self._enable_lm_head_weight_load:
|
||||
# NOTE(Shangming): gather lm_head weight when tp enabled
|
||||
target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather(
|
||||
self.scorer_worker.model_runner.model_runner.model.lm_head.\
|
||||
weight.data,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
self.proposer_worker.maybe_load_lm_head_weight(
|
||||
target_lm_head_weight)
|
||||
|
||||
self._metrics.init_tensors(self.rank, device_type=self.device)
|
||||
if model_parallel_is_initialized():
|
||||
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
|
||||
device_type=self.device)
|
||||
else:
|
||||
self.spec_decode_sampler.init_tensors(self.rank,
|
||||
device_type=self.device)
|
||||
|
||||
scorer_cls: Type[SpeculativeScorer]
|
||||
if self.disable_mqa_scorer:
|
||||
scorer_cls = ZeroOverheadBatchExpansionTop1Scorer
|
||||
logger.info("[Speculative Decoding] Use batch "
|
||||
"expansion for scoring proposals.")
|
||||
else:
|
||||
scorer_cls = MQAScorer
|
||||
logger.info(
|
||||
"[Speculative Decoding] Use MQA scorer for scoring proposals.")
|
||||
|
||||
if not self.tree_decoding:
|
||||
self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
|
||||
device=self.device,
|
||||
vocab_size=self._vocab_size)
|
||||
else:
|
||||
self.scorer = BatchExpansionTreeStyleScorer(
|
||||
scorer_worker=self.scorer_worker,
|
||||
device=self.device,
|
||||
vocab_size=self._vocab_size)
|
||||
|
||||
self._configure_model_sampler_for_spec_decode()
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||
skip_proposer: bool) -> List[SamplerOutput]:
|
||||
"""Run a single generation step without any speculation. The input is
|
||||
sent to the proposer and scorer model so that the KV cache is consistent
|
||||
between the two. When skip_proposer is True, the proposer model is
|
||||
not called, meaning that the kv-cache in proposer for requests is not
|
||||
updated, so they cannot enable spec decode in the rest decoding.
|
||||
"""
|
||||
if self.tree_decoding and self.kvcache_slot_to_be_moved is not None:
|
||||
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
|
||||
self.kvcache_slot_to_be_moved = None
|
||||
|
||||
set_spec_step(SpecStepKind.PREFILL)
|
||||
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||
|
||||
assert len(sampler_output) == 1
|
||||
sampler_output = sampler_output[0]
|
||||
|
||||
# Store hidden states from target model execution, BxD.
|
||||
hidden_states = sampler_output.hidden_states
|
||||
if hidden_states is not None:
|
||||
# Only decodes and prefill terminal chunks need a hidden state.
|
||||
seq_group_meta_with_hidden = [
|
||||
sg for sg in execute_model_req.seq_group_metadata_list
|
||||
if sg.do_sample
|
||||
]
|
||||
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
|
||||
# Drop hidden_states with no prediction (eg non-terminal chunks)
|
||||
hidden_states = hidden_states[
|
||||
torch.where(sampler_output.sampled_token_ids -
|
||||
VLLM_INVALID_TOKEN_ID)[0]]
|
||||
# if not skip_proposer:
|
||||
# if self.previous_hidden_states is None and len(
|
||||
# seq_group_meta_with_hidden):
|
||||
# self.previous_hidden_states = HiddenStates(
|
||||
# hidden_states, seq_group_meta_with_hidden)
|
||||
# elif self.previous_hidden_states and len(
|
||||
# seq_group_meta_with_hidden):
|
||||
# self.previous_hidden_states.update(hidden_states,
|
||||
# seq_group_meta_with_hidden)
|
||||
if self.previous_hidden_states is None and len(
|
||||
seq_group_meta_with_hidden):
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, seq_group_meta_with_hidden)
|
||||
elif self.previous_hidden_states and len(
|
||||
seq_group_meta_with_hidden):
|
||||
self.previous_hidden_states.update(hidden_states,
|
||||
seq_group_meta_with_hidden)
|
||||
|
||||
# Store logits from target model execution.
|
||||
if self.tree_decoding:
|
||||
logits = sampler_output.logits
|
||||
if logits is not None:
|
||||
if self.previous_logits is None:
|
||||
self.previous_logits = Logits(
|
||||
logits, execute_model_req.seq_group_metadata_list)
|
||||
else:
|
||||
self.previous_logits.update(
|
||||
logits, execute_model_req.seq_group_metadata_list)
|
||||
|
||||
if not skip_proposer:
|
||||
# We prepare the prefill hidden states here so that there no
|
||||
# additional complexity in worker for spec_decode vs non_spec_decode
|
||||
# flow and execute_model doesn't need additional modifications.
|
||||
execute_model_req.previous_hidden_states = \
|
||||
prepare_prefill_hidden_states(
|
||||
sampler_output.prefill_hidden_states)
|
||||
for i in range(self._num_spec_prefill_steps):
|
||||
execute_model_req.spec_step_idx = i
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||
if self._disable_logprobs else
|
||||
[sampler_output])
|
||||
|
||||
# Clear device tensors from sampler output. This reduces communication
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.sampled_token_probs = None
|
||||
sampler_output.sampled_token_ids = None
|
||||
sampler_output.logprobs = None
|
||||
return sampler_output_to_return
|
||||
|
||||
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||
def _verify_tokens(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_scores: SpeculativeScores,
|
||||
proposals: SpeculativeProposals,
|
||||
max_proposal_len: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
|
||||
"""Determine which speculative tokens are accepted using the
|
||||
probabilities of each token according to the proposer and scorer models.
|
||||
|
||||
Returns a tuple of Tensors, one for the accepted token ids and one for
|
||||
the logprobs according to the scoring model.
|
||||
"""
|
||||
proposal_lens_list = proposals.proposal_lens
|
||||
|
||||
# vLLM currently only supports proposal lens equal to zero or the batch
|
||||
# proposal len. This adds some complexity (splitting the batch into spec
|
||||
# and non spec sequences) and should be removed in the future. It can be
|
||||
# done by supporting per-sequence proposal lens.
|
||||
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
|
||||
seq_group_metadata_list, proposal_lens_list)
|
||||
original_indices = spec_indices + non_spec_indices
|
||||
|
||||
# Get probabilities of target model, including bonus tokens.
|
||||
if non_spec_indices:
|
||||
proposal_verifier_probs = proposal_scores.probs[spec_indices]
|
||||
else:
|
||||
proposal_verifier_probs = proposal_scores.probs
|
||||
|
||||
if self.tree_decoding:
|
||||
retrieve_indices = proposals.retrieve_indices
|
||||
proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]
|
||||
|
||||
# Get non-speculative sampled tokens from target model.
|
||||
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
||||
|
||||
# Get bonus tokens from target model.
|
||||
bonus_token_ids = proposal_scores.token_ids[:, -1:]
|
||||
if non_spec_indices:
|
||||
bonus_token_ids = bonus_token_ids[spec_indices, :]
|
||||
|
||||
# Get probabilities according to proposal method.
|
||||
proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
|
||||
if proposal_probs is not None and non_spec_indices:
|
||||
proposal_probs = proposal_probs[spec_indices]
|
||||
|
||||
# Get proposed tokens.
|
||||
proposal_token_ids = proposals.proposal_token_ids
|
||||
if non_spec_indices:
|
||||
proposal_token_ids = proposal_token_ids[spec_indices]
|
||||
|
||||
# Get tree buffers.
|
||||
cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
|
||||
if cart_candidates is not None and non_spec_indices:
|
||||
cart_candidates = cart_candidates[spec_indices]
|
||||
|
||||
# Sampler arguments
|
||||
sampler_extra_kwargs: Dict[str, Any] = {}
|
||||
if self.generators and isinstance(self.spec_decode_sampler,
|
||||
SpecDecodeStochasticBaseSampler):
|
||||
sampler_extra_kwargs["seeded_seqs"] = {
|
||||
idx: self.generators[sgm.request_id]
|
||||
for idx, sgm in enumerate(seq_group_metadata_list)
|
||||
if sgm.sampling_params.seed is not None
|
||||
}
|
||||
|
||||
if isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler):
|
||||
sampler_extra_kwargs["cart_candidates"] = cart_candidates
|
||||
sampler_extra_kwargs["best_candidates"] = []
|
||||
sampler_extra_kwargs["accept_lengths"] = []
|
||||
|
||||
first_step_flags = []
|
||||
for i, sgm in enumerate(seq_group_metadata_list):
|
||||
seq = next(iter(sgm.seq_data.values()))
|
||||
first_step_flags.append(True if seq.get_first_step_flag() else False)
|
||||
|
||||
sampler_extra_kwargs["first_step_flags"] = first_step_flags
|
||||
|
||||
accepted_token_ids = self.spec_decode_sampler(
|
||||
target_with_bonus_probs=proposal_verifier_probs,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
draft_token_ids=proposal_token_ids,
|
||||
**sampler_extra_kwargs,
|
||||
)
|
||||
# Append output tokens from non-speculative sequences to
|
||||
# the accepted token ids tensor.
|
||||
if not self.tree_decoding:
|
||||
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
|
||||
1).clone()
|
||||
else:
|
||||
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len).clone()
|
||||
|
||||
non_spec_token_ids[:, 1:] = -1
|
||||
accepted_token_ids = torch.cat(
|
||||
[accepted_token_ids, non_spec_token_ids])
|
||||
logprobs = proposal_scores.logprobs
|
||||
# Rearrange so that results are in the order of the original seq group
|
||||
# metadata.
|
||||
original_indices = async_tensor_h2d(original_indices, torch.int32,
|
||||
self.device,
|
||||
True)
|
||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
||||
# B x K+1 x D
|
||||
hidden_states = proposal_scores.hidden_states
|
||||
|
||||
select_indices = None
|
||||
accept_lengths = None
|
||||
|
||||
select_indices_list = []
|
||||
|
||||
if cart_candidates is None:
|
||||
if hidden_states is not None:
|
||||
# Only get terminal hidden states for next step
|
||||
terminal_metadata = [
|
||||
sg for sg in seq_group_metadata_list if sg.do_sample
|
||||
]
|
||||
# Contract hidden states based on accepted tokens
|
||||
hs_size = hidden_states.shape[-1]
|
||||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
|
||||
# Drop non-terminal prefill chunks hidden states.
|
||||
hidden_states = hidden_states[accepted_index !=
|
||||
VLLM_INVALID_TOKEN_ID]
|
||||
accepted_index = accepted_index[accepted_index !=
|
||||
VLLM_INVALID_TOKEN_ID]
|
||||
assert len(accepted_index) == hidden_states.shape[0] == len(
|
||||
terminal_metadata)
|
||||
index = accepted_index[:, None, None].expand(-1, 1,
|
||||
hs_size) # b x 1 x d
|
||||
second_last_token_hidden_states = hidden_states[:, -2] # b x d
|
||||
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
||||
|
||||
# Store hidden states from target model for subsequent decode step
|
||||
self.previous_hidden_states = HiddenStates(
|
||||
hidden_states, terminal_metadata,
|
||||
second_last_token_hidden_states)
|
||||
else:
|
||||
retrieve_indices = proposals.retrieve_indices
|
||||
|
||||
batch_size = len(seq_group_metadata_list)
|
||||
|
||||
best_candidates = sampler_extra_kwargs["best_candidates"]
|
||||
accept_lengths = sampler_extra_kwargs["accept_lengths"]
|
||||
|
||||
# Contract hidden states based on accepted tokens
|
||||
hs_size = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(batch_size, -1, hs_size)
|
||||
|
||||
# Store logits from target model for subsequent proposal
|
||||
logits = proposal_scores.logits
|
||||
logits = logits.view(batch_size, -1, logits.shape[-1])
|
||||
logits = logits[:, retrieve_indices] # [batch_size, retrieve_size, max_depth, vocab_size]
|
||||
|
||||
previous_logits_list = []
|
||||
|
||||
previous_hidden_state_list = []
|
||||
|
||||
retrieve_indices = retrieve_indices.cpu()
|
||||
|
||||
for i in range(batch_size):
|
||||
logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
|
||||
previous_logits_list.append(logit)
|
||||
select_indices = retrieve_indices[best_candidates[i], :accept_lengths[i]+1]
|
||||
hidden_state = hidden_states[i, select_indices[-1]].unsqueeze(0)
|
||||
select_indices_list.append(select_indices)
|
||||
previous_hidden_state_list.append(hidden_state)
|
||||
|
||||
logits = torch.cat(previous_logits_list, dim=0)
|
||||
self.previous_logits = Logits(logits, seq_group_metadata_list)
|
||||
|
||||
hidden_states = torch.cat(previous_hidden_state_list, dim=0) # [batch_size, 1, vocab_size]
|
||||
self.previous_hidden_states = HiddenStates(hidden_states,
|
||||
seq_group_metadata_list,)
|
||||
|
||||
return accepted_token_ids, logprobs, select_indices_list, accept_lengths
|
||||
|
||||
def _create_output_sampler_list(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
||||
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
||||
prompt_logprobs: Optional[
|
||||
torch.Tensor], # shape: [nprompt_tokens, vocab_size]
|
||||
k: int,
|
||||
stage_times: Tuple[float, float, float],
|
||||
) -> List[SamplerOutput]:
|
||||
"""Given the accepted token ids, create a list of SamplerOutput.
|
||||
|
||||
The output is padded with -1 tokens such that each sequence has
|
||||
the same number of outputs.
|
||||
"""
|
||||
batch_size, num_steps = accepted_token_ids.shape
|
||||
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
|
||||
if self._disable_logprobs:
|
||||
# We are skipping the logprobs. Hence don't serialize the
|
||||
# logprobs related tensors from the GPU. Instead create
|
||||
# empty/dummy lists.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step,
|
||||
topk_logprobs_by_step, topk_indices_by_step) =\
|
||||
self._create_dummy_logprob_lists(
|
||||
batch_size, num_steps,
|
||||
self.scorer_worker.model_config.max_logprobs)
|
||||
else:
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
target_logprobs_by_step = target_logprobs.transpose(0, 1)
|
||||
# Serialize all tensors into Python lists.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step,
|
||||
topk_logprobs_by_step, topk_indices_by_step) =\
|
||||
self._create_logprob_lists_from_tensors(
|
||||
target_logprobs_by_step, accepted_token_ids_by_step,
|
||||
self.scorer_worker.model_config.max_logprobs)
|
||||
|
||||
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
||||
# batch.
|
||||
seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
|
||||
seq_group_metadata_list)
|
||||
|
||||
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
||||
|
||||
# Serialize tensor to CPU Python list.
|
||||
#accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
record_accepted_token_ids(accepted_token_ids, seq_ids)
|
||||
|
||||
# Construct the output on a per-step, per-sequence basis.
|
||||
# Non-terminal prefill chunks will end up here as rows with just -1s
|
||||
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
|
||||
# terminal chunks will only have one generated token at time 0.
|
||||
sampler_output_list: List[SamplerOutput] = []
|
||||
|
||||
# Prefills are not multi-step (return at most 1 token), in order to
|
||||
# avoid padding or repetition to fit decodes, we separate them.
|
||||
for i, sg in enumerate(seq_group_metadata_list):
|
||||
if not sg.is_prompt:
|
||||
# Requests are ordered as prefills|decodes=>no more prefills.
|
||||
break
|
||||
num_logprobs = num_logprobs_per_seq[i]
|
||||
seq_kwargs = dict(token_id=-1,
|
||||
token_id_logprob_rank=0,
|
||||
token_id_logprob=-float('inf'),
|
||||
topk_token_ids=[-1] * num_logprobs,
|
||||
topk_logprobs=[-float('inf')] * num_logprobs,
|
||||
seq_id=seq_ids[i])
|
||||
# Terminal chunk, has token.
|
||||
if sg.do_sample:
|
||||
seq_kwargs.update(
|
||||
dict(
|
||||
token_id=accepted_token_ids[i][0].item(),
|
||||
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
||||
0][i],
|
||||
token_id_logprob=accepted_token_id_logprobs_by_step[0]
|
||||
[i],
|
||||
topk_token_ids=topk_indices_by_step[0][i]
|
||||
[:num_logprobs],
|
||||
# output only so step is 0
|
||||
topk_logprobs=topk_logprobs_by_step[0][i]
|
||||
[:num_logprobs],
|
||||
))
|
||||
needs_plogs = (sg.sampling_params.prompt_logprobs
|
||||
and sg.sampling_params.prompt_logprobs > 0)
|
||||
plogs = None
|
||||
if prompt_logprobs is not None:
|
||||
# Even non-terminal prompt chunks can have logprobs here.
|
||||
plogs = prompt_logprobs[i]
|
||||
elif needs_plogs:
|
||||
# Prompt logprobs are requested but `_disable_logprobs` is set.
|
||||
seq_data = next(iter(sg.seq_data.values()))
|
||||
# Get only the tokens in this chunk!
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
prompt_token_ids = prompt_token_ids[
|
||||
seq_data.
|
||||
_num_computed_tokens:seq_data._num_computed_tokens +
|
||||
sg.token_chunk_size]
|
||||
|
||||
is_first_chunk = seq_data._num_computed_tokens == 0
|
||||
# There's no prob generated for the first token in a sequence.
|
||||
if is_first_chunk:
|
||||
prompt_token_ids = prompt_token_ids[1:]
|
||||
plogs = [
|
||||
create_logprobs_output(
|
||||
token_id=p_token_id,
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
) for p_token_id in prompt_token_ids
|
||||
]
|
||||
seq_kwargs.update(dict(prompt_logprobs=plogs))
|
||||
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(
|
||||
outputs=[create_sequence_group_output(
|
||||
**seq_kwargs)])) # type: ignore
|
||||
|
||||
# Decodes, create one SamplerOutput per-step (at most K+1).
|
||||
for step_index in range(num_steps):
|
||||
# if all(token_id == -1 for sg, token_id in zip(
|
||||
# seq_group_metadata_list,
|
||||
# accepted_token_ids_by_step[step_index])
|
||||
# if not sg.is_prompt):
|
||||
# break
|
||||
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
|
||||
for sequence_index in range(batch_size):
|
||||
seq_meta = seq_group_metadata_list[sequence_index]
|
||||
# Prompts already processed above.
|
||||
if seq_meta.is_prompt:
|
||||
continue
|
||||
|
||||
# Each sequence may have a different num_logprobs; retrieve it.
|
||||
num_logprobs = num_logprobs_per_seq[sequence_index]
|
||||
step_output_token_ids.append(
|
||||
create_sequence_group_output(
|
||||
token_id = 0,
|
||||
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
||||
step_index][sequence_index],
|
||||
token_id_logprob=accepted_token_id_logprobs_by_step[
|
||||
step_index][sequence_index],
|
||||
seq_id=seq_ids[sequence_index],
|
||||
topk_token_ids=topk_indices_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
topk_logprobs=topk_logprobs_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
))
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
|
||||
# Populate the data structures needed to keep track of sequences with
|
||||
# bonus tokens.
|
||||
self._track_sequences_with_bonus_tokens(seq_ids,
|
||||
request_ids_seq_ids_mapping,
|
||||
accepted_token_ids_by_step)
|
||||
maybe_rejsample_metrics = (
|
||||
self._metrics.maybe_collect_rejsample_metrics(k))
|
||||
if maybe_rejsample_metrics is not None and sampler_output_list:
|
||||
sampler_output_list[
|
||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||
|
||||
# Log time spent in each stage periodically.
|
||||
# This is periodic because the rejection sampler emits metrics
|
||||
# periodically.
|
||||
self._maybe_log_stage_times(*stage_times)
|
||||
# First `n_prefills` entries will contain prefills SamplerOutput when
|
||||
# chunked prefill is enabled, the rest is decodes in multi-step format.
|
||||
return sampler_output_list
|
||||
|
||||
def _track_sequences_with_bonus_tokens(
|
||||
self, seq_ids: List[int],
|
||||
request_ids_seq_ids_mapping: Dict[str, Set[int]],
|
||||
accepted_token_ids_by_step: List[List[int]]):
|
||||
"""
|
||||
Updates the internal data structures which keep track of sequences
|
||||
which have been assigned bonus tokens in their last forward pass.
|
||||
"""
|
||||
for seq_index, seq_id in enumerate(seq_ids):
|
||||
# last_token_id = accepted_token_ids_by_step[-1][seq_index]
|
||||
# if last_token_id == -1:
|
||||
# self._seq_with_bonus_token_in_last_step.discard(seq_id)
|
||||
# else:
|
||||
self._seq_with_bonus_token_in_last_step.add(seq_id)
|
||||
for request_id, sequences in request_ids_seq_ids_mapping.items():
|
||||
self._request_id_seq_id_mapping[request_id].update(sequences)
|
||||
84
vllm/zero_overhead/spec_decode/top1_proproser.py
Normal file
84
vllm/zero_overhead/spec_decode/top1_proproser.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.utils import async_tensor_h2d
|
||||
from vllm.zero_overhead.utils import record_proposal_lens_list
|
||||
|
||||
class ZeroOverheadTop1Proposer(Top1Proposer):
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
proposal_len: int,
|
||||
maybe_sampler_output: Optional[List[SamplerOutput]],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
sampler_transposed: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.tensor(-1,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len)
|
||||
proposal_probs = torch.tensor(0,
|
||||
dtype=torch.float32,
|
||||
device=self._device).expand(
|
||||
batch_size, proposal_len,
|
||||
self._vocab_size)
|
||||
proposal_lens_tensor = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device=self._device).expand(
|
||||
len(proposal_lens))
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
|
||||
sampler_output, sampler_transposed)
|
||||
|
||||
proposal_lens_list = [0 for i in range(batch_size)]
|
||||
for indices in nonzero_proposal_len_indices:
|
||||
proposal_lens_list[indices] = proposal_len
|
||||
record_proposal_lens_list(proposal_lens_list)
|
||||
|
||||
nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32,
|
||||
self._device,
|
||||
True)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = proposal_tokens.new_full(
|
||||
size=(batch_size, *proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = proposal_probs.new_zeros(
|
||||
batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = (
|
||||
entire_proposal_tokens,
|
||||
entire_proposal_probs,
|
||||
)
|
||||
|
||||
proposal_lens_tensor = async_tensor_h2d(proposal_lens_list, torch.long,
|
||||
self._device,
|
||||
True)
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
Reference in New Issue
Block a user