266 lines
11 KiB
Python
266 lines
11 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
from functools import wraps
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
from fastcore.basics import patch_to
|
|
|
|
import vllm_br.envs as biren_envs
|
|
from vllm.logger import init_logger
|
|
from vllm.utils import is_pin_memory_available
|
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.spec_decode.eagle import EagleProposer
|
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
from vllm_br.v1.worker.model_runner import SUPACommonAttentionMetadata
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
PADDING_SLOT_ID = -1
|
|
|
|
|
|
def wrapper_EagleProposer_init(fn):
|
|
# FIXME: temporary fix for enabling MLA in EagleProposer
|
|
@wraps(fn)
|
|
def wrapper(self, *args, **kwargs):
|
|
fn(self, *args, **kwargs)
|
|
self.draft_model_config.weight_type = biren_envs.VLLM_BR_WEIGHT_TYPE
|
|
self.draft_model_config.use_ds_mla = True
|
|
self.draft_model_config.use_ds_mla_sparse = hasattr(
|
|
self.draft_model_config.hf_config, "index_topk")
|
|
|
|
return wrapper
|
|
|
|
|
|
EagleProposer.__init__ = wrapper_EagleProposer_init(
|
|
EagleProposer.__init__) # noqa: E501
|
|
|
|
|
|
@patch_to(EagleProposer)
|
|
def prepare_inputs(
|
|
self,
|
|
common_attn_metadata: SUPACommonAttentionMetadata,
|
|
sampled_token_ids: list[list[int]],
|
|
num_draft_tokens: list[int],
|
|
) -> tuple[SUPACommonAttentionMetadata, torch.Tensor]:
|
|
"""
|
|
This function is used to prepare the inputs for speculative decoding.
|
|
It updates to the common_attn_metadata to account for the rejected
|
|
tokens (and newly sampled tokens). It also returns the token indices
|
|
of the tokens that should be fed to the speculator.
|
|
"""
|
|
# E.g.
|
|
# common_attn_metadata.query_start_loc{_cpu}:
|
|
# [0, q1, q1 + q2, q1 + q2 + q3]
|
|
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
|
|
# num_rejected_tokens: [n1, n2, n3]
|
|
# This function computes the intermediate values:
|
|
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
|
|
# And returns:
|
|
# common_attn_metadata.query_start_loc{_cpu}:
|
|
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
|
# common_attn_metadata.seq_lens{_cpu}:
|
|
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
|
# token_indices: [0, 1, ..., q1 - n1 - 1,
|
|
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
|
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
|
|
|
num_rejected_tokens = [
|
|
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
|
for i, n in enumerate(num_draft_tokens)
|
|
]
|
|
num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
|
|
|
|
device = common_attn_metadata.query_start_loc.device
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
|
|
- num_rejected_tokens
|
|
|
|
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
|
new_query_len_per_req = (query_start_loc_cpu[1:] -
|
|
query_start_loc_cpu[:-1])
|
|
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
|
|
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
|
|
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
|
|
|
|
# [q1 - n1, q2 - n2, q3 - n3] ->
|
|
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
|
new_query_start_loc_cpu = torch.zeros(query_start_loc_cpu.shape,
|
|
dtype=torch.int32,
|
|
pin_memory=is_pin_memory_available())
|
|
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
|
|
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
|
|
|
|
total_num_tokens = new_query_start_loc_np[-1]
|
|
# Example assuming num_tokens_per_req_np = [2, 4, 3]
|
|
# this implies that `new_query_start_locs` is:
|
|
# [0, 2, 6, 9] ->
|
|
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
|
|
# _r1_ ____r2____ ___r3__
|
|
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
|
|
new_num_tokens_per_req_np)
|
|
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
|
|
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
|
|
# _r1_ ____r2____ ___r3__
|
|
token_offests = self.token_arange_np[:total_num_tokens] \
|
|
- new_query_start_locs_expanded
|
|
|
|
# Expand starting positions to match token pattern
|
|
# [0, q1, q1 + q2] ->
|
|
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
|
|
# _r1_ _____r2_______ ___________r3____________
|
|
old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(),
|
|
new_num_tokens_per_req_np)
|
|
# Final token indices are:
|
|
# [0, 1, // req 1
|
|
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
|
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
|
token_indices_np = token_offests + old_query_start_locs_expanded
|
|
token_indices = torch.from_numpy(token_indices_np).to(device,
|
|
non_blocking=True)
|
|
|
|
# seq_start_loc = torch.from_numpy(
|
|
# np.insert(np.add.accumulate(common_attn_metadata.seq_lens.cpu().numpy()), 0,
|
|
# 0)).to(common_attn_metadata.query_start_loc, non_blocking=True)
|
|
spec_common_attn_metadata = SUPACommonAttentionMetadata(
|
|
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
|
|
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
|
query_start_loc_cpu=new_query_start_loc_cpu,
|
|
seq_lens_cpu=new_seq_lens_cpu,
|
|
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
|
num_reqs=common_attn_metadata.num_reqs,
|
|
num_actual_tokens=total_num_tokens,
|
|
max_query_len=new_query_len_per_req.max().item(),
|
|
max_seq_len=new_seq_lens_cpu.max().item(),
|
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
|
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
|
causal=True,
|
|
# seq_start_loc=seq_start_loc
|
|
)
|
|
return spec_common_attn_metadata, token_indices
|
|
|
|
|
|
@patch_to(EagleProposer)
|
|
def prepare_inputs_padded(self,
|
|
common_attn_metadata: SUPACommonAttentionMetadata,
|
|
spec_decode_metadata: SpecDecodeMetadata,
|
|
valid_sampled_tokens_count: torch.Tensor) -> \
|
|
tuple[SUPACommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
This function is used to prepare the inputs for speculative decoding
|
|
It updates the common_attn_metadata for speculative decoding,
|
|
but does not consider the rejected tokens. Instead, all tokens
|
|
are included as inputs to the speculator, with the rejected tokens
|
|
used as padding and filtered out later by `token_indices_to_sample`.
|
|
No blocking CPU operations should be introduced in this function.
|
|
"""
|
|
num_draft_tokens_gpu = torch.cat([
|
|
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
|
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
|
spec_decode_metadata.cu_num_draft_tokens[:-1]
|
|
])
|
|
|
|
num_rejected_tokens_gpu = torch.where(
|
|
num_draft_tokens_gpu > 0,
|
|
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
|
torch.zeros_like(num_draft_tokens_gpu))
|
|
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
|
|
new_query_len_per_req = (query_start_loc_cpu[1:] -
|
|
query_start_loc_cpu[:-1])
|
|
|
|
total_num_tokens = query_start_loc_cpu[-1].item()
|
|
token_indices = self.arange[:total_num_tokens]
|
|
|
|
seq_start_loc = torch.from_numpy(
|
|
np.insert(
|
|
np.add.accumulate(common_attn_metadata.seq_lens.cpu().numpy()), 0,
|
|
0)).to(common_attn_metadata.query_start_loc, non_blocking=True)
|
|
spec_common_attn_metadata = SUPACommonAttentionMetadata(
|
|
query_start_loc=common_attn_metadata.query_start_loc,
|
|
seq_lens=common_attn_metadata.seq_lens,
|
|
query_start_loc_cpu=query_start_loc_cpu,
|
|
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
|
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
|
num_reqs=common_attn_metadata.num_reqs,
|
|
num_actual_tokens=total_num_tokens,
|
|
max_query_len=new_query_len_per_req.max().item(),
|
|
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
|
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
|
slot_mapping=common_attn_metadata.slot_mapping[token_indices.long()],
|
|
causal=True,
|
|
# context_lens=context_lens,
|
|
# max_decode_seq_len=self.seq_lens.np[:num_reqs].max(),
|
|
seq_start_loc=seq_start_loc)
|
|
|
|
|
|
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
|
|
- num_rejected_tokens_gpu
|
|
|
|
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
|
|
|
|
|
def wrapper_EagleProposer_propose(fn):
|
|
|
|
@wraps(fn)
|
|
def wrapper(
|
|
self,
|
|
# [num_tokens]
|
|
target_token_ids: torch.Tensor,
|
|
# [num_tokens]
|
|
target_positions: torch.Tensor,
|
|
# [num_tokens, hidden_size]
|
|
target_hidden_states: torch.Tensor,
|
|
# [batch_size]
|
|
next_token_ids: torch.Tensor,
|
|
last_token_indices: Optional[torch.Tensor],
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
sampling_metadata: SamplingMetadata,
|
|
mm_embeds: Optional[list[torch.Tensor]] = None,
|
|
):
|
|
|
|
if last_token_indices is None:
|
|
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
|
|
|
last_token_indices = last_token_indices.long()
|
|
|
|
return fn(
|
|
self,
|
|
# [num_tokens]
|
|
target_token_ids,
|
|
# [num_tokens]
|
|
target_positions,
|
|
# [num_tokens, hidden_size]
|
|
target_hidden_states,
|
|
# [batch_size]
|
|
next_token_ids,
|
|
last_token_indices,
|
|
common_attn_metadata,
|
|
sampling_metadata,
|
|
mm_embeds)
|
|
|
|
return wrapper
|
|
|
|
|
|
EagleProposer.propose = wrapper_EagleProposer_propose(
|
|
EagleProposer.propose) # noqa: E501
|