Files
enginex-biren-vllm/vllm_br/v1/spec_decode/eagle.py

266 lines
11 KiB
Python
Raw Normal View History

2026-03-10 13:31:25 +08:00
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import torch
from fastcore.basics import patch_to
import vllm_br.envs as biren_envs
from vllm.logger import init_logger
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_br.v1.worker.model_runner import SUPACommonAttentionMetadata
logger = init_logger(__name__)
PADDING_SLOT_ID = -1
def wrapper_EagleProposer_init(fn):
# FIXME: temporary fix for enabling MLA in EagleProposer
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
self.draft_model_config.weight_type = biren_envs.VLLM_BR_WEIGHT_TYPE
self.draft_model_config.use_ds_mla = True
self.draft_model_config.use_ds_mla_sparse = hasattr(
self.draft_model_config.hf_config, "index_topk")
return wrapper
EagleProposer.__init__ = wrapper_EagleProposer_init(
EagleProposer.__init__) # noqa: E501
@patch_to(EagleProposer)
def prepare_inputs(
self,
common_attn_metadata: SUPACommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[SUPACommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
# E.g.
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3]
# This function computes the intermediate values:
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# And returns:
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# common_attn_metadata.seq_lens{_cpu}:
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# token_indices: [0, 1, ..., q1 - n1 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
- num_rejected_tokens
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
# [q1 - n1, q2 - n2, q3 - n3] ->
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
new_query_start_loc_cpu = torch.zeros(query_start_loc_cpu.shape,
dtype=torch.int32,
pin_memory=is_pin_memory_available())
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
total_num_tokens = new_query_start_loc_np[-1]
# Example assuming num_tokens_per_req_np = [2, 4, 3]
# this implies that `new_query_start_locs` is:
# [0, 2, 6, 9] ->
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
# _r1_ ____r2____ ___r3__
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
new_num_tokens_per_req_np)
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
# _r1_ ____r2____ ___r3__
token_offests = self.token_arange_np[:total_num_tokens] \
- new_query_start_locs_expanded
# Expand starting positions to match token pattern
# [0, q1, q1 + q2] ->
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
# _r1_ _____r2_______ ___________r3____________
old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(),
new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to(device,
non_blocking=True)
# seq_start_loc = torch.from_numpy(
# np.insert(np.add.accumulate(common_attn_metadata.seq_lens.cpu().numpy()), 0,
# 0)).to(common_attn_metadata.query_start_loc, non_blocking=True)
spec_common_attn_metadata = SUPACommonAttentionMetadata(
query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens_cpu=new_seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
max_seq_len=new_seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
# seq_start_loc=seq_start_loc
)
return spec_common_attn_metadata, token_indices
@patch_to(EagleProposer)
def prepare_inputs_padded(self,
common_attn_metadata: SUPACommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor) -> \
tuple[SUPACommonAttentionMetadata, torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_draft_tokens_gpu = torch.cat([
spec_decode_metadata.cu_num_draft_tokens[0:1],
spec_decode_metadata.cu_num_draft_tokens[1:] -
spec_decode_metadata.cu_num_draft_tokens[:-1]
])
num_rejected_tokens_gpu = torch.where(
num_draft_tokens_gpu > 0,
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
torch.zeros_like(num_draft_tokens_gpu))
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
total_num_tokens = query_start_loc_cpu[-1].item()
token_indices = self.arange[:total_num_tokens]
seq_start_loc = torch.from_numpy(
np.insert(
np.add.accumulate(common_attn_metadata.seq_lens.cpu().numpy()), 0,
0)).to(common_attn_metadata.query_start_loc, non_blocking=True)
spec_common_attn_metadata = SUPACommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
seq_lens=common_attn_metadata.seq_lens,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices.long()],
causal=True,
# context_lens=context_lens,
# max_decode_seq_len=self.seq_lens.np[:num_reqs].max(),
seq_start_loc=seq_start_loc)
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
- num_rejected_tokens_gpu
return spec_common_attn_metadata, token_indices, token_indices_to_sample
def wrapper_EagleProposer_propose(fn):
@wraps(fn)
def wrapper(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
last_token_indices: Optional[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
):
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
last_token_indices = last_token_indices.long()
return fn(
self,
# [num_tokens]
target_token_ids,
# [num_tokens]
target_positions,
# [num_tokens, hidden_size]
target_hidden_states,
# [batch_size]
next_token_ids,
last_token_indices,
common_attn_metadata,
sampling_metadata,
mm_embeds)
return wrapper
EagleProposer.propose = wrapper_EagleProposer_propose(
EagleProposer.propose) # noqa: E501