forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
357
vllm/spec_decode/medusa_worker.py
Normal file
357
vllm/spec_decode/medusa_worker.py
Normal file
@@ -0,0 +1,357 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import weakref
|
||||
from typing import List, Optional, Set, Tuple, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import DelegateWorkerBase
|
||||
|
||||
from vllm.spec_decode.tree_style_proposer import TreeStyleProposer
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
|
||||
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
|
||||
|
||||
|
||||
class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
|
||||
"""Worker for Medusa.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# skip lora config in medusa
|
||||
DelegateWorkerBase.__init__(self, *args, **kwargs)
|
||||
# Lazy initialization list.
|
||||
self._proposer: SpeculativeProposer
|
||||
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
|
||||
|
||||
|
||||
def init_device(self):
|
||||
self.worker.init_device()
|
||||
|
||||
def load_model(self):
|
||||
super().load_model()
|
||||
|
||||
# get medusa choices and generate medusa_buffers
|
||||
self.medusa_buffers = None
|
||||
if self.tree_decoding and hasattr(self.model_runner.model, 'medusa_choices'):
|
||||
self.medusa_choices = self.model_runner.model.medusa_choices
|
||||
if self.medusa_choices is not None:
|
||||
self.medusa_buffers = self.generate_medusa_buffers(
|
||||
self.medusa_choices, device=self.device
|
||||
)
|
||||
|
||||
if self.medusa_buffers is None:
|
||||
self._proposer = Top1Proposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
self.device,
|
||||
self.vocab_size,
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
else:
|
||||
self._proposer = TreeStyleProposer(
|
||||
weakref.proxy(self), # type: ignore[arg-type]
|
||||
self.device,
|
||||
self.vocab_size,
|
||||
self.medusa_buffers,
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
pass
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self):
|
||||
pass
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
seq_lens, query_lens = self._prepare_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
|
||||
generators = self.model_runner.get_generators(
|
||||
execute_model_req.finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.model_runner.pin_memory, generators)
|
||||
|
||||
sample_indices_list = []
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
sample_indices_list.append(seq_group.sample_indices)
|
||||
|
||||
previous_hidden_states = execute_model_req.previous_hidden_states.hidden_states
|
||||
previous_logits = execute_model_req.previous_logits.logits if \
|
||||
execute_model_req.previous_logits is not None else None
|
||||
|
||||
tensor_dict = {
|
||||
"previous_hidden_states": previous_hidden_states,
|
||||
"previous_logits": previous_logits,
|
||||
"sample_indices_list": sample_indices_list,
|
||||
"seq_lens": seq_lens
|
||||
}
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_tensor_dict(tensor_dict, src=0)
|
||||
|
||||
return tensor_dict
|
||||
|
||||
def _get_worker_input_from_broadcast(
|
||||
self
|
||||
) -> Optional[Dict[str, torch.Tensor]]:
|
||||
""" Get the worker input from the broadcasted tensor dict. """
|
||||
assert self.do_metadata_broadcast
|
||||
assert not self.is_driver_worker
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
|
||||
return broadcast_data
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass to generate sample_len future tokens.
|
||||
Returns the list of sampler output, one per layer, along with indicator
|
||||
of whether torch tensor in sampler output need to be transposed in
|
||||
latter sampler_output_to_torch logic.
|
||||
|
||||
For medusa worker, this indicator shall be False.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
if self.is_driver_worker:
|
||||
tensor_dict = self._get_driver_input_and_broadcast(execute_model_req)
|
||||
else:
|
||||
tensor_dict = self._get_worker_input_from_broadcast()
|
||||
if tensor_dict is None:
|
||||
raise ValueError("Can not get inputs of medusa worker!!!")
|
||||
|
||||
model_outputs = self.model_runner.model.generate_proposals(
|
||||
previous_hidden_states=tensor_dict["previous_hidden_states"],
|
||||
sample_indices_list=tensor_dict["sample_indices_list"],
|
||||
previous_logits=tensor_dict["previous_logits"],
|
||||
medusa_buffers=self.medusa_buffers)
|
||||
|
||||
# create tree attn masks
|
||||
if self.is_driver_worker and self.medusa_buffers is not None:
|
||||
seq_lens = tensor_dict["seq_lens"]
|
||||
max_context_len = max(seq_lens)
|
||||
for sampler_output, seq_len in zip(model_outputs, seq_lens):
|
||||
context_len = seq_len
|
||||
attn_masks = self.medusa_buffers['tree_attn_masks']
|
||||
left_mask = torch.ones(attn_masks.shape[0], context_len,
|
||||
dtype=attn_masks.dtype,
|
||||
device=attn_masks.device)
|
||||
attn_masks = torch.cat([left_mask, attn_masks], dim=-1)
|
||||
right_pad = max_context_len - context_len
|
||||
if right_pad > 0:
|
||||
attn_masks = F.pad(attn_masks, (0, right_pad), "constant", 0)
|
||||
sampler_output.tree_attn_masks = attn_masks
|
||||
|
||||
return model_outputs, False
|
||||
|
||||
def _prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if not seq_group_metadata_list:
|
||||
return [], []
|
||||
|
||||
seq_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seq_data_len = seq_data.get_len()
|
||||
if is_prompt:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = min(
|
||||
seq_data_len,
|
||||
context_len + seq_group_metadata.token_chunk_size)
|
||||
seq_lens.append(seq_len)
|
||||
query_lens.append(seq_len - context_len)
|
||||
else:
|
||||
# first step of tree decoding need to ignore first token
|
||||
if self.medusa_buffers is not None and seq_data.get_first_step_flag():
|
||||
seq_data_len -= 1
|
||||
seq_lens.append(seq_data_len)
|
||||
query_lens.append(1)
|
||||
|
||||
return seq_lens, query_lens
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> None:
|
||||
"""MedusaWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if execute_model_req is None:
|
||||
return None
|
||||
|
||||
if any([
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
execute_model_req.blocks_to_copy
|
||||
]):
|
||||
raise NotImplementedError(
|
||||
"MedusaWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MedusaWorker does not support beam search.")
|
||||
|
||||
def pad_path(self, path, length, pad_value=-2):
|
||||
"""
|
||||
Pad the given path list with a specific value up to a specified length.
|
||||
|
||||
Parameters:
|
||||
- path (list): The original list that needs padding.
|
||||
- length (int): The desired length of the padded list.
|
||||
- pad_value (optional, default=-2): The value to use for padding.
|
||||
|
||||
Returns:
|
||||
- list: A new list based on the original path but padded to the desired length.
|
||||
|
||||
Example:
|
||||
>>> pad_path([1,2,3], 5)
|
||||
[1, 2, 3, -2, -2]
|
||||
|
||||
Note:
|
||||
If the given path is already longer than the specified length,
|
||||
then no padding occurs, and the original path is returned.
|
||||
"""
|
||||
|
||||
# Calculate the number of padding values needed by subtracting the length
|
||||
# of the path from the desired length.
|
||||
# Append the padding values to the original path and return the new list.
|
||||
return path + [pad_value] * (length - len(path))
|
||||
|
||||
def generate_medusa_buffers(self, medusa_choices, device="cuda"):
|
||||
"""
|
||||
Generate buffers for the Medusa structure based on the provided choices.
|
||||
|
||||
Parameters:
|
||||
- medusa_choices (list): A nested list representing tree in the Medusa structure.
|
||||
- device (str): Device to which the tensors should be moved. Default is "cuda".
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary containing buffers related to the Medusa structure.
|
||||
"""
|
||||
|
||||
# Sort the medusa_choices based on their lengths and then their values
|
||||
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
|
||||
medusa_len = len(sorted_medusa_choices) + 1
|
||||
|
||||
# Initialize depth_counts to keep track of how many choices have a particular depth
|
||||
depth_counts = []
|
||||
prev_depth = 0
|
||||
for path in sorted_medusa_choices:
|
||||
depth = len(path)
|
||||
if depth != prev_depth:
|
||||
depth_counts.append(0)
|
||||
depth_counts[depth - 1] += 1
|
||||
prev_depth = depth
|
||||
|
||||
# Create the attention mask for Medusa
|
||||
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
|
||||
medusa_attn_mask[:, 0] = 1
|
||||
start = 0
|
||||
for i in range(len(depth_counts)):
|
||||
for j in range(depth_counts[i]):
|
||||
cur_medusa_choice = sorted_medusa_choices[start + j]
|
||||
# retrieve ancestor position
|
||||
if len(cur_medusa_choice) == 1:
|
||||
continue
|
||||
ancestor_idx = []
|
||||
for c in range(len(cur_medusa_choice) - 1):
|
||||
ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
|
||||
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
|
||||
start += depth_counts[i]
|
||||
|
||||
# Generate tree indices for the Medusa structure
|
||||
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
|
||||
medusa_tree_indices[0] = 0
|
||||
start = 0
|
||||
for i in range(len(depth_counts)):
|
||||
for j in range(depth_counts[i]):
|
||||
cur_medusa_choice = sorted_medusa_choices[start + j]
|
||||
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
|
||||
start += depth_counts[i]
|
||||
|
||||
# Generate position IDs for the Medusa structure
|
||||
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
|
||||
start = 0
|
||||
for i in range(len(depth_counts)):
|
||||
medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
|
||||
start += depth_counts[i]
|
||||
|
||||
# Generate retrieval indices for Medusa structure verification
|
||||
retrieve_indices_nest = []
|
||||
retrieve_paths = []
|
||||
for i in range(len(sorted_medusa_choices)):
|
||||
cur_medusa_choice = sorted_medusa_choices[-i-1]
|
||||
retrieve_indice = []
|
||||
if cur_medusa_choice in retrieve_paths:
|
||||
continue
|
||||
else:
|
||||
for c in range(len(cur_medusa_choice)):
|
||||
retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
|
||||
retrieve_paths.append(cur_medusa_choice[:c+1])
|
||||
retrieve_indices_nest.append(retrieve_indice)
|
||||
max_length = max([len(x) for x in retrieve_indices_nest])
|
||||
retrieve_indices = [self.pad_path(path, max_length) for path in retrieve_indices_nest]
|
||||
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
|
||||
retrieve_indices = retrieve_indices + 1
|
||||
retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
|
||||
|
||||
# Aggregate the generated buffers into a dictionary
|
||||
medusa_buffers = {
|
||||
"tree_attn_masks": medusa_attn_mask.int(),
|
||||
"tree_indices": medusa_tree_indices,
|
||||
"tree_position_ids": medusa_position_ids,
|
||||
"retrieve_indices": retrieve_indices,
|
||||
}
|
||||
|
||||
# Move the tensors in the dictionary to the specified device
|
||||
medusa_buffers = {
|
||||
k: v.clone().to(device)
|
||||
if isinstance(v, torch.Tensor)
|
||||
else torch.tensor(v, device=device)
|
||||
for k, v in medusa_buffers.items()
|
||||
}
|
||||
return medusa_buffers
|
||||
Reference in New Issue
Block a user