Files

118 lines
5.1 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
# SPDX-License-Identifier: Apache-2.0
from typing import Set
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
import torch
from .var import *
class Top1Proposer(SpeculativeProposer):
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
proposal_len = execute_model_req.num_lookahead_slots
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
hidden_states = execute_model_req.previous_hidden_states
if hidden_states is not None:
hidden_states.prune(nonzero_proposal_len_seqs)
nonzero_execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=nonzero_proposal_len_seqs,
num_lookahead_slots=proposal_len,
previous_hidden_states=hidden_states,
)
#思路将sampler out和merge output合成1个OP
#remove no proposal_seqs 该流程没必要,直接全是草稿输出,不存在混合
#功能1 筛选 indices
#功能2 返回 tensor
#功能3 对tensor进行transpose
if USE_FUSED_MTP_SAMPLER:
sampler_outputs, token_indices, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len,
seq_ids_with_bonus_token_in_last_step=\
seq_ids_with_bonus_token_in_last_step,
)
outputs = sampler_outputs[0].outputs
token_probs = sampler_outputs[0].sampled_token_probs
token_ids = sampler_outputs[0].sampled_token_ids
bs = len(seq_group_metadata_list)
proposal_lens = torch.ones((bs), dtype=torch.int, device=self._worker.device)
proposal_tokens = token_ids[token_indices]
proposal_probs = token_probs[token_indices]
s0, s1 = proposal_probs.shape
proposal_probs = proposal_probs.view(s0, 1, s1)
# 筛选indices构建新的SamplerOut
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
no_proposals=outputs
is None)
return proposals
maybe_sampler_output, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len,
seq_ids_with_bonus_token_in_last_step=\
seq_ids_with_bonus_token_in_last_step,
)
(
proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices,
) = self._remove_no_proposal_seqs(proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices,
transposed)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
transposed = False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
proposal_len=proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
sampler_transposed=transposed,
)
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
no_proposals=maybe_sampler_output
is None)
return proposals