# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import List, Optional, Set, Tuple, Dict import torch from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.distributed import broadcast_tensor_dict class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): """Worker for MLPSpeculator models. Not currently compatible with LoRA or chunked prefill. """ def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest, sample_len: int, index: int, last_tokens: Optional[torch.Tensor]=None, previous_hidden_states: Optional[torch.Tensor]=None, sampling_metadata: Optional[SamplingMetadata]=None ) -> Dict[str, torch.Tensor]: if sampling_metadata is None and execute_model_req is not None: seq_group_metadata_list = execute_model_req.seq_group_metadata_list (input_tokens, seq_lens, query_lens) = self._prepare_input_tensors(seq_group_metadata_list) # b x 1 last_tokens = input_tokens.unsqueeze(1) generators = self.model_runner.get_generators( execute_model_req.finished_requests_ids) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.model_runner.pin_memory, generators) previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states # b x 1 x d previous_hidden_states = previous_hidden_states.unsqueeze(1) tensor_dict = { "input_tokens": last_tokens, "previous_hidden_states": previous_hidden_states, "sample_len": sample_len, "head_index": index } if self.do_metadata_broadcast: broadcast_tensor_dict(tensor_dict, src=0) return tensor_dict, sampling_metadata def _get_worker_input_from_broadcast( self ) -> Optional[Dict[str, torch.Tensor]]: """ Get the worker input from the broadcasted tensor dict. """ assert self.do_metadata_broadcast assert not self.is_driver_worker broadcast_data = broadcast_tensor_dict(src=0) return broadcast_data @torch.inference_mode() def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, # Unused parameter. MLPSpeculatorWorker does not use the KV Cache and # therefore does not need this parameter. seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass to generate sample_len future tokens. Returns the list of sampler output, one per layer, along with indicator of whether torch tensor in sampler output need to be transposed in latter sampler_output_to_torch logic. For mlp spec worker, this indicator shall be True. """ self._raise_if_unsupported(execute_model_req) model_outputs = [] last_tokens = None previous_hidden_states = None sampling_metadata = None for index in range(sample_len): if self.is_driver_worker: tensor_dict, sampling_metadata = self._get_driver_input_and_broadcast(execute_model_req, sample_len, index, last_tokens, previous_hidden_states, sampling_metadata) assert sampling_metadata is not None output, previous_hidden_states = self.model_runner.model.generate_proposals( input_ids=tensor_dict["input_tokens"], previous_hidden_states=tensor_dict["previous_hidden_states"], num_predict_tokens=tensor_dict["sample_len"], sampling_metadata=sampling_metadata, head_index=index) last_tokens = output.sampled_token_ids model_outputs.append(output) else: tensor_dict = self._get_worker_input_from_broadcast() if tensor_dict is None: raise ValueError("Can not get inputs of mlp_speculator worker!!!") self.model_runner.model.generate_proposals( input_ids=tensor_dict["input_tokens"], previous_hidden_states=tensor_dict["previous_hidden_states"], num_predict_tokens=tensor_dict["sample_len"], sampling_metadata=None, head_index=tensor_dict["head_index"]) if self.is_driver_worker: assert len(model_outputs) == sample_len return model_outputs, True def _prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, List[int], List[int]]: if not seq_group_metadata_list: return torch.empty(0, device=self.device), [], [] input_tokens: List[int] = [] seq_lens: List[int] = [] query_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: is_prompt = seq_group_metadata.is_prompt for seq_data in seq_group_metadata.seq_data.values(): seq_data_len = seq_data.get_len() if is_prompt: context_len = seq_data.get_num_computed_tokens() seq_len = min( seq_data_len, context_len + seq_group_metadata.token_chunk_size) tokens = seq_data.get_token_ids()[context_len:seq_len] seq_lens.append(seq_len) input_tokens.extend(tokens) query_lens.append(seq_len - context_len) else: seq_lens.append(seq_data_len) input_tokens.append(seq_data.get_last_token_id()) query_lens.append(1) input_tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, device=self.device) return input_tokens_tensor, seq_lens, query_lens