Files
2026-04-02 04:55:00 +00:00

118 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
from typing import Set
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
import torch
from .var import *
class Top1Proposer(SpeculativeProposer):
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
proposal_len = execute_model_req.num_lookahead_slots
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
# Split speculative- and non-speculative- sequences.
(
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
hidden_states = execute_model_req.previous_hidden_states
if hidden_states is not None:
hidden_states.prune(nonzero_proposal_len_seqs)
nonzero_execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=nonzero_proposal_len_seqs,
num_lookahead_slots=proposal_len,
previous_hidden_states=hidden_states,
)
#思路将sampler out和merge output合成1个OP
#remove no proposal_seqs 该流程没必要,直接全是草稿输出,不存在混合
#功能1 筛选 indices
#功能2 返回 tensor
#功能3 对tensor进行transpose
if USE_FUSED_MTP_SAMPLER:
sampler_outputs, token_indices, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len,
seq_ids_with_bonus_token_in_last_step=\
seq_ids_with_bonus_token_in_last_step,
)
outputs = sampler_outputs[0].outputs
token_probs = sampler_outputs[0].sampled_token_probs
token_ids = sampler_outputs[0].sampled_token_ids
bs = len(seq_group_metadata_list)
proposal_lens = torch.ones((bs), dtype=torch.int, device=self._worker.device)
proposal_tokens = token_ids[token_indices]
proposal_probs = token_probs[token_indices]
s0, s1 = proposal_probs.shape
proposal_probs = proposal_probs.view(s0, 1, s1)
# 筛选indices构建新的SamplerOut
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
no_proposals=outputs
is None)
return proposals
maybe_sampler_output, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len,
seq_ids_with_bonus_token_in_last_step=\
seq_ids_with_bonus_token_in_last_step,
)
(
proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices,
) = self._remove_no_proposal_seqs(proposal_lens,
maybe_sampler_output,
nonzero_proposal_len_indices,
transposed)
else:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output = None
transposed = False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
batch_size=len(seq_group_metadata_list),
proposal_len=proposal_len,
maybe_sampler_output=maybe_sampler_output,
proposal_lens=proposal_lens,
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
sampler_transposed=transposed,
)
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens,
no_proposals=maybe_sampler_output
is None)
return proposals