init
This commit is contained in:
118
vllm_vacc/vllm/spec_decode/top1_proposer.py
Normal file
118
vllm_vacc/vllm/spec_decode/top1_proposer.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Set
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
import torch
|
||||
|
||||
from .var import *
|
||||
|
||||
class Top1Proposer(SpeculativeProposer):
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
proposal_len = execute_model_req.num_lookahead_slots
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
(
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
# sequences.
|
||||
# If sampler_transposed is true, then maybe_sampler_output's
|
||||
# token_ids is like [batch] format in proposal_len size list,
|
||||
# while if it is false, the format would be [proposal_len]
|
||||
# in batch size list
|
||||
hidden_states = execute_model_req.previous_hidden_states
|
||||
if hidden_states is not None:
|
||||
hidden_states.prune(nonzero_proposal_len_seqs)
|
||||
nonzero_execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
num_lookahead_slots=proposal_len,
|
||||
previous_hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
#思路,将sampler out和merge output合成1个OP
|
||||
#remove no proposal_seqs 该流程没必要,直接全是草稿输出,不存在混合
|
||||
#功能1: 筛选 indices
|
||||
#功能2: 返回 tensor
|
||||
#功能3: 对tensor进行transpose
|
||||
if USE_FUSED_MTP_SAMPLER:
|
||||
sampler_outputs, token_indices, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
seq_ids_with_bonus_token_in_last_step=\
|
||||
seq_ids_with_bonus_token_in_last_step,
|
||||
)
|
||||
|
||||
outputs = sampler_outputs[0].outputs
|
||||
token_probs = sampler_outputs[0].sampled_token_probs
|
||||
token_ids = sampler_outputs[0].sampled_token_ids
|
||||
|
||||
bs = len(seq_group_metadata_list)
|
||||
proposal_lens = torch.ones((bs), dtype=torch.int, device=self._worker.device)
|
||||
|
||||
proposal_tokens = token_ids[token_indices]
|
||||
proposal_probs = token_probs[token_indices]
|
||||
|
||||
s0, s1 = proposal_probs.shape
|
||||
proposal_probs = proposal_probs.view(s0, 1, s1)
|
||||
|
||||
# 筛选indices,构建新的SamplerOut
|
||||
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
no_proposals=outputs
|
||||
is None)
|
||||
return proposals
|
||||
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
seq_ids_with_bonus_token_in_last_step=\
|
||||
seq_ids_with_bonus_token_in_last_step,
|
||||
)
|
||||
(
|
||||
proposal_lens,
|
||||
maybe_sampler_output,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._remove_no_proposal_seqs(proposal_lens,
|
||||
maybe_sampler_output,
|
||||
nonzero_proposal_len_indices,
|
||||
transposed)
|
||||
else:
|
||||
# If no sequences can be speculated, set sampler output to None.
|
||||
maybe_sampler_output = None
|
||||
transposed = False
|
||||
|
||||
# Combine speculative- and non-speculative sequences into the same
|
||||
# representation.
|
||||
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||||
batch_size=len(seq_group_metadata_list),
|
||||
proposal_len=proposal_len,
|
||||
maybe_sampler_output=maybe_sampler_output,
|
||||
proposal_lens=proposal_lens,
|
||||
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||||
sampler_transposed=transposed,
|
||||
)
|
||||
|
||||
proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
no_proposals=maybe_sampler_output
|
||||
is None)
|
||||
return proposals
|
||||
Reference in New Issue
Block a user