init src 0.9.2
This commit is contained in:
165
vllm/spec_decode/mlp_speculator_worker.py
Normal file
165
vllm/spec_decode/mlp_speculator_worker.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Set, Tuple, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
|
||||
|
||||
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||
"""Worker for MLPSpeculator models.
|
||||
|
||||
Not currently compatible with LoRA or chunked prefill.
|
||||
"""
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
index: int,
|
||||
last_tokens: Optional[torch.Tensor]=None,
|
||||
previous_hidden_states: Optional[torch.Tensor]=None,
|
||||
sampling_metadata: Optional[SamplingMetadata]=None
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
|
||||
if sampling_metadata is None and execute_model_req is not None:
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
(input_tokens, seq_lens,
|
||||
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
# b x 1
|
||||
last_tokens = input_tokens.unsqueeze(1)
|
||||
|
||||
generators = self.model_runner.get_generators(
|
||||
execute_model_req.finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.model_runner.pin_memory, generators)
|
||||
|
||||
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
|
||||
|
||||
# b x 1 x d
|
||||
previous_hidden_states = previous_hidden_states.unsqueeze(1)
|
||||
|
||||
tensor_dict = {
|
||||
"input_tokens": last_tokens,
|
||||
"previous_hidden_states": previous_hidden_states,
|
||||
"sample_len": sample_len,
|
||||
"head_index": index
|
||||
}
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_tensor_dict(tensor_dict, src=0)
|
||||
|
||||
return tensor_dict, sampling_metadata
|
||||
|
||||
def _get_worker_input_from_broadcast(
|
||||
self
|
||||
) -> Optional[Dict[str, torch.Tensor]]:
|
||||
""" Get the worker input from the broadcasted tensor dict. """
|
||||
assert self.do_metadata_broadcast
|
||||
assert not self.is_driver_worker
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
|
||||
return broadcast_data
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass to generate sample_len future tokens.
|
||||
Returns the list of sampler output, one per layer, along with indicator
|
||||
of whether torch tensor in sampler output need to be transposed in
|
||||
latter sampler_output_to_torch logic.
|
||||
|
||||
For mlp spec worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
model_outputs = []
|
||||
last_tokens = None
|
||||
previous_hidden_states = None
|
||||
sampling_metadata = None
|
||||
|
||||
for index in range(sample_len):
|
||||
if self.is_driver_worker:
|
||||
tensor_dict, sampling_metadata = self._get_driver_input_and_broadcast(execute_model_req,
|
||||
sample_len,
|
||||
index,
|
||||
last_tokens,
|
||||
previous_hidden_states,
|
||||
sampling_metadata)
|
||||
assert sampling_metadata is not None
|
||||
|
||||
output, previous_hidden_states = self.model_runner.model.generate_proposals(
|
||||
input_ids=tensor_dict["input_tokens"],
|
||||
previous_hidden_states=tensor_dict["previous_hidden_states"],
|
||||
num_predict_tokens=tensor_dict["sample_len"],
|
||||
sampling_metadata=sampling_metadata,
|
||||
head_index=index)
|
||||
last_tokens = output.sampled_token_ids
|
||||
|
||||
model_outputs.append(output)
|
||||
else:
|
||||
tensor_dict = self._get_worker_input_from_broadcast()
|
||||
if tensor_dict is None:
|
||||
raise ValueError("Can not get inputs of mlp_speculator worker!!!")
|
||||
self.model_runner.model.generate_proposals(
|
||||
input_ids=tensor_dict["input_tokens"],
|
||||
previous_hidden_states=tensor_dict["previous_hidden_states"],
|
||||
num_predict_tokens=tensor_dict["sample_len"],
|
||||
sampling_metadata=None,
|
||||
head_index=tensor_dict["head_index"])
|
||||
|
||||
if self.is_driver_worker:
|
||||
assert len(model_outputs) == sample_len
|
||||
|
||||
return model_outputs, True
|
||||
|
||||
def _prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[torch.Tensor, List[int], List[int]]:
|
||||
if not seq_group_metadata_list:
|
||||
return torch.empty(0, device=self.device), [], []
|
||||
|
||||
input_tokens: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seq_data_len = seq_data.get_len()
|
||||
if is_prompt:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = min(
|
||||
seq_data_len,
|
||||
context_len + seq_group_metadata.token_chunk_size)
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
seq_lens.append(seq_len)
|
||||
input_tokens.extend(tokens)
|
||||
query_lens.append(seq_len - context_len)
|
||||
else:
|
||||
seq_lens.append(seq_data_len)
|
||||
input_tokens.append(seq_data.get_last_token_id())
|
||||
query_lens.append(1)
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
return input_tokens_tensor, seq_lens, query_lens
|
||||
Reference in New Issue
Block a user