Files
bi_150-vllm/vllm/v1/spec_decode/multi_layer_eagle.py

505 lines
16 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
CommonAttentionMetadata,
)
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata
BLOCK_HIDDEN = 128
BLOCK_TOKENS = 128
class MultiLayerEagleProposer(EagleProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(vllm_config, device, runner)
self.layer_num: int = getattr(
self.speculative_config.draft_model_config.hf_text_config,
"n_predict", 0
)
self.num_speculative_tokens: int = (
self.speculative_config.num_speculative_tokens
)
def adjust_input(
self,
batch_size: int,
target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
assert multi_layer_eagle_metadata is not None
if token_indices_to_sample is None:
token_indices_to_sample = (
common_attn_metadata.query_start_loc[1:] - 1
)
MAX_SHIFT = self.layer_num
assert MAX_SHIFT > 0
prev_token_ids = target_token_ids.clone()
prev_positions = target_positions.clone()
prev_hidden_states = target_hidden_states.clone()
slot_mapping = common_attn_metadata.slot_mapping
start_token_indices = common_attn_metadata.query_start_loc[:-1]
end_token_indices = common_attn_metadata.query_start_loc[1:] - 1
pos_for_shift = (
target_positions[0]
if target_positions.dim() == 2
else target_positions
)
start_token_pos = pos_for_shift[start_token_indices]
shift = torch.minimum(
end_token_indices - token_indices_to_sample,
start_token_pos,
)
shift = torch.clamp(shift, min=0)
token_indices_to_sample.add_(shift)
common_attn_metadata.seq_lens.sub_(shift)
cached_lens = multi_layer_eagle_metadata.cached_len
shift = torch.minimum(shift, cached_lens)
_multi_layer_eagle_shift_and_cache(
batch_size=batch_size,
max_shift=MAX_SHIFT,
src_token_ids=target_token_ids,
dst_token_ids=prev_token_ids,
src_positions=target_positions,
dst_positions=prev_positions,
src_hidden_states=target_hidden_states,
dst_hidden_states=prev_hidden_states,
src_slot_mapping=slot_mapping,
dst_slot_mapping=slot_mapping,
start_token_indices=start_token_indices,
end_token_indices=end_token_indices,
token_indices_to_sample=token_indices_to_sample,
shift=shift,
cached_lens=cached_lens,
cached_prev_token_ids=(
multi_layer_eagle_metadata.cached_token_ids
),
cached_prev_positions=(
multi_layer_eagle_metadata.cached_positions
),
cached_prev_hidden_states=(
multi_layer_eagle_metadata.cached_hidden_states
),
cached_slot_mappings=(
multi_layer_eagle_metadata.cached_slot_mappings
),
common_attn_metadata=common_attn_metadata,
)
return (
prev_token_ids,
prev_positions,
prev_hidden_states,
common_attn_metadata,
)
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
raise Exception(
"speculative_config.disable_padded_drafter_batch"
" is not supported now for MultiLayerEagleProposer."
)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
)
if (
self._draft_attn_layer_names
and slot_mappings is not None
and next(iter(self._draft_attn_layer_names)) in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
adjust_input_kwargs = {
"batch_size": 1,
"target_token_ids": self.input_ids[:num_input_tokens],
"target_positions": self._get_positions(num_input_tokens),
"target_hidden_states": self.hidden_states[:num_input_tokens],
"token_indices_to_sample": torch.tensor(
[num_input_tokens - 1],
dtype=torch.int32,
device=self.device,
),
"common_attn_metadata": CommonAttentionMetadata(
query_start_loc=torch.tensor(
[0, num_input_tokens],
dtype=torch.int32,
device=self.device,
),
query_start_loc_cpu=torch.tensor(
[0, num_input_tokens],
dtype=torch.int32,
device="cpu",
),
key_start_loc=torch.tensor(
[0, num_input_tokens],
dtype=torch.int32,
device=self.device,
),
seq_lens=torch.tensor(
[num_input_tokens],
dtype=torch.int32,
device=self.device,
),
seq_lens_np=np.array([num_input_tokens], dtype=np.int32),
num_reqs=1,
num_actual_tokens=num_input_tokens,
max_query_len=self.num_speculative_tokens + 1,
max_seq_len=self.max_model_len,
block_table_tensor=torch.tensor(
[], dtype=torch.int32, device=self.device
),
slot_mapping=self.arange[:num_input_tokens],
logits_indices_padded=None,
num_logits_indices=None,
causal=True,
encoder_seq_lens=None,
),
"multi_layer_eagle_metadata": MultiLayerEagleMetadata.make_dummy(
layer_num=self.layer_num,
hidden_size=self.hidden_size,
device=self.device,
),
}
self.adjust_input(**adjust_input_kwargs)
for fwd_idx in range(self.layer_num):
with set_forward_context(
None,
self.draft_vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"hidden_states": self.hidden_states[:num_input_tokens],
"inputs_embeds": inputs_embeds,
"spec_step_idx": fwd_idx,
}
self.model(**model_kwargs)
def _multi_layer_eagle_shift_and_cache(
*,
batch_size: int,
max_shift: int,
src_token_ids: torch.Tensor,
dst_token_ids: torch.Tensor,
src_positions: torch.Tensor,
dst_positions: torch.Tensor,
src_hidden_states: torch.Tensor,
dst_hidden_states: torch.Tensor,
src_slot_mapping: torch.Tensor,
dst_slot_mapping: torch.Tensor,
start_token_indices: torch.Tensor,
end_token_indices: torch.Tensor,
token_indices_to_sample: torch.Tensor,
shift: torch.Tensor,
cached_lens: torch.Tensor,
cached_prev_token_ids: torch.Tensor,
cached_prev_positions: torch.Tensor,
cached_prev_hidden_states: torch.Tensor,
cached_slot_mappings: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
):
if batch_size == 0:
return
assert max_shift > 0
assert cached_prev_positions.is_contiguous()
assert cached_prev_token_ids.is_contiguous()
assert cached_prev_hidden_states.is_contiguous()
assert cached_slot_mappings.is_contiguous()
assert src_hidden_states.is_contiguous()
assert dst_hidden_states.is_contiguous()
if src_slot_mapping.data_ptr() == dst_slot_mapping.data_ptr():
src_slot_mapping = src_slot_mapping.clone()
store_start = torch.maximum(
start_token_indices,
(token_indices_to_sample + 1 - max_shift),
)
store_lens = torch.clamp(
token_indices_to_sample - store_start + 1,
min=0,
max=max_shift,
)
max_window_len = int(
(
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
.max()
.item()
)
num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_token_ids,
dst_token_ids,
cached_prev_token_ids,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_slot_mapping,
dst_slot_mapping,
cached_slot_mappings,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_positions,
dst_positions,
cached_prev_positions,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
hidden_size = int(dst_hidden_states.shape[1])
num_hidden_blocks = max(
1, (hidden_size + BLOCK_HIDDEN - 1) // BLOCK_HIDDEN
)
_shift_and_gather_hidden_kernel[
(batch_size, num_blocks, num_hidden_blocks)
](
src_hidden_states,
dst_hidden_states,
cached_prev_hidden_states,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
HIDDEN_SIZE=hidden_size,
BLOCK_TOKENS=BLOCK_TOKENS,
BLOCK_HIDDEN=BLOCK_HIDDEN,
num_warps=4,
)
cached_lens.copy_(store_lens)
return
@triton.jit
def _shift_and_gather_cache_1d_kernel(
src_ptr,
dst_ptr,
cached_ptr,
start_ptr,
end_ptr,
shift_ptr,
cached_len_ptr,
store_start_ptr,
store_len_ptr,
MAX_SHIFT: tl.constexpr,
PADDED_SHIFT: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
):
# Per-sequence "shift + gather" for packed 1D arrays (token ids, positions,
# slot mappings, ...).
#
# For a single sequence (0-based index i within its window):
# - Prefix (i < shift):
# dst[start + i] = cached[cached_len - shift + i]
# - Body (i >= shift):
# dst[start + i] = src[start + i - shift]
pid_seq = tl.program_id(0)
pid_blk = tl.program_id(1)
start = tl.load(start_ptr + pid_seq).to(tl.int32)
end = tl.load(end_ptr + pid_seq).to(tl.int32)
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
assert cached_len >= shift
base = pid_blk * BLOCK_TOKENS
k = tl.arange(0, BLOCK_TOKENS)
offs = base + k
dst_idx = start + offs
window_len = end - start + 1
mask = offs < window_len
base_cached = cached_ptr + pid_seq * MAX_SHIFT
cached_idx = cached_len - shift + offs
cached_mask = offs < shift
val_cached = tl.load(
base_cached + cached_idx, mask=mask & cached_mask, other=0
)
src_idx = start + offs - shift
val_src = tl.load(src_ptr + src_idx, mask=mask & ~cached_mask, other=0)
val = tl.where(cached_mask, val_cached, val_src)
tl.store(dst_ptr + dst_idx, val, mask=mask)
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
m = tl.arange(0, PADDED_SHIFT)
store_mask = m < MAX_SHIFT
dst_idx = store_start + m
val = tl.load(
dst_ptr + dst_idx, mask=store_mask & (m < store_len), other=0
)
tl.store(base_cached + m, val, mask=store_mask)
@triton.jit
def _shift_and_gather_hidden_kernel(
src_ptr,
dst_ptr,
cached_ptr,
start_ptr,
end_ptr,
shift_ptr,
cached_len_ptr,
store_start_ptr,
store_len_ptr,
MAX_SHIFT: tl.constexpr,
PADDED_SHIFT: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
BLOCK_HIDDEN: tl.constexpr,
):
# Per-sequence "shift + gather" for hidden states.
# Layout:
# - src_ptr / dst_ptr: [num_tokens, hidden_size]
# - cached_ptr: [batch_size, MAX_SHIFT, hidden_size]
pid_seq = tl.program_id(0)
pid_blk = tl.program_id(1)
pid_hid = tl.program_id(2)
start = tl.load(start_ptr + pid_seq).to(tl.int32)
end = tl.load(end_ptr + pid_seq).to(tl.int32)
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
assert cached_len >= shift
base = pid_blk * BLOCK_TOKENS
k = tl.arange(0, BLOCK_TOKENS)
tok_offs = base + k
dst_tok = start + tok_offs
n = pid_hid * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)
dst_ptrs = dst_ptr + dst_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
window_len = end - start + 1
tok_mask = tok_offs < window_len
n_mask = n < HIDDEN_SIZE
mask = tok_mask[:, None] & n_mask[None, :]
base_cached = cached_ptr + pid_seq * HIDDEN_SIZE * MAX_SHIFT
cached_tok = cached_len - shift + tok_offs
cached_ptrs = (
base_cached + cached_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
)
cached_mask = tok_offs < shift
val_cached = tl.load(
cached_ptrs, mask=mask & cached_mask[:, None], other=0
)
src_tok = start + tok_offs - shift
src_ptrs = src_ptr + src_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
val_src = tl.load(src_ptrs, mask=mask & ~cached_mask[:, None], other=0)
val = tl.where(cached_mask[:, None], val_cached, val_src)
tl.store(dst_ptrs, val, mask=mask)
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
m = tl.arange(0, PADDED_SHIFT)
m_mask = (m < MAX_SHIFT) & (m < store_len)
store_tok = store_start + m
dst_ptrs = dst_ptr + store_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
store_ptrs = (
base_cached + m[:, None] * HIDDEN_SIZE + n[None, :] * 1
)
mask = m_mask[:, None] & n_mask[None, :]
val = tl.load(dst_ptrs, mask=mask, other=0)
tl.store(store_ptrs, val, mask=mask)