from array import array import numpy as np from itertools import chain, count from typing import Iterator, List, Optional, Tuple import torch from vllm import SamplingParams from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, SequenceData, SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len from vllm.utils import async_tensor_h2d from vllm.zero_overhead.utils import get_proposal_lens_list, record_proposal_token_ids SeqId = int TargetSeqId = int TokenId = int DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams() class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer): @nvtx_range("BatchExpansionTop1Scorer.score_proposals") def score_proposals( self, execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, ) -> SpeculativeScores: """Score the proposed tokens via the scorer model. This converts each input sequence to a set of k+1 target sequences. The target sequences have the unique continuations to be scored and a unique sequence ID that is different from all input sequence ids. If a speculative sequence length would exceed the max model length, then no speculation is produced for that sequence. Args: execute_model_req: The execution request. proposals: The speculative proposals to score. Returns: SpeculativeScores: The scores of each speculative token, along with which sequences were ignored during scoring. """ proposal_lens_list = get_proposal_lens_list() record_proposal_token_ids(proposals.proposal_token_ids) proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist() # place holder tokens # Filter the list to ignore invalid proposals. proposal_token_ids_list_without_skips = [ proposals for proposals in proposal_token_ids_list if VLLM_INVALID_TOKEN_ID not in proposals ] (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) = self._expand_batch( seq_group_metadata_list=execute_model_req.seq_group_metadata_list, proposal_token_ids_list=proposal_token_ids_list_without_skips, proposal_lens_list=proposal_lens_list, ) target_sampler_output = self._scorer_worker.execute_model( execute_model_req=execute_model_req.clone( seq_group_metadata_list=target_seq_group_metadata_list)) assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] if not non_spec_indices: # All sequence groups in batch have spec decoding enabled return self._contract_batch_all_spec( target_sampler_output=target_sampler_output, proposals=proposals, ) else: # Batch has a mix of spec decode enabled and disabled seq groups return self._contract_batch( execute_model_req.seq_group_metadata_list, target_sampler_output=target_sampler_output, proposals=proposals, num_scoring_tokens=num_scoring_tokens, non_spec_indices=non_spec_indices, spec_indices=spec_indices, k=execute_model_req.num_lookahead_slots, ) def _contract_non_speculative( self, scores: SpeculativeScores, seq_group_metadata_list: List[SequenceGroupMetadata], non_spec_indices: List[int], non_spec_outputs: SpeculativeScores, has_prompt_log: bool) -> SpeculativeScores: """ Augment input `scores` with non-speculative requests outputs. This includes decode requests with speculation turned off, as well as prefill requests when `enable_chunked_prefill` is set. For the latter, prefills are further separated into terminal and non-terminal chunks (from which no token is sampled). """ if not non_spec_indices: return scores if has_prompt_log: # When prompt_logprobs is enabled, prefills yield output token # (and respective prob) in the last entry (prompt|out): # [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..]. # With chunked prefill, non-terminal chunks have -1 on each # position: they're still picked, but they're discarded later. seq_meta = seq_group_metadata_list nospec_sizes = torch.tensor([ seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1 for i in non_spec_indices ]) nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1) else: # In this case only sampled tokens are returned, select all. nospec_sampled_token_idxs = list( range(len(non_spec_outputs.token_ids))) nospec_sampled_token_idxs = async_tensor_h2d(nospec_sampled_token_idxs, torch.int32, self._device, True) non_spec_indices = async_tensor_h2d(non_spec_indices, torch.int32, self._device, True) scores.token_ids[non_spec_indices, :1] = \ non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1) scores.probs[non_spec_indices, :1, :] = \ non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1) scores.logprobs[non_spec_indices, :1, :] = \ non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1) if scores.hidden_states is not None: assert non_spec_outputs.hidden_states is not None scores.hidden_states[non_spec_indices, :1, :] = \ non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1) return scores