# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os import weakref from typing import List, Optional, Set, Tuple, Dict import torch import torch.nn.functional as F from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker_base import DelegateWorkerBase from vllm.spec_decode.tree_style_proposer import TreeStyleProposer from vllm.distributed import broadcast_tensor_dict from vllm.worker.worker_base import WorkerWrapperBase TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient) class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase): """Worker for Medusa. """ def __init__(self, *args, **kwargs): # skip lora config in medusa DelegateWorkerBase.__init__(self, *args, **kwargs) # Lazy initialization list. self._proposer: SpeculativeProposer self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1') def init_device(self): self.worker.init_device() def load_model(self): super().load_model() # get medusa choices and generate medusa_buffers self.medusa_buffers = None if self.tree_decoding and hasattr(self.model_runner.model, 'medusa_choices'): self.medusa_choices = self.model_runner.model.medusa_choices if self.medusa_choices is not None: self.medusa_buffers = self.generate_medusa_buffers( self.medusa_choices, device=self.device ) if self.medusa_buffers is None: self._proposer = Top1Proposer( weakref.proxy(self), # type: ignore[arg-type] self.device, self.vocab_size, max_proposal_len=self.max_model_len, ) else: self._proposer = TreeStyleProposer( weakref.proxy(self), # type: ignore[arg-type] self.device, self.vocab_size, self.medusa_buffers, max_proposal_len=self.max_model_len, ) def set_include_gpu_probs_tensor(self): pass def set_should_modify_greedy_probs_inplace(self): pass def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest ) -> Dict[str, torch.Tensor]: seq_group_metadata_list = execute_model_req.seq_group_metadata_list seq_lens, query_lens = self._prepare_input_tensors( seq_group_metadata_list) generators = self.model_runner.get_generators( execute_model_req.finished_requests_ids) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.model_runner.pin_memory, generators) sample_indices_list = [] for seq_group in sampling_metadata.seq_groups: sample_indices_list.append(seq_group.sample_indices) previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states previous_logits = execute_model_req.previous_logits.logits if \ execute_model_req.previous_logits is not None else None tensor_dict = { "previous_hidden_states": previous_hidden_states, "previous_logits": previous_logits, "sample_indices_list": sample_indices_list, "seq_lens": seq_lens } if self.do_metadata_broadcast: broadcast_tensor_dict(tensor_dict, src=0) return tensor_dict def _get_worker_input_from_broadcast( self ) -> Optional[Dict[str, torch.Tensor]]: """ Get the worker input from the broadcasted tensor dict. """ assert self.do_metadata_broadcast assert not self.is_driver_worker broadcast_data = broadcast_tensor_dict(src=0) return broadcast_data @torch.inference_mode() def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, # Unused parameter. seq_ids_with_bonus_token_in_last_step: Set[int], ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass to generate sample_len future tokens. Returns the list of sampler output, one per layer, along with indicator of whether torch tensor in sampler output need to be transposed in latter sampler_output_to_torch logic. For medusa worker, this indicator shall be False. """ self._raise_if_unsupported(execute_model_req) if self.is_driver_worker: tensor_dict = self._get_driver_input_and_broadcast(execute_model_req) else: tensor_dict = self._get_worker_input_from_broadcast() if tensor_dict is None: raise ValueError("Can not get inputs of medusa worker!!!") model_outputs = self.model_runner.model.generate_proposals( previous_hidden_states=tensor_dict["previous_hidden_states"], sample_indices_list=tensor_dict["sample_indices_list"], previous_logits=tensor_dict["previous_logits"], medusa_buffers=self.medusa_buffers) # create tree attn masks if self.is_driver_worker and self.medusa_buffers is not None: seq_lens = tensor_dict["seq_lens"] max_context_len = max(seq_lens) for sampler_output, seq_len in zip(model_outputs, seq_lens): context_len = seq_len attn_masks = self.medusa_buffers['tree_attn_masks'] left_mask = torch.ones(attn_masks.shape[0], context_len, dtype=attn_masks.dtype, device=attn_masks.device) attn_masks = torch.cat([left_mask, attn_masks], dim=-1) right_pad = max_context_len - context_len if right_pad > 0: attn_masks = F.pad(attn_masks, (0, right_pad), "constant", 0) sampler_output.tree_attn_masks = attn_masks return model_outputs, False def _prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[List[int], List[int]]: if not seq_group_metadata_list: return [], [] seq_lens: List[int] = [] query_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: is_prompt = seq_group_metadata.is_prompt for seq_data in seq_group_metadata.seq_data.values(): seq_data_len = seq_data.get_len() if is_prompt: context_len = seq_data.get_num_computed_tokens() seq_len = min( seq_data_len, context_len + seq_group_metadata.token_chunk_size) seq_lens.append(seq_len) query_lens.append(seq_len - context_len) else: # first step of tree decoding need to ignore first token if self.medusa_buffers is not None and seq_data.get_first_step_flag(): seq_data_len -= 1 seq_lens.append(seq_data_len) query_lens.append(1) return seq_lens, query_lens def get_spec_proposals( self, execute_model_req: ExecuteModelRequest, seq_ids_with_bonus_token_in_last_step: Set[int], ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ return self._proposer.get_spec_proposals( execute_model_req, seq_ids_with_bonus_token_in_last_step) def _raise_if_unsupported( self, execute_model_req: ExecuteModelRequest, ) -> None: """MedusaWorker does not yet implement support for cache swap operations or beam search. """ if execute_model_req is None: return None if any([ execute_model_req.blocks_to_swap_in, execute_model_req.blocks_to_swap_out, execute_model_req.blocks_to_copy ]): raise NotImplementedError( "MedusaWorker does not support cache operations") if any( len(seq_group_metadata.seq_data.keys()) != 1 for seq_group_metadata in execute_model_req.seq_group_metadata_list): raise NotImplementedError( "MedusaWorker does not support beam search.") def pad_path(self, path, length, pad_value=-2): """ Pad the given path list with a specific value up to a specified length. Parameters: - path (list): The original list that needs padding. - length (int): The desired length of the padded list. - pad_value (optional, default=-2): The value to use for padding. Returns: - list: A new list based on the original path but padded to the desired length. Example: >>> pad_path([1,2,3], 5) [1, 2, 3, -2, -2] Note: If the given path is already longer than the specified length, then no padding occurs, and the original path is returned. """ # Calculate the number of padding values needed by subtracting the length # of the path from the desired length. # Append the padding values to the original path and return the new list. return path + [pad_value] * (length - len(path)) def generate_medusa_buffers(self, medusa_choices, device="cuda"): """ Generate buffers for the Medusa structure based on the provided choices. Parameters: - medusa_choices (list): A nested list representing tree in the Medusa structure. - device (str): Device to which the tensors should be moved. Default is "cuda". Returns: - dict: A dictionary containing buffers related to the Medusa structure. """ # Sort the medusa_choices based on their lengths and then their values sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x)) medusa_len = len(sorted_medusa_choices) + 1 # Initialize depth_counts to keep track of how many choices have a particular depth depth_counts = [] prev_depth = 0 for path in sorted_medusa_choices: depth = len(path) if depth != prev_depth: depth_counts.append(0) depth_counts[depth - 1] += 1 prev_depth = depth # Create the attention mask for Medusa medusa_attn_mask = torch.eye(medusa_len, medusa_len) medusa_attn_mask[:, 0] = 1 start = 0 for i in range(len(depth_counts)): for j in range(depth_counts[i]): cur_medusa_choice = sorted_medusa_choices[start + j] # retrieve ancestor position if len(cur_medusa_choice) == 1: continue ancestor_idx = [] for c in range(len(cur_medusa_choice) - 1): ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1) medusa_attn_mask[j + start + 1, ancestor_idx] = 1 start += depth_counts[i] # Generate tree indices for the Medusa structure medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long) medusa_tree_indices[0] = 0 start = 0 for i in range(len(depth_counts)): for j in range(depth_counts[i]): cur_medusa_choice = sorted_medusa_choices[start + j] medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1 start += depth_counts[i] # Generate position IDs for the Medusa structure medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long) start = 0 for i in range(len(depth_counts)): medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1 start += depth_counts[i] # Generate retrieval indices for Medusa structure verification retrieve_indices_nest = [] retrieve_paths = [] for i in range(len(sorted_medusa_choices)): cur_medusa_choice = sorted_medusa_choices[-i-1] retrieve_indice = [] if cur_medusa_choice in retrieve_paths: continue else: for c in range(len(cur_medusa_choice)): retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1])) retrieve_paths.append(cur_medusa_choice[:c+1]) retrieve_indices_nest.append(retrieve_indice) max_length = max([len(x) for x in retrieve_indices_nest]) retrieve_indices = [self.pad_path(path, max_length) for path in retrieve_indices_nest] retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) retrieve_indices = retrieve_indices + 1 retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1) # Aggregate the generated buffers into a dictionary medusa_buffers = { "tree_attn_masks": medusa_attn_mask.int(), "tree_indices": medusa_tree_indices, "tree_position_ids": medusa_position_ids, "retrieve_indices": retrieve_indices, } # Move the tensors in the dictionary to the specified device medusa_buffers = { k: v.clone().to(device) if isinstance(v, torch.Tensor) else torch.tensor(v, device=device) for k, v in medusa_buffers.items() } return medusa_buffers