161 lines
7.4 KiB
Python
161 lines
7.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from vllm.sequence import (ExecuteModelRequest, SequenceData,
|
|
SequenceGroupMetadata, get_all_seq_ids)
|
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
|
SpeculativeScorer, SpeculativeScores)
|
|
|
|
SeqId = int
|
|
TargetSeqId = int
|
|
|
|
|
|
class MQAScorer(SpeculativeScorer):
|
|
|
|
def score_proposals(
|
|
self,
|
|
execute_model_req: ExecuteModelRequest,
|
|
proposals: SpeculativeProposals,
|
|
) -> SpeculativeScores:
|
|
target_seq_group_metadata_list = []
|
|
target_seq_id_start = max(
|
|
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
|
|
all_proposal_tokens = proposals.proposal_token_ids.tolist()
|
|
all_proposal_lengths = proposals.proposal_lens.tolist()
|
|
for i, seq_group_metadata in enumerate(
|
|
execute_model_req.seq_group_metadata_list):
|
|
if all_proposal_lengths[i] == 0:
|
|
# Keep prompt seqs untouched (keep computed_tokens for chunks).
|
|
target_seq_group_metadata_list.append(seq_group_metadata)
|
|
continue
|
|
|
|
seq_data_dict = seq_group_metadata.seq_data
|
|
assert len(seq_data_dict) == 1
|
|
seq_id = next(iter(seq_data_dict.keys()))
|
|
|
|
seq_data: SequenceData = seq_data_dict[seq_id]
|
|
prompt_token_ids = seq_data.get_prompt_token_ids()
|
|
output_token_ids = seq_data.get_output_token_ids()
|
|
proposal_token_ids = all_proposal_tokens[
|
|
i][:all_proposal_lengths[i]]
|
|
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
|
|
|
|
target_seq_id = target_seq_id_start + i
|
|
new_seq_data = SequenceData.from_seqs(
|
|
prompt_token_ids=prompt_token_ids,
|
|
output_token_ids=new_output_token_ids,
|
|
)
|
|
new_seq_data.update_num_computed_tokens(
|
|
len(prompt_token_ids) + len(output_token_ids) - 1)
|
|
|
|
# Ensure that the new decode sequence has at least one token.
|
|
assert len(output_token_ids) >= 1
|
|
new_seq_data_dict = {target_seq_id: new_seq_data}
|
|
|
|
new_seq_group_metadata = SequenceGroupMetadata(
|
|
request_id=seq_group_metadata.request_id,
|
|
is_prompt=seq_group_metadata.is_prompt,
|
|
seq_data=new_seq_data_dict,
|
|
sampling_params=seq_group_metadata.sampling_params,
|
|
block_tables={
|
|
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
|
},
|
|
lora_request=None,
|
|
)
|
|
target_seq_group_metadata_list.append(new_seq_group_metadata)
|
|
|
|
target_sampler_output = self._scorer_worker.execute_model(
|
|
execute_model_req=execute_model_req.clone(
|
|
seq_group_metadata_list=target_seq_group_metadata_list))
|
|
|
|
target_sampler_output = target_sampler_output[0]
|
|
|
|
k = execute_model_req.num_lookahead_slots
|
|
bs = len(execute_model_req.seq_group_metadata_list)
|
|
target_token_ids = target_sampler_output.sampled_token_ids
|
|
target_probs = target_sampler_output.sampled_token_probs
|
|
target_logprobs = target_sampler_output.logprobs
|
|
prompt_logprobs = None
|
|
|
|
# If all requests have the same number of query tokens, we can avoid
|
|
# the for loop to build output for better performance.
|
|
if min(all_proposal_lengths) == k:
|
|
# Regular decodes only.
|
|
assert all(not sg.is_prompt
|
|
for sg in target_seq_group_metadata_list
|
|
if sg.is_prompt)
|
|
bs, _ = proposals.proposal_token_ids.shape
|
|
all_tokens = target_token_ids.reshape(bs, k + 1)
|
|
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
|
|
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
|
|
else:
|
|
# We either have decodes with different lens or prefill+decodes.
|
|
all_tokens = target_token_ids.new_full(size=(bs, k + 1),
|
|
fill_value=-1)
|
|
all_probs = target_probs.new_zeros(*all_tokens.shape,
|
|
self._vocab_size)
|
|
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
|
|
fill_value=-float("inf"))
|
|
target_token_ids = target_token_ids.flatten()
|
|
|
|
# When prompt logprobs is enabled, lens of returned tensors go from
|
|
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
|
|
# We adjust stride accordingly to get the generated tokens and
|
|
# their probs, but pass on prompt_logprobs as is, since it may be
|
|
# that n_prompts >> K.
|
|
has_prompt_log = any((sg.sampling_params.prompt_logprobs
|
|
and sg.sampling_params.prompt_logprobs > 0)
|
|
for sg in target_seq_group_metadata_list)
|
|
# TODO (NickLucche) we should surface `disable_logprobs` as to not
|
|
# break abstraction to get its value.
|
|
if (not self._scorer_worker.model_runner.disable_logprobs\
|
|
and has_prompt_log):
|
|
prompt_logprobs = [
|
|
o.prompt_logprobs for o in target_sampler_output.outputs
|
|
]
|
|
|
|
# Split loop into prefill|decode for readability.
|
|
start_loc, i = 0, 0
|
|
while i < len(target_seq_group_metadata_list
|
|
) and target_seq_group_metadata_list[i].is_prompt:
|
|
seq_meta = target_seq_group_metadata_list[i]
|
|
end_loc = start_loc
|
|
if has_prompt_log:
|
|
end_loc += seq_meta.token_chunk_size
|
|
elif seq_meta.do_sample:
|
|
end_loc += 1
|
|
|
|
# Skip chunks with no output tokens.
|
|
if seq_meta.do_sample:
|
|
# Get sampled token (last position in chunk) and its prob.
|
|
all_tokens[i, 0] = target_token_ids[end_loc - 1]
|
|
all_probs[i, 0] = target_probs[end_loc - 1]
|
|
all_logprobs[i, 0] = target_logprobs[end_loc - 1]
|
|
|
|
i += 1
|
|
start_loc = end_loc
|
|
# Decodes.
|
|
while i < len(target_seq_group_metadata_list):
|
|
proposed_len, seq_meta = all_proposal_lengths[
|
|
i], target_seq_group_metadata_list[i]
|
|
output_len = proposed_len + 1
|
|
end_loc = start_loc + output_len
|
|
all_tokens[
|
|
i, :output_len] = target_token_ids[start_loc:end_loc]
|
|
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
|
|
all_logprobs[
|
|
i, :output_len] = target_logprobs[start_loc:end_loc]
|
|
start_loc = end_loc
|
|
i += 1
|
|
|
|
hidden_states = None
|
|
if target_sampler_output.hidden_states is not None:
|
|
hidden_states = target_sampler_output.hidden_states.reshape(
|
|
bs, (k + 1), -1)
|
|
|
|
return SpeculativeScores(probs=all_probs,
|
|
token_ids=all_tokens,
|
|
logprobs=all_logprobs,
|
|
hidden_states=hidden_states,
|
|
prompt_logprobs=prompt_logprobs)
|