118 lines
5.1 KiB
Python
118 lines
5.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
|
||
from typing import Set
|
||
from vllm.sequence import ExecuteModelRequest
|
||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||
SpeculativeProposer)
|
||
import torch
|
||
|
||
from .var import *
|
||
|
||
class Top1Proposer(SpeculativeProposer):
|
||
|
||
def get_spec_proposals(
|
||
self,
|
||
execute_model_req: ExecuteModelRequest,
|
||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||
) -> SpeculativeProposals:
|
||
"""Get speculative proposals given the input batch.
|
||
|
||
Sequences which would exceed the max model length are skipped during
|
||
speculation.
|
||
"""
|
||
proposal_len = execute_model_req.num_lookahead_slots
|
||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||
|
||
# Split speculative- and non-speculative- sequences.
|
||
(
|
||
proposal_lens,
|
||
nonzero_proposal_len_seqs,
|
||
nonzero_proposal_len_indices,
|
||
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
|
||
|
||
if nonzero_proposal_len_seqs:
|
||
# Speculate tokens using the draft worker for the speculative
|
||
# sequences.
|
||
# If sampler_transposed is true, then maybe_sampler_output's
|
||
# token_ids is like [batch] format in proposal_len size list,
|
||
# while if it is false, the format would be [proposal_len]
|
||
# in batch size list
|
||
hidden_states = execute_model_req.previous_hidden_states
|
||
if hidden_states is not None:
|
||
hidden_states.prune(nonzero_proposal_len_seqs)
|
||
nonzero_execute_model_req = ExecuteModelRequest(
|
||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||
num_lookahead_slots=proposal_len,
|
||
previous_hidden_states=hidden_states,
|
||
)
|
||
|
||
#思路,将sampler out和merge output合成1个OP
|
||
#remove no proposal_seqs 该流程没必要,直接全是草稿输出,不存在混合
|
||
#功能1: 筛选 indices
|
||
#功能2: 返回 tensor
|
||
#功能3: 对tensor进行transpose
|
||
if USE_FUSED_MTP_SAMPLER:
|
||
sampler_outputs, token_indices, transposed = self._worker.sampler_output(
|
||
execute_model_req=nonzero_execute_model_req,
|
||
sample_len=proposal_len,
|
||
seq_ids_with_bonus_token_in_last_step=\
|
||
seq_ids_with_bonus_token_in_last_step,
|
||
)
|
||
|
||
outputs = sampler_outputs[0].outputs
|
||
token_probs = sampler_outputs[0].sampled_token_probs
|
||
token_ids = sampler_outputs[0].sampled_token_ids
|
||
|
||
bs = len(seq_group_metadata_list)
|
||
proposal_lens = torch.ones((bs), dtype=torch.int, device=self._worker.device)
|
||
|
||
proposal_tokens = token_ids[token_indices]
|
||
proposal_probs = token_probs[token_indices]
|
||
|
||
s0, s1 = proposal_probs.shape
|
||
proposal_probs = proposal_probs.view(s0, 1, s1)
|
||
|
||
# 筛选indices,构建新的SamplerOut
|
||
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
|
||
proposal_probs=proposal_probs,
|
||
proposal_lens=proposal_lens,
|
||
no_proposals=outputs
|
||
is None)
|
||
return proposals
|
||
|
||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||
execute_model_req=nonzero_execute_model_req,
|
||
sample_len=proposal_len,
|
||
seq_ids_with_bonus_token_in_last_step=\
|
||
seq_ids_with_bonus_token_in_last_step,
|
||
)
|
||
(
|
||
proposal_lens,
|
||
maybe_sampler_output,
|
||
nonzero_proposal_len_indices,
|
||
) = self._remove_no_proposal_seqs(proposal_lens,
|
||
maybe_sampler_output,
|
||
nonzero_proposal_len_indices,
|
||
transposed)
|
||
else:
|
||
# If no sequences can be speculated, set sampler output to None.
|
||
maybe_sampler_output = None
|
||
transposed = False
|
||
|
||
# Combine speculative- and non-speculative sequences into the same
|
||
# representation.
|
||
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||
batch_size=len(seq_group_metadata_list),
|
||
proposal_len=proposal_len,
|
||
maybe_sampler_output=maybe_sampler_output,
|
||
proposal_lens=proposal_lens,
|
||
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||
sampler_transposed=transposed,
|
||
)
|
||
|
||
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
|
||
proposal_probs=proposal_probs,
|
||
proposal_lens=proposal_lens,
|
||
no_proposals=maybe_sampler_output
|
||
is None)
|
||
return proposals |