from vllm.sequence import VLLM_INVALID_TOKEN_ID, ExecuteModelRequest from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScores) class MLUBatchExpansionTop1Scorer(BatchExpansionTop1Scorer): """Implements a speculative scorer that uses batch expansion to get probabilities of speculative tokens according to the scoring model. Batch expansion converts a list of sequences and multiple query positions to a new batch of sequences, each with a single query position. This allows for MQA-like scoring in speculative decoding without requiring an MQA kernel. It is strictly less efficient than MQA scoring. It only supports scoring the top1 proposal tokens of the proposer, instead of topk/tree. """ def score_proposals( self, execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, ) -> SpeculativeScores: """Score the proposed tokens via the scorer model. This converts each input sequence to a set of k+1 target sequences. The target sequences have the unique continuations to be scored and a unique sequence ID that is different from all input sequence ids. If a speculative sequence length would exceed the max model length, then no speculation is produced for that sequence. Args: execute_model_req: The execution request. proposals: The speculative proposals to score. Returns: SpeculativeScores: The scores of each speculative token, along with which sequences were ignored during scoring. """ # TODO(cade) perform this on GPU to remove blocking call. proposal_lens_list = proposals.proposal_lens.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist() # Filter the list to ignore invalid proposals. proposal_token_ids_list_without_skips = [ proposals for proposals in proposal_token_ids_list if VLLM_INVALID_TOKEN_ID not in proposals ] (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) = self._expand_batch( seq_group_metadata_list=execute_model_req.seq_group_metadata_list, proposal_token_ids_list=proposal_token_ids_list_without_skips, proposal_lens_list=proposal_lens_list, ) target_sampler_output = self._scorer_worker.execute_model( execute_model_req=execute_model_req.clone( seq_group_metadata_list=target_seq_group_metadata_list)) assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] if not non_spec_indices: # All sequence groups in batch have spec decoding enabled contracted = self._contract_batch_all_spec( target_sampler_output=target_sampler_output, proposals=proposals, ) else: # Batch has a mix of spec decode enabled and disabled seq groups contracted = self._contract_batch( execute_model_req.seq_group_metadata_list, target_sampler_output=target_sampler_output, proposals=proposals, num_scoring_tokens=num_scoring_tokens, non_spec_indices=non_spec_indices, spec_indices=spec_indices, k=execute_model_req.num_lookahead_slots, ) all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted return SpeculativeScores( probs=all_probs, token_ids=all_tokens, logprobs=spec_logprobs, hidden_states=all_hidden_states, )