Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
504
vllm/v1/spec_decode/multi_layer_eagle.py
Normal file
504
vllm/v1/spec_decode/multi_layer_eagle.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import (
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata
|
||||
|
||||
BLOCK_HIDDEN = 128
|
||||
BLOCK_TOKENS = 128
|
||||
|
||||
|
||||
class MultiLayerEagleProposer(EagleProposer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
super().__init__(vllm_config, device, runner)
|
||||
|
||||
self.layer_num: int = getattr(
|
||||
self.speculative_config.draft_model_config.hf_text_config,
|
||||
"n_predict", 0
|
||||
)
|
||||
self.num_speculative_tokens: int = (
|
||||
self.speculative_config.num_speculative_tokens
|
||||
)
|
||||
|
||||
def adjust_input(
|
||||
self,
|
||||
batch_size: int,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_positions: torch.Tensor,
|
||||
target_hidden_states: torch.Tensor,
|
||||
token_indices_to_sample: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
|
||||
assert multi_layer_eagle_metadata is not None
|
||||
if token_indices_to_sample is None:
|
||||
token_indices_to_sample = (
|
||||
common_attn_metadata.query_start_loc[1:] - 1
|
||||
)
|
||||
|
||||
MAX_SHIFT = self.layer_num
|
||||
assert MAX_SHIFT > 0
|
||||
|
||||
prev_token_ids = target_token_ids.clone()
|
||||
prev_positions = target_positions.clone()
|
||||
prev_hidden_states = target_hidden_states.clone()
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
start_token_indices = common_attn_metadata.query_start_loc[:-1]
|
||||
end_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
pos_for_shift = (
|
||||
target_positions[0]
|
||||
if target_positions.dim() == 2
|
||||
else target_positions
|
||||
)
|
||||
start_token_pos = pos_for_shift[start_token_indices]
|
||||
|
||||
shift = torch.minimum(
|
||||
end_token_indices - token_indices_to_sample,
|
||||
start_token_pos,
|
||||
)
|
||||
shift = torch.clamp(shift, min=0)
|
||||
|
||||
token_indices_to_sample.add_(shift)
|
||||
common_attn_metadata.seq_lens.sub_(shift)
|
||||
|
||||
cached_lens = multi_layer_eagle_metadata.cached_len
|
||||
shift = torch.minimum(shift, cached_lens)
|
||||
|
||||
_multi_layer_eagle_shift_and_cache(
|
||||
batch_size=batch_size,
|
||||
max_shift=MAX_SHIFT,
|
||||
src_token_ids=target_token_ids,
|
||||
dst_token_ids=prev_token_ids,
|
||||
src_positions=target_positions,
|
||||
dst_positions=prev_positions,
|
||||
src_hidden_states=target_hidden_states,
|
||||
dst_hidden_states=prev_hidden_states,
|
||||
src_slot_mapping=slot_mapping,
|
||||
dst_slot_mapping=slot_mapping,
|
||||
start_token_indices=start_token_indices,
|
||||
end_token_indices=end_token_indices,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
shift=shift,
|
||||
cached_lens=cached_lens,
|
||||
cached_prev_token_ids=(
|
||||
multi_layer_eagle_metadata.cached_token_ids
|
||||
),
|
||||
cached_prev_positions=(
|
||||
multi_layer_eagle_metadata.cached_positions
|
||||
),
|
||||
cached_prev_hidden_states=(
|
||||
multi_layer_eagle_metadata.cached_hidden_states
|
||||
),
|
||||
cached_slot_mappings=(
|
||||
multi_layer_eagle_metadata.cached_slot_mappings
|
||||
),
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
return (
|
||||
prev_token_ids,
|
||||
prev_positions,
|
||||
prev_hidden_states,
|
||||
common_attn_metadata,
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_draft_tokens: list[int],
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||||
raise Exception(
|
||||
"speculative_config.disable_padded_drafter_batch"
|
||||
" is not supported now for MultiLayerEagleProposer."
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
is_graph_capturing: bool = False,
|
||||
slot_mappings: dict[str, torch.Tensor] | None = None,
|
||||
) -> None:
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens, use_cudagraphs=use_cudagraphs
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
self._draft_attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and next(iter(self._draft_attn_layer_names)) in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
slot_mapping_dict = slot_mappings or {}
|
||||
|
||||
adjust_input_kwargs = {
|
||||
"batch_size": 1,
|
||||
"target_token_ids": self.input_ids[:num_input_tokens],
|
||||
"target_positions": self._get_positions(num_input_tokens),
|
||||
"target_hidden_states": self.hidden_states[:num_input_tokens],
|
||||
"token_indices_to_sample": torch.tensor(
|
||||
[num_input_tokens - 1],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"common_attn_metadata": CommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor(
|
||||
[0, num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
query_start_loc_cpu=torch.tensor(
|
||||
[0, num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
),
|
||||
key_start_loc=torch.tensor(
|
||||
[0, num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
seq_lens=torch.tensor(
|
||||
[num_input_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
seq_lens_np=np.array([num_input_tokens], dtype=np.int32),
|
||||
num_reqs=1,
|
||||
num_actual_tokens=num_input_tokens,
|
||||
max_query_len=self.num_speculative_tokens + 1,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_table_tensor=torch.tensor(
|
||||
[], dtype=torch.int32, device=self.device
|
||||
),
|
||||
slot_mapping=self.arange[:num_input_tokens],
|
||||
logits_indices_padded=None,
|
||||
num_logits_indices=None,
|
||||
causal=True,
|
||||
encoder_seq_lens=None,
|
||||
),
|
||||
"multi_layer_eagle_metadata": MultiLayerEagleMetadata.make_dummy(
|
||||
layer_num=self.layer_num,
|
||||
hidden_size=self.hidden_size,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
self.adjust_input(**adjust_input_kwargs)
|
||||
|
||||
for fwd_idx in range(self.layer_num):
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.draft_vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=slot_mapping_dict,
|
||||
):
|
||||
if self.supports_mm_inputs:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": self._get_positions(num_input_tokens),
|
||||
"hidden_states": self.hidden_states[:num_input_tokens],
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"spec_step_idx": fwd_idx,
|
||||
}
|
||||
|
||||
self.model(**model_kwargs)
|
||||
|
||||
|
||||
def _multi_layer_eagle_shift_and_cache(
|
||||
*,
|
||||
batch_size: int,
|
||||
max_shift: int,
|
||||
src_token_ids: torch.Tensor,
|
||||
dst_token_ids: torch.Tensor,
|
||||
src_positions: torch.Tensor,
|
||||
dst_positions: torch.Tensor,
|
||||
src_hidden_states: torch.Tensor,
|
||||
dst_hidden_states: torch.Tensor,
|
||||
src_slot_mapping: torch.Tensor,
|
||||
dst_slot_mapping: torch.Tensor,
|
||||
start_token_indices: torch.Tensor,
|
||||
end_token_indices: torch.Tensor,
|
||||
token_indices_to_sample: torch.Tensor,
|
||||
shift: torch.Tensor,
|
||||
cached_lens: torch.Tensor,
|
||||
cached_prev_token_ids: torch.Tensor,
|
||||
cached_prev_positions: torch.Tensor,
|
||||
cached_prev_hidden_states: torch.Tensor,
|
||||
cached_slot_mappings: torch.Tensor,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
):
|
||||
if batch_size == 0:
|
||||
return
|
||||
|
||||
assert max_shift > 0
|
||||
assert cached_prev_positions.is_contiguous()
|
||||
assert cached_prev_token_ids.is_contiguous()
|
||||
assert cached_prev_hidden_states.is_contiguous()
|
||||
assert cached_slot_mappings.is_contiguous()
|
||||
assert src_hidden_states.is_contiguous()
|
||||
assert dst_hidden_states.is_contiguous()
|
||||
|
||||
if src_slot_mapping.data_ptr() == dst_slot_mapping.data_ptr():
|
||||
src_slot_mapping = src_slot_mapping.clone()
|
||||
|
||||
store_start = torch.maximum(
|
||||
start_token_indices,
|
||||
(token_indices_to_sample + 1 - max_shift),
|
||||
)
|
||||
store_lens = torch.clamp(
|
||||
token_indices_to_sample - store_start + 1,
|
||||
min=0,
|
||||
max=max_shift,
|
||||
)
|
||||
|
||||
max_window_len = int(
|
||||
(
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
.max()
|
||||
.item()
|
||||
)
|
||||
num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS)
|
||||
|
||||
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
|
||||
src_token_ids,
|
||||
dst_token_ids,
|
||||
cached_prev_token_ids,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
)
|
||||
|
||||
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
|
||||
src_slot_mapping,
|
||||
dst_slot_mapping,
|
||||
cached_slot_mappings,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
)
|
||||
|
||||
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
|
||||
src_positions,
|
||||
dst_positions,
|
||||
cached_prev_positions,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
)
|
||||
|
||||
hidden_size = int(dst_hidden_states.shape[1])
|
||||
num_hidden_blocks = max(
|
||||
1, (hidden_size + BLOCK_HIDDEN - 1) // BLOCK_HIDDEN
|
||||
)
|
||||
|
||||
_shift_and_gather_hidden_kernel[
|
||||
(batch_size, num_blocks, num_hidden_blocks)
|
||||
](
|
||||
src_hidden_states,
|
||||
dst_hidden_states,
|
||||
cached_prev_hidden_states,
|
||||
start_token_indices,
|
||||
end_token_indices,
|
||||
shift,
|
||||
cached_lens,
|
||||
store_start,
|
||||
store_lens,
|
||||
MAX_SHIFT=max_shift,
|
||||
PADDED_SHIFT=triton.next_power_of_2(max_shift),
|
||||
HIDDEN_SIZE=hidden_size,
|
||||
BLOCK_TOKENS=BLOCK_TOKENS,
|
||||
BLOCK_HIDDEN=BLOCK_HIDDEN,
|
||||
num_warps=4,
|
||||
)
|
||||
|
||||
cached_lens.copy_(store_lens)
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _shift_and_gather_cache_1d_kernel(
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
cached_ptr,
|
||||
start_ptr,
|
||||
end_ptr,
|
||||
shift_ptr,
|
||||
cached_len_ptr,
|
||||
store_start_ptr,
|
||||
store_len_ptr,
|
||||
MAX_SHIFT: tl.constexpr,
|
||||
PADDED_SHIFT: tl.constexpr,
|
||||
BLOCK_TOKENS: tl.constexpr,
|
||||
):
|
||||
# Per-sequence "shift + gather" for packed 1D arrays (token ids, positions,
|
||||
# slot mappings, ...).
|
||||
#
|
||||
# For a single sequence (0-based index i within its window):
|
||||
# - Prefix (i < shift):
|
||||
# dst[start + i] = cached[cached_len - shift + i]
|
||||
# - Body (i >= shift):
|
||||
# dst[start + i] = src[start + i - shift]
|
||||
pid_seq = tl.program_id(0)
|
||||
pid_blk = tl.program_id(1)
|
||||
|
||||
start = tl.load(start_ptr + pid_seq).to(tl.int32)
|
||||
end = tl.load(end_ptr + pid_seq).to(tl.int32)
|
||||
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
|
||||
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
|
||||
|
||||
assert cached_len >= shift
|
||||
|
||||
base = pid_blk * BLOCK_TOKENS
|
||||
k = tl.arange(0, BLOCK_TOKENS)
|
||||
offs = base + k
|
||||
dst_idx = start + offs
|
||||
|
||||
window_len = end - start + 1
|
||||
mask = offs < window_len
|
||||
|
||||
base_cached = cached_ptr + pid_seq * MAX_SHIFT
|
||||
cached_idx = cached_len - shift + offs
|
||||
cached_mask = offs < shift
|
||||
val_cached = tl.load(
|
||||
base_cached + cached_idx, mask=mask & cached_mask, other=0
|
||||
)
|
||||
|
||||
src_idx = start + offs - shift
|
||||
val_src = tl.load(src_ptr + src_idx, mask=mask & ~cached_mask, other=0)
|
||||
|
||||
val = tl.where(cached_mask, val_cached, val_src)
|
||||
tl.store(dst_ptr + dst_idx, val, mask=mask)
|
||||
|
||||
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
|
||||
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
|
||||
m = tl.arange(0, PADDED_SHIFT)
|
||||
store_mask = m < MAX_SHIFT
|
||||
dst_idx = store_start + m
|
||||
val = tl.load(
|
||||
dst_ptr + dst_idx, mask=store_mask & (m < store_len), other=0
|
||||
)
|
||||
tl.store(base_cached + m, val, mask=store_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _shift_and_gather_hidden_kernel(
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
cached_ptr,
|
||||
start_ptr,
|
||||
end_ptr,
|
||||
shift_ptr,
|
||||
cached_len_ptr,
|
||||
store_start_ptr,
|
||||
store_len_ptr,
|
||||
MAX_SHIFT: tl.constexpr,
|
||||
PADDED_SHIFT: tl.constexpr,
|
||||
HIDDEN_SIZE: tl.constexpr,
|
||||
BLOCK_TOKENS: tl.constexpr,
|
||||
BLOCK_HIDDEN: tl.constexpr,
|
||||
):
|
||||
# Per-sequence "shift + gather" for hidden states.
|
||||
# Layout:
|
||||
# - src_ptr / dst_ptr: [num_tokens, hidden_size]
|
||||
# - cached_ptr: [batch_size, MAX_SHIFT, hidden_size]
|
||||
pid_seq = tl.program_id(0)
|
||||
pid_blk = tl.program_id(1)
|
||||
pid_hid = tl.program_id(2)
|
||||
|
||||
start = tl.load(start_ptr + pid_seq).to(tl.int32)
|
||||
end = tl.load(end_ptr + pid_seq).to(tl.int32)
|
||||
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
|
||||
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
|
||||
|
||||
assert cached_len >= shift
|
||||
|
||||
base = pid_blk * BLOCK_TOKENS
|
||||
k = tl.arange(0, BLOCK_TOKENS)
|
||||
tok_offs = base + k
|
||||
dst_tok = start + tok_offs
|
||||
n = pid_hid * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)
|
||||
dst_ptrs = dst_ptr + dst_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
|
||||
window_len = end - start + 1
|
||||
tok_mask = tok_offs < window_len
|
||||
n_mask = n < HIDDEN_SIZE
|
||||
mask = tok_mask[:, None] & n_mask[None, :]
|
||||
|
||||
base_cached = cached_ptr + pid_seq * HIDDEN_SIZE * MAX_SHIFT
|
||||
cached_tok = cached_len - shift + tok_offs
|
||||
cached_ptrs = (
|
||||
base_cached + cached_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
)
|
||||
cached_mask = tok_offs < shift
|
||||
val_cached = tl.load(
|
||||
cached_ptrs, mask=mask & cached_mask[:, None], other=0
|
||||
)
|
||||
|
||||
src_tok = start + tok_offs - shift
|
||||
src_ptrs = src_ptr + src_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
val_src = tl.load(src_ptrs, mask=mask & ~cached_mask[:, None], other=0)
|
||||
|
||||
val = tl.where(cached_mask[:, None], val_cached, val_src)
|
||||
tl.store(dst_ptrs, val, mask=mask)
|
||||
|
||||
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
|
||||
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
|
||||
m = tl.arange(0, PADDED_SHIFT)
|
||||
m_mask = (m < MAX_SHIFT) & (m < store_len)
|
||||
store_tok = store_start + m
|
||||
dst_ptrs = dst_ptr + store_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
store_ptrs = (
|
||||
base_cached + m[:, None] * HIDDEN_SIZE + n[None, :] * 1
|
||||
)
|
||||
mask = m_mask[:, None] & n_mask[None, :]
|
||||
val = tl.load(dst_ptrs, mask=mask, other=0)
|
||||
tl.store(store_ptrs, val, mask=mask)
|
||||
Reference in New Issue
Block a user