565 lines
27 KiB
Python
565 lines
27 KiB
Python
|
|
import os
|
|
import copy
|
|
from collections import defaultdict
|
|
from functools import cached_property
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
|
from vllm.distributed.communication_op import (broadcast_tensor_dict,
|
|
get_tp_group,
|
|
tensor_model_parallel_gather)
|
|
from vllm.distributed.parallel_state import model_parallel_is_initialized
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
|
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
|
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
|
TypicalAcceptanceSampler)
|
|
from vllm.platforms import current_platform
|
|
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
|
CompletionSequenceGroupOutput, ExecuteModelRequest,
|
|
HiddenStates, SequenceGroupMetadata,
|
|
get_all_seq_ids_and_request_ids, Logits)
|
|
from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer
|
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
|
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, prepare_prefill_hidden_states
|
|
from vllm.zero_overhead.spec_decode.batch_expansion import ZeroOverheadBatchExpansionTop1Scorer
|
|
from vllm.zero_overhead.utils import SpecStepKind, record_accepted_token_ids, set_spec_step
|
|
|
|
if current_platform.is_cuda_alike():
|
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
|
|
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
|
SpeculativeScorer, SpeculativeScores)
|
|
from vllm.spec_decode.medusa_worker import MedusaWorker
|
|
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
|
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
|
from vllm.spec_decode.mqa_scorer import MQAScorer
|
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
|
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
|
from vllm.spec_decode.target_model_runner import TargetModelRunner
|
|
from vllm.spec_decode.util import (Timer, create_logprobs_output,
|
|
create_sequence_group_output,
|
|
get_all_num_logprobs,
|
|
get_sampled_token_logprobs, nvtx_range,
|
|
split_batch_by_proposal_len)
|
|
from vllm.utils import async_tensor_h2d, resolve_obj_by_qualname
|
|
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
|
|
|
|
from vllm.worker.cache_engine import CacheEngine
|
|
from vllm.attention.ops.paged_attn import PagedAttention
|
|
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
|
|
|
|
def init_device(self) -> None:
|
|
"""Initialize both scorer and proposer models.
|
|
"""
|
|
# The scorer worker model is initialized first in case the proposer
|
|
# model has a smaller TP degree than the target worker.
|
|
self.scorer_worker.init_device()
|
|
self.proposer_worker.init_device()
|
|
|
|
# NOTE(cade): load_model is not part of the WorkerBase interface.
|
|
self.scorer_worker.load_model()
|
|
self.proposer_worker.load_model()
|
|
|
|
if self._enable_lm_head_weight_load:
|
|
# NOTE(Shangming): gather lm_head weight when tp enabled
|
|
target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather(
|
|
self.scorer_worker.model_runner.model_runner.model.lm_head.\
|
|
weight.data,
|
|
dim=0,
|
|
)
|
|
|
|
self.proposer_worker.maybe_load_lm_head_weight(
|
|
target_lm_head_weight)
|
|
|
|
self._metrics.init_tensors(self.rank, device_type=self.device)
|
|
if model_parallel_is_initialized():
|
|
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
|
|
device_type=self.device)
|
|
else:
|
|
self.spec_decode_sampler.init_tensors(self.rank,
|
|
device_type=self.device)
|
|
|
|
scorer_cls: Type[SpeculativeScorer]
|
|
if self.disable_mqa_scorer:
|
|
scorer_cls = ZeroOverheadBatchExpansionTop1Scorer
|
|
logger.info("[Speculative Decoding] Use batch "
|
|
"expansion for scoring proposals.")
|
|
else:
|
|
scorer_cls = MQAScorer
|
|
logger.info(
|
|
"[Speculative Decoding] Use MQA scorer for scoring proposals.")
|
|
|
|
if not self.tree_decoding:
|
|
self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
|
|
device=self.device,
|
|
vocab_size=self._vocab_size)
|
|
else:
|
|
self.scorer = BatchExpansionTreeStyleScorer(
|
|
scorer_worker=self.scorer_worker,
|
|
device=self.device,
|
|
vocab_size=self._vocab_size)
|
|
|
|
self._configure_model_sampler_for_spec_decode()
|
|
|
|
@nvtx_range("spec_decode_worker._run_no_spec")
|
|
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
|
skip_proposer: bool) -> List[SamplerOutput]:
|
|
"""Run a single generation step without any speculation. The input is
|
|
sent to the proposer and scorer model so that the KV cache is consistent
|
|
between the two. When skip_proposer is True, the proposer model is
|
|
not called, meaning that the kv-cache in proposer for requests is not
|
|
updated, so they cannot enable spec decode in the rest decoding.
|
|
"""
|
|
if self.tree_decoding and self.kvcache_slot_to_be_moved is not None:
|
|
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
|
|
self.kvcache_slot_to_be_moved = None
|
|
|
|
set_spec_step(SpecStepKind.PREFILL)
|
|
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
|
|
|
assert len(sampler_output) == 1
|
|
sampler_output = sampler_output[0]
|
|
|
|
# Store hidden states from target model execution, BxD.
|
|
hidden_states = sampler_output.hidden_states
|
|
if hidden_states is not None:
|
|
# Only decodes and prefill terminal chunks need a hidden state.
|
|
seq_group_meta_with_hidden = [
|
|
sg for sg in execute_model_req.seq_group_metadata_list
|
|
if sg.do_sample
|
|
]
|
|
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
|
|
# Drop hidden_states with no prediction (eg non-terminal chunks)
|
|
hidden_states = hidden_states[
|
|
torch.where(sampler_output.sampled_token_ids -
|
|
VLLM_INVALID_TOKEN_ID)[0]]
|
|
# if not skip_proposer:
|
|
# if self.previous_hidden_states is None and len(
|
|
# seq_group_meta_with_hidden):
|
|
# self.previous_hidden_states = HiddenStates(
|
|
# hidden_states, seq_group_meta_with_hidden)
|
|
# elif self.previous_hidden_states and len(
|
|
# seq_group_meta_with_hidden):
|
|
# self.previous_hidden_states.update(hidden_states,
|
|
# seq_group_meta_with_hidden)
|
|
if self.previous_hidden_states is None and len(
|
|
seq_group_meta_with_hidden):
|
|
self.previous_hidden_states = HiddenStates(
|
|
hidden_states, seq_group_meta_with_hidden)
|
|
elif self.previous_hidden_states and len(
|
|
seq_group_meta_with_hidden):
|
|
self.previous_hidden_states.update(hidden_states,
|
|
seq_group_meta_with_hidden)
|
|
|
|
# Store logits from target model execution.
|
|
if self.tree_decoding:
|
|
logits = sampler_output.logits
|
|
if logits is not None:
|
|
if self.previous_logits is None:
|
|
self.previous_logits = Logits(
|
|
logits, execute_model_req.seq_group_metadata_list)
|
|
else:
|
|
self.previous_logits.update(
|
|
logits, execute_model_req.seq_group_metadata_list)
|
|
|
|
if not skip_proposer:
|
|
# We prepare the prefill hidden states here so that there no
|
|
# additional complexity in worker for spec_decode vs non_spec_decode
|
|
# flow and execute_model doesn't need additional modifications.
|
|
execute_model_req.previous_hidden_states = \
|
|
prepare_prefill_hidden_states(
|
|
sampler_output.prefill_hidden_states)
|
|
for i in range(self._num_spec_prefill_steps):
|
|
execute_model_req.spec_step_idx = i
|
|
self.proposer_worker.execute_model(execute_model_req)
|
|
|
|
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
|
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
|
if self._disable_logprobs else
|
|
[sampler_output])
|
|
|
|
# Clear device tensors from sampler output. This reduces communication
|
|
# overhead when the engine runs in a different process than the workers.
|
|
sampler_output.sampled_token_probs = None
|
|
sampler_output.sampled_token_ids = None
|
|
sampler_output.logprobs = None
|
|
return sampler_output_to_return
|
|
|
|
@nvtx_range("spec_decode_worker._verify_tokens")
|
|
def _verify_tokens(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
proposal_scores: SpeculativeScores,
|
|
proposals: SpeculativeProposals,
|
|
max_proposal_len: int,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
|
|
"""Determine which speculative tokens are accepted using the
|
|
probabilities of each token according to the proposer and scorer models.
|
|
|
|
Returns a tuple of Tensors, one for the accepted token ids and one for
|
|
the logprobs according to the scoring model.
|
|
"""
|
|
proposal_lens_list = proposals.proposal_lens
|
|
|
|
# vLLM currently only supports proposal lens equal to zero or the batch
|
|
# proposal len. This adds some complexity (splitting the batch into spec
|
|
# and non spec sequences) and should be removed in the future. It can be
|
|
# done by supporting per-sequence proposal lens.
|
|
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
|
|
seq_group_metadata_list, proposal_lens_list)
|
|
original_indices = spec_indices + non_spec_indices
|
|
|
|
# Get probabilities of target model, including bonus tokens.
|
|
if non_spec_indices:
|
|
proposal_verifier_probs = proposal_scores.probs[spec_indices]
|
|
else:
|
|
proposal_verifier_probs = proposal_scores.probs
|
|
|
|
if self.tree_decoding:
|
|
retrieve_indices = proposals.retrieve_indices
|
|
proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]
|
|
|
|
# Get non-speculative sampled tokens from target model.
|
|
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
|
|
|
# Get bonus tokens from target model.
|
|
bonus_token_ids = proposal_scores.token_ids[:, -1:]
|
|
if non_spec_indices:
|
|
bonus_token_ids = bonus_token_ids[spec_indices, :]
|
|
|
|
# Get probabilities according to proposal method.
|
|
proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
|
|
if proposal_probs is not None and non_spec_indices:
|
|
proposal_probs = proposal_probs[spec_indices]
|
|
|
|
# Get proposed tokens.
|
|
proposal_token_ids = proposals.proposal_token_ids
|
|
if non_spec_indices:
|
|
proposal_token_ids = proposal_token_ids[spec_indices]
|
|
|
|
# Get tree buffers.
|
|
cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
|
|
if cart_candidates is not None and non_spec_indices:
|
|
cart_candidates = cart_candidates[spec_indices]
|
|
|
|
# Sampler arguments
|
|
sampler_extra_kwargs: Dict[str, Any] = {}
|
|
if self.generators and isinstance(self.spec_decode_sampler,
|
|
SpecDecodeStochasticBaseSampler):
|
|
sampler_extra_kwargs["seeded_seqs"] = {
|
|
idx: self.generators[sgm.request_id]
|
|
for idx, sgm in enumerate(seq_group_metadata_list)
|
|
if sgm.sampling_params.seed is not None
|
|
}
|
|
|
|
if isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler):
|
|
sampler_extra_kwargs["cart_candidates"] = cart_candidates
|
|
sampler_extra_kwargs["best_candidates"] = []
|
|
sampler_extra_kwargs["accept_lengths"] = []
|
|
|
|
first_step_flags = []
|
|
for i, sgm in enumerate(seq_group_metadata_list):
|
|
seq = next(iter(sgm.seq_data.values()))
|
|
first_step_flags.append(True if seq.get_first_step_flag() else False)
|
|
|
|
sampler_extra_kwargs["first_step_flags"] = first_step_flags
|
|
|
|
accepted_token_ids = self.spec_decode_sampler(
|
|
target_with_bonus_probs=proposal_verifier_probs,
|
|
bonus_token_ids=bonus_token_ids,
|
|
draft_probs=proposal_probs,
|
|
draft_token_ids=proposal_token_ids,
|
|
**sampler_extra_kwargs,
|
|
)
|
|
# Append output tokens from non-speculative sequences to
|
|
# the accepted token ids tensor.
|
|
if not self.tree_decoding:
|
|
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
|
|
1).clone()
|
|
else:
|
|
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len).clone()
|
|
|
|
non_spec_token_ids[:, 1:] = -1
|
|
accepted_token_ids = torch.cat(
|
|
[accepted_token_ids, non_spec_token_ids])
|
|
logprobs = proposal_scores.logprobs
|
|
# Rearrange so that results are in the order of the original seq group
|
|
# metadata.
|
|
original_indices = async_tensor_h2d(original_indices, torch.int32,
|
|
self.device,
|
|
True)
|
|
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
|
|
|
# B x K+1 x D
|
|
hidden_states = proposal_scores.hidden_states
|
|
|
|
select_indices = None
|
|
accept_lengths = None
|
|
|
|
select_indices_list = []
|
|
|
|
if cart_candidates is None:
|
|
if hidden_states is not None:
|
|
# Only get terminal hidden states for next step
|
|
terminal_metadata = [
|
|
sg for sg in seq_group_metadata_list if sg.do_sample
|
|
]
|
|
# Contract hidden states based on accepted tokens
|
|
hs_size = hidden_states.shape[-1]
|
|
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
|
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
|
|
# Drop non-terminal prefill chunks hidden states.
|
|
hidden_states = hidden_states[accepted_index !=
|
|
VLLM_INVALID_TOKEN_ID]
|
|
accepted_index = accepted_index[accepted_index !=
|
|
VLLM_INVALID_TOKEN_ID]
|
|
assert len(accepted_index) == hidden_states.shape[0] == len(
|
|
terminal_metadata)
|
|
index = accepted_index[:, None, None].expand(-1, 1,
|
|
hs_size) # b x 1 x d
|
|
second_last_token_hidden_states = hidden_states[:, -2] # b x d
|
|
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
|
|
|
# Store hidden states from target model for subsequent decode step
|
|
self.previous_hidden_states = HiddenStates(
|
|
hidden_states, terminal_metadata,
|
|
second_last_token_hidden_states)
|
|
else:
|
|
retrieve_indices = proposals.retrieve_indices
|
|
|
|
batch_size = len(seq_group_metadata_list)
|
|
|
|
best_candidates = sampler_extra_kwargs["best_candidates"]
|
|
accept_lengths = sampler_extra_kwargs["accept_lengths"]
|
|
|
|
# Contract hidden states based on accepted tokens
|
|
hs_size = hidden_states.shape[-1]
|
|
hidden_states = hidden_states.view(batch_size, -1, hs_size)
|
|
|
|
# Store logits from target model for subsequent proposal
|
|
logits = proposal_scores.logits
|
|
logits = logits.view(batch_size, -1, logits.shape[-1])
|
|
logits = logits[:, retrieve_indices] # [batch_size, retrieve_size, max_depth, vocab_size]
|
|
|
|
previous_logits_list = []
|
|
|
|
previous_hidden_state_list = []
|
|
|
|
retrieve_indices = retrieve_indices.cpu()
|
|
|
|
for i in range(batch_size):
|
|
logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
|
|
previous_logits_list.append(logit)
|
|
select_indices = retrieve_indices[best_candidates[i], :accept_lengths[i]+1]
|
|
hidden_state = hidden_states[i, select_indices[-1]].unsqueeze(0)
|
|
select_indices_list.append(select_indices)
|
|
previous_hidden_state_list.append(hidden_state)
|
|
|
|
logits = torch.cat(previous_logits_list, dim=0)
|
|
self.previous_logits = Logits(logits, seq_group_metadata_list)
|
|
|
|
hidden_states = torch.cat(previous_hidden_state_list, dim=0) # [batch_size, 1, vocab_size]
|
|
self.previous_hidden_states = HiddenStates(hidden_states,
|
|
seq_group_metadata_list,)
|
|
|
|
return accepted_token_ids, logprobs, select_indices_list, accept_lengths
|
|
|
|
def _create_output_sampler_list(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
|
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
|
prompt_logprobs: Optional[
|
|
torch.Tensor], # shape: [nprompt_tokens, vocab_size]
|
|
k: int,
|
|
stage_times: Tuple[float, float, float],
|
|
) -> List[SamplerOutput]:
|
|
"""Given the accepted token ids, create a list of SamplerOutput.
|
|
|
|
The output is padded with -1 tokens such that each sequence has
|
|
the same number of outputs.
|
|
"""
|
|
batch_size, num_steps = accepted_token_ids.shape
|
|
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
|
|
if self._disable_logprobs:
|
|
# We are skipping the logprobs. Hence don't serialize the
|
|
# logprobs related tensors from the GPU. Instead create
|
|
# empty/dummy lists.
|
|
(accepted_token_id_ranks_by_step,
|
|
accepted_token_id_logprobs_by_step,
|
|
topk_logprobs_by_step, topk_indices_by_step) =\
|
|
self._create_dummy_logprob_lists(
|
|
batch_size, num_steps,
|
|
self.scorer_worker.model_config.max_logprobs)
|
|
else:
|
|
# Organize input tensors by step instead of by sequence.
|
|
target_logprobs_by_step = target_logprobs.transpose(0, 1)
|
|
# Serialize all tensors into Python lists.
|
|
(accepted_token_id_ranks_by_step,
|
|
accepted_token_id_logprobs_by_step,
|
|
topk_logprobs_by_step, topk_indices_by_step) =\
|
|
self._create_logprob_lists_from_tensors(
|
|
target_logprobs_by_step, accepted_token_ids_by_step,
|
|
self.scorer_worker.model_config.max_logprobs)
|
|
|
|
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
|
# batch.
|
|
seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
|
|
seq_group_metadata_list)
|
|
|
|
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
|
|
|
# Serialize tensor to CPU Python list.
|
|
#accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
|
record_accepted_token_ids(accepted_token_ids, seq_ids)
|
|
|
|
# Construct the output on a per-step, per-sequence basis.
|
|
# Non-terminal prefill chunks will end up here as rows with just -1s
|
|
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
|
|
# terminal chunks will only have one generated token at time 0.
|
|
sampler_output_list: List[SamplerOutput] = []
|
|
|
|
# Prefills are not multi-step (return at most 1 token), in order to
|
|
# avoid padding or repetition to fit decodes, we separate them.
|
|
for i, sg in enumerate(seq_group_metadata_list):
|
|
if not sg.is_prompt:
|
|
# Requests are ordered as prefills|decodes=>no more prefills.
|
|
break
|
|
num_logprobs = num_logprobs_per_seq[i]
|
|
seq_kwargs = dict(token_id=-1,
|
|
token_id_logprob_rank=0,
|
|
token_id_logprob=-float('inf'),
|
|
topk_token_ids=[-1] * num_logprobs,
|
|
topk_logprobs=[-float('inf')] * num_logprobs,
|
|
seq_id=seq_ids[i])
|
|
# Terminal chunk, has token.
|
|
if sg.do_sample:
|
|
seq_kwargs.update(
|
|
dict(
|
|
token_id=accepted_token_ids[i][0].item(),
|
|
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
|
0][i],
|
|
token_id_logprob=accepted_token_id_logprobs_by_step[0]
|
|
[i],
|
|
topk_token_ids=topk_indices_by_step[0][i]
|
|
[:num_logprobs],
|
|
# output only so step is 0
|
|
topk_logprobs=topk_logprobs_by_step[0][i]
|
|
[:num_logprobs],
|
|
))
|
|
needs_plogs = (sg.sampling_params.prompt_logprobs
|
|
and sg.sampling_params.prompt_logprobs > 0)
|
|
plogs = None
|
|
if prompt_logprobs is not None:
|
|
# Even non-terminal prompt chunks can have logprobs here.
|
|
plogs = prompt_logprobs[i]
|
|
elif needs_plogs:
|
|
# Prompt logprobs are requested but `_disable_logprobs` is set.
|
|
seq_data = next(iter(sg.seq_data.values()))
|
|
# Get only the tokens in this chunk!
|
|
prompt_token_ids = seq_data.get_prompt_token_ids()
|
|
prompt_token_ids = prompt_token_ids[
|
|
seq_data.
|
|
_num_computed_tokens:seq_data._num_computed_tokens +
|
|
sg.token_chunk_size]
|
|
|
|
is_first_chunk = seq_data._num_computed_tokens == 0
|
|
# There's no prob generated for the first token in a sequence.
|
|
if is_first_chunk:
|
|
prompt_token_ids = prompt_token_ids[1:]
|
|
plogs = [
|
|
create_logprobs_output(
|
|
token_id=p_token_id,
|
|
token_id_logprob_rank=-1,
|
|
token_id_logprob=0.0,
|
|
topk_token_ids=[],
|
|
topk_logprobs=[],
|
|
) for p_token_id in prompt_token_ids
|
|
]
|
|
seq_kwargs.update(dict(prompt_logprobs=plogs))
|
|
|
|
sampler_output_list.append(
|
|
SamplerOutput(
|
|
outputs=[create_sequence_group_output(
|
|
**seq_kwargs)])) # type: ignore
|
|
|
|
# Decodes, create one SamplerOutput per-step (at most K+1).
|
|
for step_index in range(num_steps):
|
|
# if all(token_id == -1 for sg, token_id in zip(
|
|
# seq_group_metadata_list,
|
|
# accepted_token_ids_by_step[step_index])
|
|
# if not sg.is_prompt):
|
|
# break
|
|
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
|
|
for sequence_index in range(batch_size):
|
|
seq_meta = seq_group_metadata_list[sequence_index]
|
|
# Prompts already processed above.
|
|
if seq_meta.is_prompt:
|
|
continue
|
|
|
|
# Each sequence may have a different num_logprobs; retrieve it.
|
|
num_logprobs = num_logprobs_per_seq[sequence_index]
|
|
step_output_token_ids.append(
|
|
create_sequence_group_output(
|
|
token_id = 0,
|
|
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
|
step_index][sequence_index],
|
|
token_id_logprob=accepted_token_id_logprobs_by_step[
|
|
step_index][sequence_index],
|
|
seq_id=seq_ids[sequence_index],
|
|
topk_token_ids=topk_indices_by_step[step_index]
|
|
[sequence_index][:num_logprobs],
|
|
topk_logprobs=topk_logprobs_by_step[step_index]
|
|
[sequence_index][:num_logprobs],
|
|
))
|
|
sampler_output_list.append(
|
|
SamplerOutput(outputs=step_output_token_ids))
|
|
|
|
# Populate the data structures needed to keep track of sequences with
|
|
# bonus tokens.
|
|
self._track_sequences_with_bonus_tokens(seq_ids,
|
|
request_ids_seq_ids_mapping,
|
|
accepted_token_ids_by_step)
|
|
maybe_rejsample_metrics = (
|
|
self._metrics.maybe_collect_rejsample_metrics(k))
|
|
if maybe_rejsample_metrics is not None and sampler_output_list:
|
|
sampler_output_list[
|
|
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
|
|
|
# Log time spent in each stage periodically.
|
|
# This is periodic because the rejection sampler emits metrics
|
|
# periodically.
|
|
self._maybe_log_stage_times(*stage_times)
|
|
# First `n_prefills` entries will contain prefills SamplerOutput when
|
|
# chunked prefill is enabled, the rest is decodes in multi-step format.
|
|
return sampler_output_list
|
|
|
|
def _track_sequences_with_bonus_tokens(
|
|
self, seq_ids: List[int],
|
|
request_ids_seq_ids_mapping: Dict[str, Set[int]],
|
|
accepted_token_ids_by_step: List[List[int]]):
|
|
"""
|
|
Updates the internal data structures which keep track of sequences
|
|
which have been assigned bonus tokens in their last forward pass.
|
|
"""
|
|
for seq_index, seq_id in enumerate(seq_ids):
|
|
# last_token_id = accepted_token_ids_by_step[-1][seq_index]
|
|
# if last_token_id == -1:
|
|
# self._seq_with_bonus_token_in_last_step.discard(seq_id)
|
|
# else:
|
|
self._seq_with_bonus_token_in_last_step.add(seq_id)
|
|
for request_id, sequences in request_ids_seq_ids_mapping.items():
|
|
self._request_id_seq_id_mapping[request_id].update(sequences) |