Files

369 lines
15 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import VllmConfig, replace
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
)
PADDING_SLOT_ID = -1
@triton.jit
def eagle_prepare_inputs_padded_kernel(
cu_num_draft_tokens_ptr, # [num_reqs]
valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output)
num_rejected_tokens_gpu_ptr, # [num_reqs] (output)
num_reqs, # tl.int32
):
"""
Fused kernel for Eagle prepare_input_padded. This kernel computes the
token index to sample for each request, taking into account the number
of draft tokens and the number of valid sampled tokens (which is one more than
the number of accepted tokens).
"""
req_idx = tl.program_id(axis=0)
if req_idx >= num_reqs:
return
# Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive
# cumulative sum (first entry is the first value, not zero).
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = 0
if req_idx == 0:
num_draft_tokens = cu_draft_curr
else:
cu_draft_prev = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
num_draft_tokens = cu_draft_curr - cu_draft_prev
valid_count = tl.load(valid_sampled_tokens_count_ptr + req_idx)
num_rejected_tokens = num_draft_tokens + 1 - valid_count
num_rejected_tokens = tl.where(num_draft_tokens > 0, num_rejected_tokens, 0)
# query_start_loc[req_idx + 1] is the start position of the next request,
# which is one past the last token of this request.
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + req_idx + 1) - 1
index_to_sample = q_last_tok_idx - num_rejected_tokens
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
tl.store(num_rejected_tokens_gpu_ptr + req_idx, num_rejected_tokens)
@triton.jit
def eagle_prepare_next_token_padded_kernel(
sampled_token_ids_ptr, # [num_reqs, num_sampled_tokens_per_req]
discard_request_mask_ptr, # [num_reqs]
backup_next_token_ids_ptr, # [num_reqs]
next_token_ids_ptr, # [num_reqs] (output)
valid_sampled_tokens_count_ptr, # [num_reqs] (output)
vocab_size, # tl.int32
num_sampled_tokens_per_req, # tl.int32 (num_spec_tokens + 1)
num_reqs, # tl.int32
stride_sampled_token_ids, # tl.int32 (stride for dim 0)
BLOCK_SIZE_TOKENS: tl.constexpr, # Power-of-2 >= num_sampled_tokens_per_req
):
"""
Fused kernel for Eagle prepare_next_token_ids_padded. This kernel computes the
number of valid (1 + accepted) tokens for each request, and the corresponding
"next" token id to sample from during speculative decoding. This is the
"last accepted token" from the sampled tokens, or the backup token if no
tokens were accepted or if the request is marked as discarded.
"""
req_idx = tl.program_id(axis=0)
if req_idx >= num_reqs:
return
# Check if this request is discarded.
is_discarded = tl.load(discard_request_mask_ptr + req_idx)
if is_discarded:
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
valid_count = tl.full((), 0, dtype=tl.uint32)
tl.store(next_token_ids_ptr + req_idx, backup_token)
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
else:
# Count the number of valid tokens among the sampled tokens.
token_offs = tl.arange(0, BLOCK_SIZE_TOKENS)
token_mask = token_offs < num_sampled_tokens_per_req
row_ptr = sampled_token_ids_ptr + req_idx * stride_sampled_token_ids
token_ids = tl.load(row_ptr + token_offs, mask=token_mask, other=-1)
# Rejected tokens are -1, valid tokens are in [0, vocab_size)
is_valid_mask = (token_ids != -1) & (token_ids < vocab_size) & token_mask
valid_count = tl.sum(is_valid_mask)
if valid_count > 0:
# Guaranteed to be well-defined since
# valid_count > 0 implies is_valid_mask is not empty
last_valid_index = tl.max(tl.where(is_valid_mask, token_offs, -1))
# Select the token at that index, using a sum trick since
# we don't want to load again to access token_ids[last_valid_index].
last_valid_token = tl.sum(
tl.where(token_offs == last_valid_index, token_ids, 0)
)
tl.store(next_token_ids_ptr + req_idx, last_valid_token)
else:
# No valid tokens found, use backup token
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
tl.store(next_token_ids_ptr + req_idx, backup_token)
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
def compute_new_slot_mapping(
cad: CommonAttentionMetadata,
new_positions: torch.Tensor,
is_rejected_token_mask: torch.Tensor,
block_size: int,
num_new_tokens: int,
max_model_len: int,
):
batch_size, n_blocks_per_req = cad.block_table_tensor.shape
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
req_indices = torch.repeat_interleave(
req_indices,
cad.naive_query_lens() + num_new_tokens,
output_size=len(new_positions),
)
# Clamp the positions to prevent an out-of-bounds error when indexing
# into block_table_tensor.
clamped_positions = torch.clamp(new_positions, max=max_model_len - 1)
block_table_indices = (
req_indices * n_blocks_per_req + clamped_positions // block_size
)
block_nums = cad.block_table_tensor.view(-1)[block_table_indices]
block_offsets = clamped_positions % block_size
new_slot_mapping = block_nums * block_size + block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len = new_positions >= max_model_len
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID)
return new_slot_mapping
def create_vllm_config_for_draft_model(
target_model_vllm_config: VllmConfig,
) -> VllmConfig:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the drafter.
The vllm_config is useful when loading the draft model with get_model().
This helper returns the original target config for the common case and only
rewrites rank/parallel info when the drafter is configured to run locally
on the last target PP stage. This keeps runtime behavior unchanged for the
common case while still handling PP rank remapping.
"""
old = target_model_vllm_config
assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
needs_rank_remap = old_spec_config.needs_partial_pp_draft_remap(old.parallel_config)
if not needs_rank_remap:
return old
draft_rank = old_spec_config.resolve_partial_pp_draft_rank(old.parallel_config)
new_parallel_config = replace(
old_spec_config.draft_parallel_config, rank=draft_rank
)
new: VllmConfig = replace(
old,
quant_config=None,
parallel_config=new_parallel_config,
model_config=old_spec_config.draft_model_config,
)
return new
def extend_all_queries_by_N(
common_attn_metadata: CommonAttentionMetadata,
N: int,
arange: torch.Tensor,
new_slot_mapping: torch.Tensor,
) -> CommonAttentionMetadata:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by N.
Also all seq lens are increased by N.
This is useful e.g. in speculative decoding with parallel drafting, where we
extend each sequence by N tokens and predict all tokens in one pass.
The slot mapping is computed externally, as it requires more information.
"""
cad = common_attn_metadata
# query start loc must be increased by [+0, +N, +2N, ..., +batch_size * N]
new_query_start_loc = cad.query_start_loc + N * arange[: len(cad.query_start_loc)]
new_query_start_loc_cpu = cad.query_start_loc_cpu + N * torch.arange(
len(cad.query_start_loc_cpu), dtype=torch.int32
)
new_cad = cad.replace(
query_start_loc=new_query_start_loc,
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=cad.seq_lens + N,
# each request is extended by N tokens -> batch_size * N tokens are added
num_actual_tokens=cad.num_actual_tokens + cad.batch_size() * N,
# All query lens increase by N, so max query len increases by N
max_query_len=cad.max_query_len + N,
max_seq_len=cad.max_seq_len + N,
slot_mapping=new_slot_mapping,
)
return new_cad
# Unified copy/expand kernel
@triton.jit
def copy_and_expand_eagle_inputs_kernel(
# (Padded) Inputs from the target model
target_token_ids_ptr, # [total_tokens_in_batch]
target_positions_ptr, # [total_tokens_in_batch]
next_token_ids_ptr, # [num_reqs]
# Outputs to the drafting buffers
out_input_ids_ptr, # [total_draft_tokens_in_batch] (output)
out_positions_ptr, # [total_draft_tokens_in_batch] (output)
out_is_rejected_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
out_is_masked_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
out_new_token_indices_ptr, # [num_padding_slots_per_request * num_reqs] (output)
out_hidden_state_mapping_ptr, # [total_tokens_in_batch]
# Input metadata
query_start_loc_ptr, # [num_reqs + 1], last value is the total num input tokens
query_end_loc_ptr, # [num_reqs]
padding_token_id, # tl.int32
parallel_drafting_token_id, # tl.int32
# Sizing info
total_input_tokens, # tl.int32
num_padding_slots_per_request, # tl.int32
shift_input_ids, # tl.bool
BLOCK_SIZE_TOKENS: tl.constexpr, # Blocks along token dim to handle prefills
):
"""
Copy and expand inputs from the target model to the drafting buffers for Eagle
speculative decoding. This kernel handles padding slots and parallel drafting
tokens, if enabled.
"""
request_idx = tl.program_id(axis=0)
token_batch_idx = tl.program_id(axis=1)
# Load query locations
query_start_loc = tl.load(query_start_loc_ptr + request_idx)
next_query_start_loc = tl.load(query_start_loc_ptr + request_idx + 1)
query_end_loc = tl.load(query_end_loc_ptr + request_idx)
# Calculate number of valid tokens to copy and input offset
# With shift_input_ids=True, we skip the first token
# Output layout: each request gets (input_len + num_padding_slots_per_request) slots
# But with shift, we lose one token per request
if shift_input_ids:
num_valid_tokens = query_end_loc - query_start_loc
input_offset = 1
output_start = query_start_loc + request_idx * (
num_padding_slots_per_request - 1
)
else:
num_valid_tokens = query_end_loc - query_start_loc + 1
input_offset = 0
output_start = query_start_loc + request_idx * num_padding_slots_per_request
# Number of rejected tokens from previous speculation
num_rejected = next_query_start_loc - query_end_loc - 1
# Total output tokens for this request
total_output_tokens = (
num_valid_tokens + num_padding_slots_per_request + num_rejected
)
# Process tokens in this block
j = token_batch_idx * BLOCK_SIZE_TOKENS + tl.arange(0, BLOCK_SIZE_TOKENS)
# Compute masks for different output regions:
# [0, num_valid_tokens): valid tokens copied from input
# [num_valid_tokens]: bonus token from next_token_ids
# (num_valid_tokens, num_valid_tokens + num_padding_slots_per_request):
# parallel drafting slots
# [num_valid_tokens + num_padding_slots_per_request, total_output_tokens):
# rejected slots
in_bounds = j < total_output_tokens
is_valid_region = j < num_valid_tokens
is_bonus_region = j == num_valid_tokens
is_parallel_draft_region = (j > num_valid_tokens) & (
j < num_valid_tokens + num_padding_slots_per_request
)
is_rejected_region = j >= num_valid_tokens + num_padding_slots_per_request
# Compute output indices
out_idx = output_start + j
# For valid tokens, compute input index
in_idx = query_start_loc + input_offset + j
# Clamp to avoid out-of-bounds access (masked loads still need valid addresses)
in_idx_clamped = tl.minimum(in_idx, total_input_tokens - 1)
# Load input tokens (masked to valid region)
token_ids = tl.load(
target_token_ids_ptr + in_idx_clamped, mask=is_valid_region & in_bounds, other=0
)
# Load the starting position for this request (first position in the sequence)
start_pos = tl.load(target_positions_ptr + query_start_loc)
# Load bonus token for this request
bonus_token = tl.load(next_token_ids_ptr + request_idx)
# Build final token_ids based on region
token_ids = tl.where(is_bonus_region, bonus_token, token_ids)
token_ids = tl.where(
is_parallel_draft_region, parallel_drafting_token_id, token_ids
)
token_ids = tl.where(is_rejected_region, padding_token_id, token_ids)
# Build final positions:
# Positions are NOT shifted - they start from the first input position and increment
# Output position j gets start_pos + j
# (e.g., input positions [5,6,7] -> output [5,6,7,8,9,...])
positions = start_pos + j
# Rejected positions are don't-care, set to 0
positions = tl.where(is_rejected_region, 0, positions)
# Compute output masks
is_rejected_out = is_rejected_region & in_bounds
is_masked_out = is_parallel_draft_region & in_bounds
# Compute indices of new tokens (bonus + parallel drafting) for sampling
# New tokens are at positions
# [num_valid_tokens, num_valid_tokens + num_padding_slots_per_request)
is_new_token_region = (j >= num_valid_tokens) & (
j < num_valid_tokens + num_padding_slots_per_request
)
new_token_local_idx = (
j - num_valid_tokens
) # 0 for bonus, 1, 2, ... for parallel drafting
new_token_out_idx = (
request_idx * num_padding_slots_per_request + new_token_local_idx
)
# Compute hidden state mapping (source index -> destination index)
# This maps each input position to its corresponding output position
# Hidden states don't get shifted, so we map all input tokens (including rejected)
if shift_input_ids:
num_input_tokens_this_request = next_query_start_loc - query_start_loc
is_input_region = j < num_input_tokens_this_request
src_idx = query_start_loc + j
tl.store(out_hidden_state_mapping_ptr + src_idx, out_idx, mask=is_input_region)
# Store outputs
tl.store(out_input_ids_ptr + out_idx, token_ids, mask=in_bounds)
tl.store(out_positions_ptr + out_idx, positions, mask=in_bounds)
tl.store(out_is_rejected_token_mask_ptr + out_idx, is_rejected_out, mask=in_bounds)
tl.store(out_is_masked_token_mask_ptr + out_idx, is_masked_out, mask=in_bounds)
tl.store(
out_new_token_indices_ptr + new_token_out_idx,
out_idx,
mask=is_new_token_region & in_bounds,
)