################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from functools import wraps # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional import numpy as np import torch from fastcore.basics import patch_to import vllm_br.envs as biren_envs from vllm.logger import init_logger from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm_br.v1.worker.model_runner import SUPACommonAttentionMetadata logger = init_logger(__name__) PADDING_SLOT_ID = -1 def wrapper_EagleProposer_init(fn): # FIXME: temporary fix for enabling MLA in EagleProposer @wraps(fn) def wrapper(self, *args, **kwargs): fn(self, *args, **kwargs) self.draft_model_config.weight_type = biren_envs.VLLM_BR_WEIGHT_TYPE self.draft_model_config.use_ds_mla = True self.draft_model_config.use_ds_mla_sparse = hasattr( self.draft_model_config.hf_config, "index_topk") return wrapper EagleProposer.__init__ = wrapper_EagleProposer_init( EagleProposer.__init__) # noqa: E501 @patch_to(EagleProposer) def prepare_inputs( self, common_attn_metadata: SUPACommonAttentionMetadata, sampled_token_ids: list[list[int]], num_draft_tokens: list[int], ) -> tuple[SUPACommonAttentionMetadata, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding. It updates to the common_attn_metadata to account for the rejected tokens (and newly sampled tokens). It also returns the token indices of the tokens that should be fed to the speculator. """ # E.g. # common_attn_metadata.query_start_loc{_cpu}: # [0, q1, q1 + q2, q1 + q2 + q3] # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] # num_rejected_tokens: [n1, n2, n3] # This function computes the intermediate values: # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] # And returns: # common_attn_metadata.query_start_loc{_cpu}: # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] # common_attn_metadata.seq_lens{_cpu}: # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] # token_indices: [0, 1, ..., q1 - n1 - 1, # q1, q1 + 1, ..., q1 + q2 - n2 - 1, # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] num_rejected_tokens = [ n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32) device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - num_rejected_tokens # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] new_query_len_per_req = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() # [q1 - n1, q2 - n2, q3 - n3] -> # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] new_query_start_loc_cpu = torch.zeros(query_start_loc_cpu.shape, dtype=torch.int32, pin_memory=is_pin_memory_available()) new_query_start_loc_np = new_query_start_loc_cpu.numpy() np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) total_num_tokens = new_query_start_loc_np[-1] # Example assuming num_tokens_per_req_np = [2, 4, 3] # this implies that `new_query_start_locs` is: # [0, 2, 6, 9] -> # [0, 0, 2, 2, 2, 2, 6, 6, 6] # _r1_ ____r2____ ___r3__ new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], new_num_tokens_per_req_np) # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> # [0, 1, 0, 1, 2, 3, 0, 1, 2] # _r1_ ____r2____ ___r3__ token_offests = self.token_arange_np[:total_num_tokens] \ - new_query_start_locs_expanded # Expand starting positions to match token pattern # [0, q1, q1 + q2] -> # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] # _r1_ _____r2_______ ___________r3____________ old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) # Final token indices are: # [0, 1, // req 1 # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True) # seq_start_loc = torch.from_numpy( # np.insert(np.add.accumulate(common_attn_metadata.seq_lens.cpu().numpy()), 0, # 0)).to(common_attn_metadata.query_start_loc, non_blocking=True) spec_common_attn_metadata = SUPACommonAttentionMetadata( query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), query_start_loc_cpu=new_query_start_loc_cpu, seq_lens_cpu=new_seq_lens_cpu, num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), max_seq_len=new_seq_lens_cpu.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, # seq_start_loc=seq_start_loc ) return spec_common_attn_metadata, token_indices @patch_to(EagleProposer) def prepare_inputs_padded(self, common_attn_metadata: SUPACommonAttentionMetadata, spec_decode_metadata: SpecDecodeMetadata, valid_sampled_tokens_count: torch.Tensor) -> \ tuple[SUPACommonAttentionMetadata, torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding It updates the common_attn_metadata for speculative decoding, but does not consider the rejected tokens. Instead, all tokens are included as inputs to the speculator, with the rejected tokens used as padding and filtered out later by `token_indices_to_sample`. No blocking CPU operations should be introduced in this function. """ num_draft_tokens_gpu = torch.cat([ spec_decode_metadata.cu_num_draft_tokens[0:1], spec_decode_metadata.cu_num_draft_tokens[1:] - spec_decode_metadata.cu_num_draft_tokens[:-1] ]) num_rejected_tokens_gpu = torch.where( num_draft_tokens_gpu > 0, num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, torch.zeros_like(num_draft_tokens_gpu)) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu new_query_len_per_req = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) total_num_tokens = query_start_loc_cpu[-1].item() token_indices = self.arange[:total_num_tokens] seq_start_loc = torch.from_numpy( np.insert( np.add.accumulate(common_attn_metadata.seq_lens.cpu().numpy()), 0, 0)).to(common_attn_metadata.query_start_loc, non_blocking=True) spec_common_attn_metadata = SUPACommonAttentionMetadata( query_start_loc=common_attn_metadata.query_start_loc, seq_lens=common_attn_metadata.seq_lens, query_start_loc_cpu=query_start_loc_cpu, seq_lens_cpu=common_attn_metadata.seq_lens_cpu, num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices.long()], causal=True, # context_lens=context_lens, # max_decode_seq_len=self.seq_lens.np[:num_reqs].max(), seq_start_loc=seq_start_loc) token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \ - num_rejected_tokens_gpu return spec_common_attn_metadata, token_indices, token_indices_to_sample def wrapper_EagleProposer_propose(fn): @wraps(fn) def wrapper( self, # [num_tokens] target_token_ids: torch.Tensor, # [num_tokens] target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, mm_embeds: Optional[list[torch.Tensor]] = None, ): if last_token_indices is None: last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 last_token_indices = last_token_indices.long() return fn( self, # [num_tokens] target_token_ids, # [num_tokens] target_positions, # [num_tokens, hidden_size] target_hidden_states, # [batch_size] next_token_ids, last_token_indices, common_attn_metadata, sampling_metadata, mm_embeds) return wrapper EagleProposer.propose = wrapper_EagleProposer_propose( EagleProposer.propose) # noqa: E501