Files
enginex-hygon-vllm/vllm/spec_decode/medusa_worker.py
2026-01-09 15:09:53 +08:00

358 lines
14 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import weakref
from typing import List, Optional, Set, Tuple, Dict
import torch
import torch.nn.functional as F
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import DelegateWorkerBase
from vllm.spec_decode.tree_style_proposer import TreeStyleProposer
from vllm.distributed import broadcast_tensor_dict
from vllm.worker.worker_base import WorkerWrapperBase
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
"""Worker for Medusa.
"""
def __init__(self, *args, **kwargs):
# skip lora config in medusa
DelegateWorkerBase.__init__(self, *args, **kwargs)
# Lazy initialization list.
self._proposer: SpeculativeProposer
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def init_device(self):
self.worker.init_device()
def load_model(self):
super().load_model()
# get medusa choices and generate medusa_buffers
self.medusa_buffers = None
if self.tree_decoding and hasattr(self.model_runner.model, 'medusa_choices'):
self.medusa_choices = self.model_runner.model.medusa_choices
if self.medusa_choices is not None:
self.medusa_buffers = self.generate_medusa_buffers(
self.medusa_choices, device=self.device
)
if self.medusa_buffers is None:
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
else:
self._proposer = TreeStyleProposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
self.medusa_buffers,
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self):
pass
def set_should_modify_greedy_probs_inplace(self):
pass
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Dict[str, torch.Tensor]:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory, generators)
sample_indices_list = []
for seq_group in sampling_metadata.seq_groups:
sample_indices_list.append(seq_group.sample_indices)
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
previous_logits = execute_model_req.previous_logits.logits if \
execute_model_req.previous_logits is not None else None
tensor_dict = {
"previous_hidden_states": previous_hidden_states,
"previous_logits": previous_logits,
"sample_indices_list": sample_indices_list,
"seq_lens": seq_lens
}
if self.do_metadata_broadcast:
broadcast_tensor_dict(tensor_dict, src=0)
return tensor_dict
def _get_worker_input_from_broadcast(
self
) -> Optional[Dict[str, torch.Tensor]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
return broadcast_data
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
# Unused parameter.
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For medusa worker, this indicator shall be False.
"""
self._raise_if_unsupported(execute_model_req)
if self.is_driver_worker:
tensor_dict = self._get_driver_input_and_broadcast(execute_model_req)
else:
tensor_dict = self._get_worker_input_from_broadcast()
if tensor_dict is None:
raise ValueError("Can not get inputs of medusa worker!!!")
model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=tensor_dict["previous_hidden_states"],
sample_indices_list=tensor_dict["sample_indices_list"],
previous_logits=tensor_dict["previous_logits"],
medusa_buffers=self.medusa_buffers)
# create tree attn masks
if self.is_driver_worker and self.medusa_buffers is not None:
seq_lens = tensor_dict["seq_lens"]
max_context_len = max(seq_lens)
for sampler_output, seq_len in zip(model_outputs, seq_lens):
context_len = seq_len
attn_masks = self.medusa_buffers['tree_attn_masks']
left_mask = torch.ones(attn_masks.shape[0], context_len,
dtype=attn_masks.dtype,
device=attn_masks.device)
attn_masks = torch.cat([left_mask, attn_masks], dim=-1)
right_pad = max_context_len - context_len
if right_pad > 0:
attn_masks = F.pad(attn_masks, (0, right_pad), "constant", 0)
sampler_output.tree_attn_masks = attn_masks
return model_outputs, False
def _prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[List[int], List[int]]:
if not seq_group_metadata_list:
return [], []
seq_lens: List[int] = []
query_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
is_prompt = seq_group_metadata.is_prompt
for seq_data in seq_group_metadata.seq_data.values():
seq_data_len = seq_data.get_len()
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
seq_len = min(
seq_data_len,
context_len + seq_group_metadata.token_chunk_size)
seq_lens.append(seq_len)
query_lens.append(seq_len - context_len)
else:
# first step of tree decoding need to ignore first token
if self.medusa_buffers is not None and seq_data.get_first_step_flag():
seq_data_len -= 1
seq_lens.append(seq_data_len)
query_lens.append(1)
return seq_lens, query_lens
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
def _raise_if_unsupported(
self,
execute_model_req: ExecuteModelRequest,
) -> None:
"""MedusaWorker does not yet implement support for cache swap
operations or beam search.
"""
if execute_model_req is None:
return None
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"MedusaWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MedusaWorker does not support beam search.")
def pad_path(self, path, length, pad_value=-2):
"""
Pad the given path list with a specific value up to a specified length.
Parameters:
- path (list): The original list that needs padding.
- length (int): The desired length of the padded list.
- pad_value (optional, default=-2): The value to use for padding.
Returns:
- list: A new list based on the original path but padded to the desired length.
Example:
>>> pad_path([1,2,3], 5)
[1, 2, 3, -2, -2]
Note:
If the given path is already longer than the specified length,
then no padding occurs, and the original path is returned.
"""
# Calculate the number of padding values needed by subtracting the length
# of the path from the desired length.
# Append the padding values to the original path and return the new list.
return path + [pad_value] * (length - len(path))
def generate_medusa_buffers(self, medusa_choices, device="cuda"):
"""
Generate buffers for the Medusa structure based on the provided choices.
Parameters:
- medusa_choices (list): A nested list representing tree in the Medusa structure.
- device (str): Device to which the tensors should be moved. Default is "cuda".
Returns:
- dict: A dictionary containing buffers related to the Medusa structure.
"""
# Sort the medusa_choices based on their lengths and then their values
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
medusa_len = len(sorted_medusa_choices) + 1
# Initialize depth_counts to keep track of how many choices have a particular depth
depth_counts = []
prev_depth = 0
for path in sorted_medusa_choices:
depth = len(path)
if depth != prev_depth:
depth_counts.append(0)
depth_counts[depth - 1] += 1
prev_depth = depth
# Create the attention mask for Medusa
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
medusa_attn_mask[:, 0] = 1
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
# retrieve ancestor position
if len(cur_medusa_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_medusa_choice) - 1):
ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
start += depth_counts[i]
# Generate tree indices for the Medusa structure
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
medusa_tree_indices[0] = 0
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
start += depth_counts[i]
# Generate position IDs for the Medusa structure
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
start = 0
for i in range(len(depth_counts)):
medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
start += depth_counts[i]
# Generate retrieval indices for Medusa structure verification
retrieve_indices_nest = []
retrieve_paths = []
for i in range(len(sorted_medusa_choices)):
cur_medusa_choice = sorted_medusa_choices[-i-1]
retrieve_indice = []
if cur_medusa_choice in retrieve_paths:
continue
else:
for c in range(len(cur_medusa_choice)):
retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
retrieve_paths.append(cur_medusa_choice[:c+1])
retrieve_indices_nest.append(retrieve_indice)
max_length = max([len(x) for x in retrieve_indices_nest])
retrieve_indices = [self.pad_path(path, max_length) for path in retrieve_indices_nest]
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
retrieve_indices = retrieve_indices + 1
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
# Aggregate the generated buffers into a dictionary
medusa_buffers = {
"tree_attn_masks": medusa_attn_mask.int(),
"tree_indices": medusa_tree_indices,
"tree_position_ids": medusa_position_ids,
"retrieve_indices": retrieve_indices,
}
# Move the tensors in the dictionary to the specified device
medusa_buffers = {
k: v.clone().to(device)
if isinstance(v, torch.Tensor)
else torch.tensor(v, device=device)
for k, v in medusa_buffers.items()
}
return medusa_buffers