import os import copy from collections import defaultdict from functools import cached_property from typing import Any, Dict, List, Optional, Set, Tuple, Type import torch import torch.nn as nn from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.distributed.communication_op import (broadcast_tensor_dict, get_tp_group, tensor_model_parallel_gather) from vllm.distributed.parallel_state import model_parallel_is_initialized from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) from vllm.platforms import current_platform from vllm.sequence import (VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, get_all_seq_ids_and_request_ids, Logits) from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, prepare_prefill_hidden_states from vllm.zero_overhead.spec_decode.batch_expansion import ZeroOverheadBatchExpansionTop1Scorer from vllm.zero_overhead.utils import SpecStepKind, record_accepted_token_ids, set_spec_step if current_platform.is_cuda_alike(): from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.medusa_worker import MedusaWorker from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker from vllm.spec_decode.mqa_scorer import MQAScorer from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.target_model_runner import TargetModelRunner from vllm.spec_decode.util import (Timer, create_logprobs_output, create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) from vllm.utils import async_tensor_h2d, resolve_obj_by_qualname from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase from vllm.worker.cache_engine import CacheEngine from vllm.attention.ops.paged_attn import PagedAttention from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase logger = init_logger(__name__) class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker): def init_device(self) -> None: """Initialize both scorer and proposer models. """ # The scorer worker model is initialized first in case the proposer # model has a smaller TP degree than the target worker. self.scorer_worker.init_device() self.proposer_worker.init_device() # NOTE(cade): load_model is not part of the WorkerBase interface. self.scorer_worker.load_model() self.proposer_worker.load_model() if self._enable_lm_head_weight_load: # NOTE(Shangming): gather lm_head weight when tp enabled target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather( self.scorer_worker.model_runner.model_runner.model.lm_head.\ weight.data, dim=0, ) self.proposer_worker.maybe_load_lm_head_weight( target_lm_head_weight) self._metrics.init_tensors(self.rank, device_type=self.device) if model_parallel_is_initialized(): self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, device_type=self.device) else: self.spec_decode_sampler.init_tensors(self.rank, device_type=self.device) scorer_cls: Type[SpeculativeScorer] if self.disable_mqa_scorer: scorer_cls = ZeroOverheadBatchExpansionTop1Scorer logger.info("[Speculative Decoding] Use batch " "expansion for scoring proposals.") else: scorer_cls = MQAScorer logger.info( "[Speculative Decoding] Use MQA scorer for scoring proposals.") if not self.tree_decoding: self.scorer = scorer_cls(scorer_worker=self.scorer_worker, device=self.device, vocab_size=self._vocab_size) else: self.scorer = BatchExpansionTreeStyleScorer( scorer_worker=self.scorer_worker, device=self.device, vocab_size=self._vocab_size) self._configure_model_sampler_for_spec_decode() @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec(self, execute_model_req: ExecuteModelRequest, skip_proposer: bool) -> List[SamplerOutput]: """Run a single generation step without any speculation. The input is sent to the proposer and scorer model so that the KV cache is consistent between the two. When skip_proposer is True, the proposer model is not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. """ if self.tree_decoding and self.kvcache_slot_to_be_moved is not None: execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved self.kvcache_slot_to_be_moved = None set_spec_step(SpecStepKind.PREFILL) sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] # Store hidden states from target model execution, BxD. hidden_states = sampler_output.hidden_states if hidden_states is not None: # Only decodes and prefill terminal chunks need a hidden state. seq_group_meta_with_hidden = [ sg for sg in execute_model_req.seq_group_metadata_list if sg.do_sample ] if any(seq.is_prompt for seq in seq_group_meta_with_hidden): # Drop hidden_states with no prediction (eg non-terminal chunks) hidden_states = hidden_states[ torch.where(sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] # if not skip_proposer: # if self.previous_hidden_states is None and len( # seq_group_meta_with_hidden): # self.previous_hidden_states = HiddenStates( # hidden_states, seq_group_meta_with_hidden) # elif self.previous_hidden_states and len( # seq_group_meta_with_hidden): # self.previous_hidden_states.update(hidden_states, # seq_group_meta_with_hidden) if self.previous_hidden_states is None and len( seq_group_meta_with_hidden): self.previous_hidden_states = HiddenStates( hidden_states, seq_group_meta_with_hidden) elif self.previous_hidden_states and len( seq_group_meta_with_hidden): self.previous_hidden_states.update(hidden_states, seq_group_meta_with_hidden) # Store logits from target model execution. if self.tree_decoding: logits = sampler_output.logits if logits is not None: if self.previous_logits is None: self.previous_logits = Logits( logits, execute_model_req.seq_group_metadata_list) else: self.previous_logits.update( logits, execute_model_req.seq_group_metadata_list) if not skip_proposer: # We prepare the prefill hidden states here so that there no # additional complexity in worker for spec_decode vs non_spec_decode # flow and execute_model doesn't need additional modifications. execute_model_req.previous_hidden_states = \ prepare_prefill_hidden_states( sampler_output.prefill_hidden_states) for i in range(self._num_spec_prefill_steps): execute_model_req.spec_step_idx = i self.proposer_worker.execute_model(execute_model_req) sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( execute_model_req=execute_model_req, sampler_output=sampler_output) if self._disable_logprobs else [sampler_output]) # Clear device tensors from sampler output. This reduces communication # overhead when the engine runs in a different process than the workers. sampler_output.sampled_token_probs = None sampler_output.sampled_token_ids = None sampler_output.logprobs = None return sampler_output_to_return @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( self, seq_group_metadata_list: List[SequenceGroupMetadata], proposal_scores: SpeculativeScores, proposals: SpeculativeProposals, max_proposal_len: int, ) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]: """Determine which speculative tokens are accepted using the probabilities of each token according to the proposer and scorer models. Returns a tuple of Tensors, one for the accepted token ids and one for the logprobs according to the scoring model. """ proposal_lens_list = proposals.proposal_lens # vLLM currently only supports proposal lens equal to zero or the batch # proposal len. This adds some complexity (splitting the batch into spec # and non spec sequences) and should be removed in the future. It can be # done by supporting per-sequence proposal lens. (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len( seq_group_metadata_list, proposal_lens_list) original_indices = spec_indices + non_spec_indices # Get probabilities of target model, including bonus tokens. if non_spec_indices: proposal_verifier_probs = proposal_scores.probs[spec_indices] else: proposal_verifier_probs = proposal_scores.probs if self.tree_decoding: retrieve_indices = proposals.retrieve_indices proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices] # Get non-speculative sampled tokens from target model. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] # Get bonus tokens from target model. bonus_token_ids = proposal_scores.token_ids[:, -1:] if non_spec_indices: bonus_token_ids = bonus_token_ids[spec_indices, :] # Get probabilities according to proposal method. proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None if proposal_probs is not None and non_spec_indices: proposal_probs = proposal_probs[spec_indices] # Get proposed tokens. proposal_token_ids = proposals.proposal_token_ids if non_spec_indices: proposal_token_ids = proposal_token_ids[spec_indices] # Get tree buffers. cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None if cart_candidates is not None and non_spec_indices: cart_candidates = cart_candidates[spec_indices] # Sampler arguments sampler_extra_kwargs: Dict[str, Any] = {} if self.generators and isinstance(self.spec_decode_sampler, SpecDecodeStochasticBaseSampler): sampler_extra_kwargs["seeded_seqs"] = { idx: self.generators[sgm.request_id] for idx, sgm in enumerate(seq_group_metadata_list) if sgm.sampling_params.seed is not None } if isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler): sampler_extra_kwargs["cart_candidates"] = cart_candidates sampler_extra_kwargs["best_candidates"] = [] sampler_extra_kwargs["accept_lengths"] = [] first_step_flags = [] for i, sgm in enumerate(seq_group_metadata_list): seq = next(iter(sgm.seq_data.values())) first_step_flags.append(True if seq.get_first_step_flag() else False) sampler_extra_kwargs["first_step_flags"] = first_step_flags accepted_token_ids = self.spec_decode_sampler( target_with_bonus_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, draft_probs=proposal_probs, draft_token_ids=proposal_token_ids, **sampler_extra_kwargs, ) # Append output tokens from non-speculative sequences to # the accepted token ids tensor. if not self.tree_decoding: non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + 1).clone() else: non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len).clone() non_spec_token_ids[:, 1:] = -1 accepted_token_ids = torch.cat( [accepted_token_ids, non_spec_token_ids]) logprobs = proposal_scores.logprobs # Rearrange so that results are in the order of the original seq group # metadata. original_indices = async_tensor_h2d(original_indices, torch.int32, self.device, True) accepted_token_ids[original_indices] = accepted_token_ids.clone() # B x K+1 x D hidden_states = proposal_scores.hidden_states select_indices = None accept_lengths = None select_indices_list = [] if cart_candidates is None: if hidden_states is not None: # Only get terminal hidden states for next step terminal_metadata = [ sg for sg in seq_group_metadata_list if sg.do_sample ] # Contract hidden states based on accepted tokens hs_size = hidden_states.shape[-1] accepted_index = accepted_token_ids + 1 # Convert -1 to 0 accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b # Drop non-terminal prefill chunks hidden states. hidden_states = hidden_states[accepted_index != VLLM_INVALID_TOKEN_ID] accepted_index = accepted_index[accepted_index != VLLM_INVALID_TOKEN_ID] assert len(accepted_index) == hidden_states.shape[0] == len( terminal_metadata) index = accepted_index[:, None, None].expand(-1, 1, hs_size) # b x 1 x d second_last_token_hidden_states = hidden_states[:, -2] # b x d hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d # Store hidden states from target model for subsequent decode step self.previous_hidden_states = HiddenStates( hidden_states, terminal_metadata, second_last_token_hidden_states) else: retrieve_indices = proposals.retrieve_indices batch_size = len(seq_group_metadata_list) best_candidates = sampler_extra_kwargs["best_candidates"] accept_lengths = sampler_extra_kwargs["accept_lengths"] # Contract hidden states based on accepted tokens hs_size = hidden_states.shape[-1] hidden_states = hidden_states.view(batch_size, -1, hs_size) # Store logits from target model for subsequent proposal logits = proposal_scores.logits logits = logits.view(batch_size, -1, logits.shape[-1]) logits = logits[:, retrieve_indices] # [batch_size, retrieve_size, max_depth, vocab_size] previous_logits_list = [] previous_hidden_state_list = [] retrieve_indices = retrieve_indices.cpu() for i in range(batch_size): logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0) previous_logits_list.append(logit) select_indices = retrieve_indices[best_candidates[i], :accept_lengths[i]+1] hidden_state = hidden_states[i, select_indices[-1]].unsqueeze(0) select_indices_list.append(select_indices) previous_hidden_state_list.append(hidden_state) logits = torch.cat(previous_logits_list, dim=0) self.previous_logits = Logits(logits, seq_group_metadata_list) hidden_states = torch.cat(previous_hidden_state_list, dim=0) # [batch_size, 1, vocab_size] self.previous_hidden_states = HiddenStates(hidden_states, seq_group_metadata_list,) return accepted_token_ids, logprobs, select_indices_list, accept_lengths def _create_output_sampler_list( self, seq_group_metadata_list: List[SequenceGroupMetadata], accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] prompt_logprobs: Optional[ torch.Tensor], # shape: [nprompt_tokens, vocab_size] k: int, stage_times: Tuple[float, float, float], ) -> List[SamplerOutput]: """Given the accepted token ids, create a list of SamplerOutput. The output is padded with -1 tokens such that each sequence has the same number of outputs. """ batch_size, num_steps = accepted_token_ids.shape accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) if self._disable_logprobs: # We are skipping the logprobs. Hence don't serialize the # logprobs related tensors from the GPU. Instead create # empty/dummy lists. (accepted_token_id_ranks_by_step, accepted_token_id_logprobs_by_step, topk_logprobs_by_step, topk_indices_by_step) =\ self._create_dummy_logprob_lists( batch_size, num_steps, self.scorer_worker.model_config.max_logprobs) else: # Organize input tensors by step instead of by sequence. target_logprobs_by_step = target_logprobs.transpose(0, 1) # Serialize all tensors into Python lists. (accepted_token_id_ranks_by_step, accepted_token_id_logprobs_by_step, topk_logprobs_by_step, topk_indices_by_step) =\ self._create_logprob_lists_from_tensors( target_logprobs_by_step, accepted_token_ids_by_step, self.scorer_worker.model_config.max_logprobs) # Get the sequence ids and num_logprobs (sampling parameter) in the # batch. seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids( seq_group_metadata_list) num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) # Serialize tensor to CPU Python list. #accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() record_accepted_token_ids(accepted_token_ids, seq_ids) # Construct the output on a per-step, per-sequence basis. # Non-terminal prefill chunks will end up here as rows with just -1s # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while # terminal chunks will only have one generated token at time 0. sampler_output_list: List[SamplerOutput] = [] # Prefills are not multi-step (return at most 1 token), in order to # avoid padding or repetition to fit decodes, we separate them. for i, sg in enumerate(seq_group_metadata_list): if not sg.is_prompt: # Requests are ordered as prefills|decodes=>no more prefills. break num_logprobs = num_logprobs_per_seq[i] seq_kwargs = dict(token_id=-1, token_id_logprob_rank=0, token_id_logprob=-float('inf'), topk_token_ids=[-1] * num_logprobs, topk_logprobs=[-float('inf')] * num_logprobs, seq_id=seq_ids[i]) # Terminal chunk, has token. if sg.do_sample: seq_kwargs.update( dict( token_id=accepted_token_ids[i][0].item(), token_id_logprob_rank=accepted_token_id_ranks_by_step[ 0][i], token_id_logprob=accepted_token_id_logprobs_by_step[0] [i], topk_token_ids=topk_indices_by_step[0][i] [:num_logprobs], # output only so step is 0 topk_logprobs=topk_logprobs_by_step[0][i] [:num_logprobs], )) needs_plogs = (sg.sampling_params.prompt_logprobs and sg.sampling_params.prompt_logprobs > 0) plogs = None if prompt_logprobs is not None: # Even non-terminal prompt chunks can have logprobs here. plogs = prompt_logprobs[i] elif needs_plogs: # Prompt logprobs are requested but `_disable_logprobs` is set. seq_data = next(iter(sg.seq_data.values())) # Get only the tokens in this chunk! prompt_token_ids = seq_data.get_prompt_token_ids() prompt_token_ids = prompt_token_ids[ seq_data. _num_computed_tokens:seq_data._num_computed_tokens + sg.token_chunk_size] is_first_chunk = seq_data._num_computed_tokens == 0 # There's no prob generated for the first token in a sequence. if is_first_chunk: prompt_token_ids = prompt_token_ids[1:] plogs = [ create_logprobs_output( token_id=p_token_id, token_id_logprob_rank=-1, token_id_logprob=0.0, topk_token_ids=[], topk_logprobs=[], ) for p_token_id in prompt_token_ids ] seq_kwargs.update(dict(prompt_logprobs=plogs)) sampler_output_list.append( SamplerOutput( outputs=[create_sequence_group_output( **seq_kwargs)])) # type: ignore # Decodes, create one SamplerOutput per-step (at most K+1). for step_index in range(num_steps): # if all(token_id == -1 for sg, token_id in zip( # seq_group_metadata_list, # accepted_token_ids_by_step[step_index]) # if not sg.is_prompt): # break step_output_token_ids: List[CompletionSequenceGroupOutput] = [] for sequence_index in range(batch_size): seq_meta = seq_group_metadata_list[sequence_index] # Prompts already processed above. if seq_meta.is_prompt: continue # Each sequence may have a different num_logprobs; retrieve it. num_logprobs = num_logprobs_per_seq[sequence_index] step_output_token_ids.append( create_sequence_group_output( token_id = 0, token_id_logprob_rank=accepted_token_id_ranks_by_step[ step_index][sequence_index], token_id_logprob=accepted_token_id_logprobs_by_step[ step_index][sequence_index], seq_id=seq_ids[sequence_index], topk_token_ids=topk_indices_by_step[step_index] [sequence_index][:num_logprobs], topk_logprobs=topk_logprobs_by_step[step_index] [sequence_index][:num_logprobs], )) sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) # Populate the data structures needed to keep track of sequences with # bonus tokens. self._track_sequences_with_bonus_tokens(seq_ids, request_ids_seq_ids_mapping, accepted_token_ids_by_step) maybe_rejsample_metrics = ( self._metrics.maybe_collect_rejsample_metrics(k)) if maybe_rejsample_metrics is not None and sampler_output_list: sampler_output_list[ 0].spec_decode_worker_metrics = maybe_rejsample_metrics # Log time spent in each stage periodically. # This is periodic because the rejection sampler emits metrics # periodically. self._maybe_log_stage_times(*stage_times) # First `n_prefills` entries will contain prefills SamplerOutput when # chunked prefill is enabled, the rest is decodes in multi-step format. return sampler_output_list def _track_sequences_with_bonus_tokens( self, seq_ids: List[int], request_ids_seq_ids_mapping: Dict[str, Set[int]], accepted_token_ids_by_step: List[List[int]]): """ Updates the internal data structures which keep track of sequences which have been assigned bonus tokens in their last forward pass. """ for seq_index, seq_id in enumerate(seq_ids): # last_token_id = accepted_token_ids_by_step[-1][seq_index] # if last_token_id == -1: # self._seq_with_bonus_token_in_last_step.discard(seq_id) # else: self._seq_with_bonus_token_in_last_step.add(seq_id) for request_id, sequences in request_ids_seq_ids_mapping.items(): self._request_id_seq_id_mapping[request_id].update(sequences)