358 lines
14 KiB
Python
358 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import os
|
|
import weakref
|
|
from typing import List, Optional, Set, Tuple, Dict
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from vllm.model_executor import SamplingMetadata
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
|
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
|
|
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
|
from vllm.worker.worker_base import DelegateWorkerBase
|
|
|
|
from vllm.spec_decode.tree_style_proposer import TreeStyleProposer
|
|
from vllm.distributed import broadcast_tensor_dict
|
|
from vllm.worker.worker_base import WorkerWrapperBase
|
|
|
|
|
|
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
|
|
|
|
|
|
class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
|
|
"""Worker for Medusa.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
# skip lora config in medusa
|
|
DelegateWorkerBase.__init__(self, *args, **kwargs)
|
|
# Lazy initialization list.
|
|
self._proposer: SpeculativeProposer
|
|
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
|
|
|
|
|
|
def init_device(self):
|
|
self.worker.init_device()
|
|
|
|
def load_model(self):
|
|
super().load_model()
|
|
|
|
# get medusa choices and generate medusa_buffers
|
|
self.medusa_buffers = None
|
|
if self.tree_decoding and hasattr(self.model_runner.model, 'medusa_choices'):
|
|
self.medusa_choices = self.model_runner.model.medusa_choices
|
|
if self.medusa_choices is not None:
|
|
self.medusa_buffers = self.generate_medusa_buffers(
|
|
self.medusa_choices, device=self.device
|
|
)
|
|
|
|
if self.medusa_buffers is None:
|
|
self._proposer = Top1Proposer(
|
|
weakref.proxy(self), # type: ignore[arg-type]
|
|
self.device,
|
|
self.vocab_size,
|
|
max_proposal_len=self.max_model_len,
|
|
)
|
|
else:
|
|
self._proposer = TreeStyleProposer(
|
|
weakref.proxy(self), # type: ignore[arg-type]
|
|
self.device,
|
|
self.vocab_size,
|
|
self.medusa_buffers,
|
|
max_proposal_len=self.max_model_len,
|
|
)
|
|
|
|
|
|
def set_include_gpu_probs_tensor(self):
|
|
pass
|
|
|
|
def set_should_modify_greedy_probs_inplace(self):
|
|
pass
|
|
|
|
def _get_driver_input_and_broadcast(
|
|
self, execute_model_req: ExecuteModelRequest
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
|
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
|
|
|
seq_lens, query_lens = self._prepare_input_tensors(
|
|
seq_group_metadata_list)
|
|
|
|
generators = self.model_runner.get_generators(
|
|
execute_model_req.finished_requests_ids)
|
|
sampling_metadata = SamplingMetadata.prepare(
|
|
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
|
self.model_runner.pin_memory, generators)
|
|
|
|
sample_indices_list = []
|
|
for seq_group in sampling_metadata.seq_groups:
|
|
sample_indices_list.append(seq_group.sample_indices)
|
|
|
|
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
|
|
previous_logits = execute_model_req.previous_logits.logits if \
|
|
execute_model_req.previous_logits is not None else None
|
|
|
|
tensor_dict = {
|
|
"previous_hidden_states": previous_hidden_states,
|
|
"previous_logits": previous_logits,
|
|
"sample_indices_list": sample_indices_list,
|
|
"seq_lens": seq_lens
|
|
}
|
|
|
|
if self.do_metadata_broadcast:
|
|
broadcast_tensor_dict(tensor_dict, src=0)
|
|
|
|
return tensor_dict
|
|
|
|
def _get_worker_input_from_broadcast(
|
|
self
|
|
) -> Optional[Dict[str, torch.Tensor]]:
|
|
""" Get the worker input from the broadcasted tensor dict. """
|
|
assert self.do_metadata_broadcast
|
|
assert not self.is_driver_worker
|
|
broadcast_data = broadcast_tensor_dict(src=0)
|
|
|
|
return broadcast_data
|
|
|
|
@torch.inference_mode()
|
|
def sampler_output(
|
|
self,
|
|
execute_model_req: ExecuteModelRequest,
|
|
sample_len: int,
|
|
# Unused parameter.
|
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
|
) -> Tuple[List[SamplerOutput], bool]:
|
|
"""Run the model forward pass to generate sample_len future tokens.
|
|
Returns the list of sampler output, one per layer, along with indicator
|
|
of whether torch tensor in sampler output need to be transposed in
|
|
latter sampler_output_to_torch logic.
|
|
|
|
For medusa worker, this indicator shall be False.
|
|
"""
|
|
self._raise_if_unsupported(execute_model_req)
|
|
|
|
if self.is_driver_worker:
|
|
tensor_dict = self._get_driver_input_and_broadcast(execute_model_req)
|
|
else:
|
|
tensor_dict = self._get_worker_input_from_broadcast()
|
|
if tensor_dict is None:
|
|
raise ValueError("Can not get inputs of medusa worker!!!")
|
|
|
|
model_outputs = self.model_runner.model.generate_proposals(
|
|
previous_hidden_states=tensor_dict["previous_hidden_states"],
|
|
sample_indices_list=tensor_dict["sample_indices_list"],
|
|
previous_logits=tensor_dict["previous_logits"],
|
|
medusa_buffers=self.medusa_buffers)
|
|
|
|
# create tree attn masks
|
|
if self.is_driver_worker and self.medusa_buffers is not None:
|
|
seq_lens = tensor_dict["seq_lens"]
|
|
max_context_len = max(seq_lens)
|
|
for sampler_output, seq_len in zip(model_outputs, seq_lens):
|
|
context_len = seq_len
|
|
attn_masks = self.medusa_buffers['tree_attn_masks']
|
|
left_mask = torch.ones(attn_masks.shape[0], context_len,
|
|
dtype=attn_masks.dtype,
|
|
device=attn_masks.device)
|
|
attn_masks = torch.cat([left_mask, attn_masks], dim=-1)
|
|
right_pad = max_context_len - context_len
|
|
if right_pad > 0:
|
|
attn_masks = F.pad(attn_masks, (0, right_pad), "constant", 0)
|
|
sampler_output.tree_attn_masks = attn_masks
|
|
|
|
return model_outputs, False
|
|
|
|
def _prepare_input_tensors(
|
|
self,
|
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
|
) -> Tuple[List[int], List[int]]:
|
|
if not seq_group_metadata_list:
|
|
return [], []
|
|
|
|
seq_lens: List[int] = []
|
|
query_lens: List[int] = []
|
|
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
is_prompt = seq_group_metadata.is_prompt
|
|
|
|
for seq_data in seq_group_metadata.seq_data.values():
|
|
seq_data_len = seq_data.get_len()
|
|
if is_prompt:
|
|
context_len = seq_data.get_num_computed_tokens()
|
|
seq_len = min(
|
|
seq_data_len,
|
|
context_len + seq_group_metadata.token_chunk_size)
|
|
seq_lens.append(seq_len)
|
|
query_lens.append(seq_len - context_len)
|
|
else:
|
|
# first step of tree decoding need to ignore first token
|
|
if self.medusa_buffers is not None and seq_data.get_first_step_flag():
|
|
seq_data_len -= 1
|
|
seq_lens.append(seq_data_len)
|
|
query_lens.append(1)
|
|
|
|
return seq_lens, query_lens
|
|
|
|
def get_spec_proposals(
|
|
self,
|
|
execute_model_req: ExecuteModelRequest,
|
|
seq_ids_with_bonus_token_in_last_step: Set[int],
|
|
) -> SpeculativeProposals:
|
|
"""Produce speculations given an input batch of sequences. The number of
|
|
speculative tokens per sequence is determined by max_proposal_len.
|
|
"""
|
|
|
|
return self._proposer.get_spec_proposals(
|
|
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
|
|
|
def _raise_if_unsupported(
|
|
self,
|
|
execute_model_req: ExecuteModelRequest,
|
|
) -> None:
|
|
"""MedusaWorker does not yet implement support for cache swap
|
|
operations or beam search.
|
|
"""
|
|
if execute_model_req is None:
|
|
return None
|
|
|
|
if any([
|
|
execute_model_req.blocks_to_swap_in,
|
|
execute_model_req.blocks_to_swap_out,
|
|
execute_model_req.blocks_to_copy
|
|
]):
|
|
raise NotImplementedError(
|
|
"MedusaWorker does not support cache operations")
|
|
|
|
if any(
|
|
len(seq_group_metadata.seq_data.keys()) != 1
|
|
for seq_group_metadata in
|
|
execute_model_req.seq_group_metadata_list):
|
|
raise NotImplementedError(
|
|
"MedusaWorker does not support beam search.")
|
|
|
|
def pad_path(self, path, length, pad_value=-2):
|
|
"""
|
|
Pad the given path list with a specific value up to a specified length.
|
|
|
|
Parameters:
|
|
- path (list): The original list that needs padding.
|
|
- length (int): The desired length of the padded list.
|
|
- pad_value (optional, default=-2): The value to use for padding.
|
|
|
|
Returns:
|
|
- list: A new list based on the original path but padded to the desired length.
|
|
|
|
Example:
|
|
>>> pad_path([1,2,3], 5)
|
|
[1, 2, 3, -2, -2]
|
|
|
|
Note:
|
|
If the given path is already longer than the specified length,
|
|
then no padding occurs, and the original path is returned.
|
|
"""
|
|
|
|
# Calculate the number of padding values needed by subtracting the length
|
|
# of the path from the desired length.
|
|
# Append the padding values to the original path and return the new list.
|
|
return path + [pad_value] * (length - len(path))
|
|
|
|
def generate_medusa_buffers(self, medusa_choices, device="cuda"):
|
|
"""
|
|
Generate buffers for the Medusa structure based on the provided choices.
|
|
|
|
Parameters:
|
|
- medusa_choices (list): A nested list representing tree in the Medusa structure.
|
|
- device (str): Device to which the tensors should be moved. Default is "cuda".
|
|
|
|
Returns:
|
|
- dict: A dictionary containing buffers related to the Medusa structure.
|
|
"""
|
|
|
|
# Sort the medusa_choices based on their lengths and then their values
|
|
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
|
|
medusa_len = len(sorted_medusa_choices) + 1
|
|
|
|
# Initialize depth_counts to keep track of how many choices have a particular depth
|
|
depth_counts = []
|
|
prev_depth = 0
|
|
for path in sorted_medusa_choices:
|
|
depth = len(path)
|
|
if depth != prev_depth:
|
|
depth_counts.append(0)
|
|
depth_counts[depth - 1] += 1
|
|
prev_depth = depth
|
|
|
|
# Create the attention mask for Medusa
|
|
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
|
|
medusa_attn_mask[:, 0] = 1
|
|
start = 0
|
|
for i in range(len(depth_counts)):
|
|
for j in range(depth_counts[i]):
|
|
cur_medusa_choice = sorted_medusa_choices[start + j]
|
|
# retrieve ancestor position
|
|
if len(cur_medusa_choice) == 1:
|
|
continue
|
|
ancestor_idx = []
|
|
for c in range(len(cur_medusa_choice) - 1):
|
|
ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
|
|
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
|
|
start += depth_counts[i]
|
|
|
|
# Generate tree indices for the Medusa structure
|
|
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
|
|
medusa_tree_indices[0] = 0
|
|
start = 0
|
|
for i in range(len(depth_counts)):
|
|
for j in range(depth_counts[i]):
|
|
cur_medusa_choice = sorted_medusa_choices[start + j]
|
|
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
|
|
start += depth_counts[i]
|
|
|
|
# Generate position IDs for the Medusa structure
|
|
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
|
|
start = 0
|
|
for i in range(len(depth_counts)):
|
|
medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
|
|
start += depth_counts[i]
|
|
|
|
# Generate retrieval indices for Medusa structure verification
|
|
retrieve_indices_nest = []
|
|
retrieve_paths = []
|
|
for i in range(len(sorted_medusa_choices)):
|
|
cur_medusa_choice = sorted_medusa_choices[-i-1]
|
|
retrieve_indice = []
|
|
if cur_medusa_choice in retrieve_paths:
|
|
continue
|
|
else:
|
|
for c in range(len(cur_medusa_choice)):
|
|
retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
|
|
retrieve_paths.append(cur_medusa_choice[:c+1])
|
|
retrieve_indices_nest.append(retrieve_indice)
|
|
max_length = max([len(x) for x in retrieve_indices_nest])
|
|
retrieve_indices = [self.pad_path(path, max_length) for path in retrieve_indices_nest]
|
|
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
|
|
retrieve_indices = retrieve_indices + 1
|
|
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
|
|
|
|
# Aggregate the generated buffers into a dictionary
|
|
medusa_buffers = {
|
|
"tree_attn_masks": medusa_attn_mask.int(),
|
|
"tree_indices": medusa_tree_indices,
|
|
"tree_position_ids": medusa_position_ids,
|
|
"retrieve_indices": retrieve_indices,
|
|
}
|
|
|
|
# Move the tensors in the dictionary to the specified device
|
|
medusa_buffers = {
|
|
k: v.clone().to(device)
|
|
if isinstance(v, torch.Tensor)
|
|
else torch.tensor(v, device=device)
|
|
for k, v in medusa_buffers.items()
|
|
}
|
|
return medusa_buffers
|