add qwen3
This commit is contained in:
92
vllm-v0.6.2/vllm/spec_decode/mlu_batch_expansion.py
Normal file
92
vllm-v0.6.2/vllm/spec_decode/mlu_batch_expansion.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from vllm.sequence import VLLM_INVALID_TOKEN_ID, ExecuteModelRequest
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScores)
|
||||
|
||||
|
||||
class MLUBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
|
||||
"""Implements a speculative scorer that uses batch expansion to get
|
||||
probabilities of speculative tokens according to the scoring model.
|
||||
|
||||
Batch expansion converts a list of sequences and multiple query positions
|
||||
to a new batch of sequences, each with a single query position. This allows
|
||||
for MQA-like scoring in speculative decoding without requiring an MQA
|
||||
kernel.
|
||||
|
||||
It is strictly less efficient than MQA scoring.
|
||||
|
||||
It only supports scoring the top1 proposal tokens of the proposer, instead
|
||||
of topk/tree.
|
||||
"""
|
||||
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
"""Score the proposed tokens via the scorer model.
|
||||
|
||||
This converts each input sequence to a set of k+1 target sequences. The
|
||||
target sequences have the unique continuations to be scored and a
|
||||
unique sequence ID that is different from all input sequence ids.
|
||||
|
||||
If a speculative sequence length would exceed the max model length, then
|
||||
no speculation is produced for that sequence.
|
||||
|
||||
Args:
|
||||
execute_model_req: The execution request.
|
||||
proposals: The speculative proposals to score.
|
||||
Returns:
|
||||
SpeculativeScores: The scores of each speculative token, along with
|
||||
which sequences were ignored during scoring.
|
||||
"""
|
||||
|
||||
# TODO(cade) perform this on GPU to remove blocking call.
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||
|
||||
# Filter the list to ignore invalid proposals.
|
||||
proposal_token_ids_list_without_skips = [
|
||||
proposals for proposals in proposal_token_ids_list
|
||||
if VLLM_INVALID_TOKEN_ID not in proposals
|
||||
]
|
||||
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens) = self._expand_batch(
|
||||
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list))
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
if not non_spec_indices:
|
||||
# All sequence groups in batch have spec decoding enabled
|
||||
contracted = self._contract_batch_all_spec(
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
)
|
||||
else:
|
||||
# Batch has a mix of spec decode enabled and disabled seq groups
|
||||
contracted = self._contract_batch(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
non_spec_indices=non_spec_indices,
|
||||
spec_indices=spec_indices,
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
)
|
||||
|
||||
all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
|
||||
return SpeculativeScores(
|
||||
probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=spec_logprobs,
|
||||
hidden_states=all_hidden_states,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user