# SPDX-License-Identifier: Apache-2.0 from typing import Set from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) import torch from .var import * class Top1Proposer(SpeculativeProposer): def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: """Get speculative proposals given the input batch. Sequences which would exceed the max model length are skipped during speculation. """ proposal_len = execute_model_req.num_lookahead_slots seq_group_metadata_list = execute_model_req.seq_group_metadata_list # Split speculative- and non-speculative- sequences. ( proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices, ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) if nonzero_proposal_len_seqs: # Speculate tokens using the draft worker for the speculative # sequences. # If sampler_transposed is true, then maybe_sampler_output's # token_ids is like [batch] format in proposal_len size list, # while if it is false, the format would be [proposal_len] # in batch size list hidden_states = execute_model_req.previous_hidden_states if hidden_states is not None: hidden_states.prune(nonzero_proposal_len_seqs) nonzero_execute_model_req = ExecuteModelRequest( seq_group_metadata_list=nonzero_proposal_len_seqs, num_lookahead_slots=proposal_len, previous_hidden_states=hidden_states, ) #思路,将sampler out和merge output合成1个OP #remove no proposal_seqs 该流程没必要,直接全是草稿输出,不存在混合 #功能1: 筛选 indices #功能2: 返回 tensor #功能3: 对tensor进行transpose if USE_FUSED_MTP_SAMPLER: sampler_outputs, token_indices, transposed = self._worker.sampler_output( execute_model_req=nonzero_execute_model_req, sample_len=proposal_len, seq_ids_with_bonus_token_in_last_step=\ seq_ids_with_bonus_token_in_last_step, ) outputs = sampler_outputs[0].outputs token_probs = sampler_outputs[0].sampled_token_probs token_ids = sampler_outputs[0].sampled_token_ids bs = len(seq_group_metadata_list) proposal_lens = torch.ones((bs), dtype=torch.int, device=self._worker.device) proposal_tokens = token_ids[token_indices] proposal_probs = token_probs[token_indices] s0, s1 = proposal_probs.shape proposal_probs = proposal_probs.view(s0, 1, s1) # 筛选indices,构建新的SamplerOut proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens, proposal_probs=proposal_probs, proposal_lens=proposal_lens, no_proposals=outputs is None) return proposals maybe_sampler_output, transposed = self._worker.sampler_output( execute_model_req=nonzero_execute_model_req, sample_len=proposal_len, seq_ids_with_bonus_token_in_last_step=\ seq_ids_with_bonus_token_in_last_step, ) ( proposal_lens, maybe_sampler_output, nonzero_proposal_len_indices, ) = self._remove_no_proposal_seqs(proposal_lens, maybe_sampler_output, nonzero_proposal_len_indices, transposed) else: # If no sequences can be speculated, set sampler output to None. maybe_sampler_output = None transposed = False # Combine speculative- and non-speculative sequences into the same # representation. proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( batch_size=len(seq_group_metadata_list), proposal_len=proposal_len, maybe_sampler_output=maybe_sampler_output, proposal_lens=proposal_lens, nonzero_proposal_len_indices=nonzero_proposal_len_indices, sampler_transposed=transposed, ) proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens, proposal_probs=proposal_probs, proposal_lens=proposal_lens, no_proposals=maybe_sampler_output is None) return proposals