# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any import numpy as np import torch from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( CommonAttentionMetadata, ) from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata BLOCK_HIDDEN = 128 BLOCK_TOKENS = 128 class MultiLayerEagleProposer(EagleProposer): def __init__( self, vllm_config: VllmConfig, device: torch.device, runner=None, ): super().__init__(vllm_config, device, runner) self.layer_num: int = getattr( self.speculative_config.draft_model_config.hf_text_config, "n_predict", 0 ) self.num_speculative_tokens: int = ( self.speculative_config.num_speculative_tokens ) def adjust_input( self, batch_size: int, target_token_ids: torch.Tensor, target_positions: torch.Tensor, target_hidden_states: torch.Tensor, token_indices_to_sample: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]: assert multi_layer_eagle_metadata is not None if token_indices_to_sample is None: token_indices_to_sample = ( common_attn_metadata.query_start_loc[1:] - 1 ) MAX_SHIFT = self.layer_num assert MAX_SHIFT > 0 prev_token_ids = target_token_ids.clone() prev_positions = target_positions.clone() prev_hidden_states = target_hidden_states.clone() slot_mapping = common_attn_metadata.slot_mapping start_token_indices = common_attn_metadata.query_start_loc[:-1] end_token_indices = common_attn_metadata.query_start_loc[1:] - 1 pos_for_shift = ( target_positions[0] if target_positions.dim() == 2 else target_positions ) start_token_pos = pos_for_shift[start_token_indices] shift = torch.minimum( end_token_indices - token_indices_to_sample, start_token_pos, ) shift = torch.clamp(shift, min=0) token_indices_to_sample.add_(shift) common_attn_metadata.seq_lens.sub_(shift) cached_lens = multi_layer_eagle_metadata.cached_len shift = torch.minimum(shift, cached_lens) _multi_layer_eagle_shift_and_cache( batch_size=batch_size, max_shift=MAX_SHIFT, src_token_ids=target_token_ids, dst_token_ids=prev_token_ids, src_positions=target_positions, dst_positions=prev_positions, src_hidden_states=target_hidden_states, dst_hidden_states=prev_hidden_states, src_slot_mapping=slot_mapping, dst_slot_mapping=slot_mapping, start_token_indices=start_token_indices, end_token_indices=end_token_indices, token_indices_to_sample=token_indices_to_sample, shift=shift, cached_lens=cached_lens, cached_prev_token_ids=( multi_layer_eagle_metadata.cached_token_ids ), cached_prev_positions=( multi_layer_eagle_metadata.cached_positions ), cached_prev_hidden_states=( multi_layer_eagle_metadata.cached_hidden_states ), cached_slot_mappings=( multi_layer_eagle_metadata.cached_slot_mappings ), common_attn_metadata=common_attn_metadata, ) return ( prev_token_ids, prev_positions, prev_hidden_states, common_attn_metadata, ) def prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata, sampled_token_ids: list[list[int]], num_draft_tokens: list[int], ) -> tuple[CommonAttentionMetadata, torch.Tensor]: raise Exception( "speculative_config.disable_padded_drafter_batch" " is not supported now for MultiLayerEagleProposer." ) @torch.inference_mode() def dummy_run( self, num_tokens: int, use_cudagraphs: bool = True, is_graph_capturing: bool = False, slot_mappings: dict[str, torch.Tensor] | None = None, ) -> None: cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( self._determine_batch_execution_and_padding( num_tokens, use_cudagraphs=use_cudagraphs ) ) if ( self._draft_attn_layer_names and slot_mappings is not None and next(iter(self._draft_attn_layer_names)) in slot_mappings ): slot_mapping_dict = self._get_slot_mapping(num_input_tokens) else: slot_mapping_dict = slot_mappings or {} adjust_input_kwargs = { "batch_size": 1, "target_token_ids": self.input_ids[:num_input_tokens], "target_positions": self._get_positions(num_input_tokens), "target_hidden_states": self.hidden_states[:num_input_tokens], "token_indices_to_sample": torch.tensor( [num_input_tokens - 1], dtype=torch.int32, device=self.device, ), "common_attn_metadata": CommonAttentionMetadata( query_start_loc=torch.tensor( [0, num_input_tokens], dtype=torch.int32, device=self.device, ), query_start_loc_cpu=torch.tensor( [0, num_input_tokens], dtype=torch.int32, device="cpu", ), key_start_loc=torch.tensor( [0, num_input_tokens], dtype=torch.int32, device=self.device, ), seq_lens=torch.tensor( [num_input_tokens], dtype=torch.int32, device=self.device, ), seq_lens_np=np.array([num_input_tokens], dtype=np.int32), num_reqs=1, num_actual_tokens=num_input_tokens, max_query_len=self.num_speculative_tokens + 1, max_seq_len=self.max_model_len, block_table_tensor=torch.tensor( [], dtype=torch.int32, device=self.device ), slot_mapping=self.arange[:num_input_tokens], logits_indices_padded=None, num_logits_indices=None, causal=True, encoder_seq_lens=None, ), "multi_layer_eagle_metadata": MultiLayerEagleMetadata.make_dummy( layer_num=self.layer_num, hidden_size=self.hidden_size, device=self.device, ), } self.adjust_input(**adjust_input_kwargs) for fwd_idx in range(self.layer_num): with set_forward_context( None, self.draft_vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, slot_mapping=slot_mapping_dict, ): if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_input_tokens] else: input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None model_kwargs = { "input_ids": input_ids, "positions": self._get_positions(num_input_tokens), "hidden_states": self.hidden_states[:num_input_tokens], "inputs_embeds": inputs_embeds, "spec_step_idx": fwd_idx, } self.model(**model_kwargs) def _multi_layer_eagle_shift_and_cache( *, batch_size: int, max_shift: int, src_token_ids: torch.Tensor, dst_token_ids: torch.Tensor, src_positions: torch.Tensor, dst_positions: torch.Tensor, src_hidden_states: torch.Tensor, dst_hidden_states: torch.Tensor, src_slot_mapping: torch.Tensor, dst_slot_mapping: torch.Tensor, start_token_indices: torch.Tensor, end_token_indices: torch.Tensor, token_indices_to_sample: torch.Tensor, shift: torch.Tensor, cached_lens: torch.Tensor, cached_prev_token_ids: torch.Tensor, cached_prev_positions: torch.Tensor, cached_prev_hidden_states: torch.Tensor, cached_slot_mappings: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, ): if batch_size == 0: return assert max_shift > 0 assert cached_prev_positions.is_contiguous() assert cached_prev_token_ids.is_contiguous() assert cached_prev_hidden_states.is_contiguous() assert cached_slot_mappings.is_contiguous() assert src_hidden_states.is_contiguous() assert dst_hidden_states.is_contiguous() if src_slot_mapping.data_ptr() == dst_slot_mapping.data_ptr(): src_slot_mapping = src_slot_mapping.clone() store_start = torch.maximum( start_token_indices, (token_indices_to_sample + 1 - max_shift), ) store_lens = torch.clamp( token_indices_to_sample - store_start + 1, min=0, max=max_shift, ) max_window_len = int( ( common_attn_metadata.query_start_loc_cpu[1:] - common_attn_metadata.query_start_loc_cpu[:-1] ) .max() .item() ) num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS) _shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)]( src_token_ids, dst_token_ids, cached_prev_token_ids, start_token_indices, end_token_indices, shift, cached_lens, store_start, store_lens, MAX_SHIFT=max_shift, PADDED_SHIFT=triton.next_power_of_2(max_shift), BLOCK_TOKENS=BLOCK_TOKENS, ) _shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)]( src_slot_mapping, dst_slot_mapping, cached_slot_mappings, start_token_indices, end_token_indices, shift, cached_lens, store_start, store_lens, MAX_SHIFT=max_shift, PADDED_SHIFT=triton.next_power_of_2(max_shift), BLOCK_TOKENS=BLOCK_TOKENS, ) _shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)]( src_positions, dst_positions, cached_prev_positions, start_token_indices, end_token_indices, shift, cached_lens, store_start, store_lens, MAX_SHIFT=max_shift, PADDED_SHIFT=triton.next_power_of_2(max_shift), BLOCK_TOKENS=BLOCK_TOKENS, ) hidden_size = int(dst_hidden_states.shape[1]) num_hidden_blocks = max( 1, (hidden_size + BLOCK_HIDDEN - 1) // BLOCK_HIDDEN ) _shift_and_gather_hidden_kernel[ (batch_size, num_blocks, num_hidden_blocks) ]( src_hidden_states, dst_hidden_states, cached_prev_hidden_states, start_token_indices, end_token_indices, shift, cached_lens, store_start, store_lens, MAX_SHIFT=max_shift, PADDED_SHIFT=triton.next_power_of_2(max_shift), HIDDEN_SIZE=hidden_size, BLOCK_TOKENS=BLOCK_TOKENS, BLOCK_HIDDEN=BLOCK_HIDDEN, num_warps=4, ) cached_lens.copy_(store_lens) return @triton.jit def _shift_and_gather_cache_1d_kernel( src_ptr, dst_ptr, cached_ptr, start_ptr, end_ptr, shift_ptr, cached_len_ptr, store_start_ptr, store_len_ptr, MAX_SHIFT: tl.constexpr, PADDED_SHIFT: tl.constexpr, BLOCK_TOKENS: tl.constexpr, ): # Per-sequence "shift + gather" for packed 1D arrays (token ids, positions, # slot mappings, ...). # # For a single sequence (0-based index i within its window): # - Prefix (i < shift): # dst[start + i] = cached[cached_len - shift + i] # - Body (i >= shift): # dst[start + i] = src[start + i - shift] pid_seq = tl.program_id(0) pid_blk = tl.program_id(1) start = tl.load(start_ptr + pid_seq).to(tl.int32) end = tl.load(end_ptr + pid_seq).to(tl.int32) shift = tl.load(shift_ptr + pid_seq).to(tl.int32) cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32) assert cached_len >= shift base = pid_blk * BLOCK_TOKENS k = tl.arange(0, BLOCK_TOKENS) offs = base + k dst_idx = start + offs window_len = end - start + 1 mask = offs < window_len base_cached = cached_ptr + pid_seq * MAX_SHIFT cached_idx = cached_len - shift + offs cached_mask = offs < shift val_cached = tl.load( base_cached + cached_idx, mask=mask & cached_mask, other=0 ) src_idx = start + offs - shift val_src = tl.load(src_ptr + src_idx, mask=mask & ~cached_mask, other=0) val = tl.where(cached_mask, val_cached, val_src) tl.store(dst_ptr + dst_idx, val, mask=mask) store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32) store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32) m = tl.arange(0, PADDED_SHIFT) store_mask = m < MAX_SHIFT dst_idx = store_start + m val = tl.load( dst_ptr + dst_idx, mask=store_mask & (m < store_len), other=0 ) tl.store(base_cached + m, val, mask=store_mask) @triton.jit def _shift_and_gather_hidden_kernel( src_ptr, dst_ptr, cached_ptr, start_ptr, end_ptr, shift_ptr, cached_len_ptr, store_start_ptr, store_len_ptr, MAX_SHIFT: tl.constexpr, PADDED_SHIFT: tl.constexpr, HIDDEN_SIZE: tl.constexpr, BLOCK_TOKENS: tl.constexpr, BLOCK_HIDDEN: tl.constexpr, ): # Per-sequence "shift + gather" for hidden states. # Layout: # - src_ptr / dst_ptr: [num_tokens, hidden_size] # - cached_ptr: [batch_size, MAX_SHIFT, hidden_size] pid_seq = tl.program_id(0) pid_blk = tl.program_id(1) pid_hid = tl.program_id(2) start = tl.load(start_ptr + pid_seq).to(tl.int32) end = tl.load(end_ptr + pid_seq).to(tl.int32) shift = tl.load(shift_ptr + pid_seq).to(tl.int32) cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32) assert cached_len >= shift base = pid_blk * BLOCK_TOKENS k = tl.arange(0, BLOCK_TOKENS) tok_offs = base + k dst_tok = start + tok_offs n = pid_hid * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN) dst_ptrs = dst_ptr + dst_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1 window_len = end - start + 1 tok_mask = tok_offs < window_len n_mask = n < HIDDEN_SIZE mask = tok_mask[:, None] & n_mask[None, :] base_cached = cached_ptr + pid_seq * HIDDEN_SIZE * MAX_SHIFT cached_tok = cached_len - shift + tok_offs cached_ptrs = ( base_cached + cached_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1 ) cached_mask = tok_offs < shift val_cached = tl.load( cached_ptrs, mask=mask & cached_mask[:, None], other=0 ) src_tok = start + tok_offs - shift src_ptrs = src_ptr + src_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1 val_src = tl.load(src_ptrs, mask=mask & ~cached_mask[:, None], other=0) val = tl.where(cached_mask[:, None], val_cached, val_src) tl.store(dst_ptrs, val, mask=mask) store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32) store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32) m = tl.arange(0, PADDED_SHIFT) m_mask = (m < MAX_SHIFT) & (m < store_len) store_tok = store_start + m dst_ptrs = dst_ptr + store_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1 store_ptrs = ( base_cached + m[:, None] * HIDDEN_SIZE + n[None, :] * 1 ) mask = m_mask[:, None] & n_mask[None, :] val = tl.load(dst_ptrs, mask=mask, other=0) tl.store(store_ptrs, val, mask=mask)