Files
sglang/python/sglang/srt/layers/attention/flashinfer_backend.py
2025-10-14 23:52:04 +08:00

1580 lines
62 KiB
Python

from __future__ import annotations
"""
Support different attention backends.
Now there are two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
import logging
import os
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch
from sglang.srt.environ import envs
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
get_int_env_var,
is_flashinfer_available,
is_sm100_supported,
next_power_of_2,
)
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
logger = logging.getLogger(__name__)
if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True
if is_flashinfer_available():
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
fast_decode_plan,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import _get_range_buf, get_seq_lens
class WrapperDispatch(Enum):
SLIDING_WINDOW = auto()
CROSS_ATTENTION = auto()
@dataclass
class MultiItemScoringParams:
"""Parameters for multi-item scoring in attention computation.
Used when processing sequences with multiple items separated by delimiters,
where each item needs specific attention patterns that respect item boundaries.
Attributes:
prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
The tensor size is equal to the batch size.
token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
starting from 0 (delimiter) for each item. For batch size > 1,
sequences are concatenated with zero padding to ensure same length.
token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
batch_size > 1 case. Defines the padded length for each sequence.
max_item_len_ptr: A uint16 tensor containing the max token length of all items
for each prompt in the batch.
"""
prefix_len_ptr: Optional[torch.Tensor] = None
token_pos_in_items_ptr: Optional[torch.Tensor] = None
token_pos_in_items_len: int = 0
max_item_len_ptr: Optional[torch.Tensor] = None
def is_enabled(self) -> bool:
"""Check if multi-item scoring is enabled."""
return self.prefix_len_ptr is not None
@dataclass
class DecodeMetadata:
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
@dataclass
class PrefillMetadata:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
use_ragged: bool
extend_no_prefix: bool
multi_item_params: Optional[MultiItemScoringParams] = None
# Reuse this workspace buffer across all flashinfer wrappers
global_workspace_buffer = None
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global_override_indptr_cpu = None
class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels."""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None,
):
super().__init__()
# Store multi-item scoring delimiter for efficient access
self.multi_item_scoring_delimiter = (
model_runner.server_args.multi_item_scoring_delimiter
)
# Parse constants
self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype,
num_attention_heads=model_runner.model_config.num_attention_heads
// get_attention_tp_size(),
num_kv_heads=model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
),
)
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal
assert not (
model_runner.sliding_window_size is not None
and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together"
if model_runner.sliding_window_size is not None:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
elif model_runner.model_config.is_encoder_decoder:
self.num_wrappers = 2
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
else:
self.num_wrappers = 1
self.dispatch_reason = None
# Qwen2/Qwen3 models require higher flashinfer workspace size
if (
"Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
):
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(512 * 1024 * 1024)
# When deterministic inference is enabled, tensor cores should be used for decode
# Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
# More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675
self.enable_deterministic = (
model_runner.server_args.enable_deterministic_inference
)
self.prefill_split_tile_size = None
self.decode_split_tile_size = None
self.disable_cuda_graph_kv_split = False
if self.enable_deterministic:
self.decode_use_tensor_cores = True
self.prefill_split_tile_size = get_int_env_var(
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
)
self.decode_split_tile_size = get_int_env_var(
"SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
)
self.disable_cuda_graph_kv_split = True
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(2048 * 1024 * 1024)
# Allocate buffers
global global_workspace_buffer
if global_workspace_buffer is None:
# different from flashinfer zero_init_global_workspace_buffer
global_workspace_size = envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get()
global_workspace_buffer = torch.empty(
global_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = [
torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
for _ in range(self.num_wrappers)
]
else:
assert self.num_wrappers == 1
self.kv_indptr = [kv_indptr_buf]
if kv_last_page_len_buf is None:
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
else:
assert self.num_wrappers == 1
self.kv_last_page_len = kv_last_page_len_buf
if not self.skip_prefill:
self.qo_indptr = [
torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
for _ in range(self.num_wrappers)
]
fmha_backend = "auto"
if is_sm100_supported():
fmha_backend = "cutlass"
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD", backend=fmha_backend
)
# Two wrappers: one for sliding window attention and one for full attention.
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
self.prefill_wrappers_paged = []
self.prefill_wrappers_verify = []
self.decode_wrappers = []
for _ in range(self.num_wrappers):
if not skip_prefill:
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
)
)
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=self.decode_use_tensor_cores,
)
)
# Create indices updater
if not skip_prefill:
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self
) # for verify
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {} # For verify
self.draft_extend_cuda_graph_metadata = {} # For draft extend
def _process_multi_item_scoring(
self, forward_batch: ForwardBatch
) -> MultiItemScoringParams:
"""Process multi-item scoring tensors for FlashInfer attention.
This method handles sequences containing multiple "items" separated by delimiter tokens,
where each item needs specific attention patterns that respect item boundaries.
The method produces four key tensors for FlashInfer:
- prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
- token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
- token_pos_in_items_len: padding length for batch processing
- max_item_len_ptr: uint16 tensor with max item length for each prompt
Args:
forward_batch: The forward batch containing input sequences and delimiter info
Returns:
MultiItemScoringParams: The processed multi-item scoring parameters
Examples:
Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
Case 1: Single sequence
Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
- prefix_len_ptr: [7] (query length before first delimiter)
- token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
- token_pos_in_items_len: 7 (actual length)
- max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
Case 2: Batch processing (batch_size=2)
Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
After padding both to length 10:
- token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
- token_pos_in_items_len: 10 (padded length for batch processing)
- max_item_len_ptr: [2, 3] (max lengths per sequence)
"""
delimiter = self.multi_item_scoring_delimiter
if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
return MultiItemScoringParams()
delimiter_mask = forward_batch.input_ids == delimiter
prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
prefix_len_ptr, token_pos_in_items_ptr = [], []
token_pos_in_items_len = 0
# If no extend_seq_lens, treat whole batch as one sequence
if extend_seq_lens is None or len(extend_seq_lens) <= 1:
extend_seq_lens = [forward_batch.input_ids.size(0)]
seq_start = 0
for i, seq_len in enumerate(extend_seq_lens):
seq_end = seq_start + seq_len
mask = delimiter_mask[seq_start:seq_end]
pos = forward_batch.positions[seq_start:seq_end]
delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
if len(delimiter_indices) > 0:
first_delim = delimiter_indices[0]
# Prefix length: store as scalar
prefix_len = first_delim + (
prefix_cache_lens[i] if prefix_cache_lens is not None else 0
)
prefix_len_ptr.append(
prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
)
# Compute relative positions within items after delimiters
diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
token_pos = (diff - pos[first_delim]).to(torch.uint16)
token_pos_in_items_ptr.append(token_pos)
# Update forward_batch positions in-place
pos[first_delim:] = diff - 1
forward_batch.positions[seq_start:seq_end] = pos
seq_start = seq_end
# Pad token_pos_in_items_ptr for batch processing
if token_pos_in_items_ptr:
token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
device = forward_batch.input_ids.device
token_pos_in_items_ptr = [
torch.cat(
[
t,
torch.zeros(
token_pos_in_items_len - t.numel(),
dtype=torch.uint16,
device=device,
),
]
)
for t in token_pos_in_items_ptr
]
if not prefix_len_ptr or not token_pos_in_items_ptr:
return MultiItemScoringParams()
# Build final params
device = forward_batch.input_ids.device
return MultiItemScoringParams(
prefix_len_ptr=torch.tensor(
prefix_len_ptr, dtype=torch.uint32, device=device
),
token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
max_item_len_ptr=torch.stack(
[
t.to(torch.int32).max().to(torch.uint16)
for t in token_pos_in_items_ptr
],
dim=0,
),
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum,
decode_wrappers=self.decode_wrappers,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
fixed_split_size=self.decode_split_tile_size,
disable_split_kv=False,
)
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrappers=self.prefill_wrappers_paged,
use_ragged=False,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_paged, False, False
)
elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrappers=self.prefill_wrappers_verify,
use_ragged=False,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_verify, False, False
)
else:
prefix_lens = forward_batch.extend_prefix_lens
# Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
# use_ragged = False: Multi-item scoring requires the paged wrapper because:
# 1. Ragged wrapper doesn't support the specialized multi-item parameters
# (prefix_len_ptr, token_pos_in_items_ptr, etc.)
# 2. Paged wrapper provides better control over attention masking needed
# for respecting item boundaries in multi-item sequences
# 3. Custom masking logic conflicts with ragged wrapper's assumptions
use_ragged = False
extend_no_prefix = False
else:
use_ragged = not self.enable_deterministic
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
# Process multi-item scoring in attention backend instead of ForwardBatch
multi_item_params = MultiItemScoringParams()
if self.multi_item_scoring_delimiter is not None:
# Use new backend-specific implementation
multi_item_params = self._process_multi_item_scoring(forward_batch)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum,
prefix_lens,
prefill_wrappers=self.prefill_wrappers_paged,
use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
fixed_split_size=self.prefill_split_tile_size,
multi_item_params=multi_item_params,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_paged,
use_ragged,
extend_no_prefix,
multi_item_params,
)
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
(max_num_tokens * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = kv_indices_buf
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
]
# Ensure tensors are properly allocated
for i in range(self.num_wrappers):
# Force allocation by performing a small operation
if len(self.cuda_graph_kv_indices[i]) > 0:
self.cuda_graph_kv_indices[i][0] = 0
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_num_tokens * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
for i in range(self.num_wrappers):
decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[
:num_tokens
],
)
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices,
seq_lens,
seq_lens.cpu(), # may add a little overhead in capture stage
seq_lens_sum,
decode_wrappers=decode_wrappers,
encoder_lens=encoder_lens,
spec_info=spec_info,
fixed_split_size=None,
disable_split_kv=self.disable_cuda_graph_kv_split,
)
self.decode_cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = DecodeMetadata(decode_wrappers)
for i in range(self.num_wrappers):
decode_wrappers[i].begin_forward = partial(
fast_decode_plan, decode_wrappers[i]
)
elif forward_mode.is_target_verify():
prefill_wrappers = []
for i in range(self.num_wrappers):
prefill_wrappers.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
custom_mask_buf=self.cuda_graph_custom_mask,
mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
)
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens.cpu(), # may add a little overhead in capture stage
seq_lens_sum,
prefix_lens=None,
prefill_wrappers=prefill_wrappers,
use_ragged=False,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
elif forward_mode.is_draft_extend():
prefill_wrappers = []
for i in range(self.num_wrappers):
prefill_wrappers.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="fa2",
use_cuda_graph=True,
qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
)
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens.cpu(), # may add a little overhead in capture stage
seq_lens_sum,
prefix_lens=None,
prefill_wrappers=prefill_wrappers,
use_ragged=False,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
seq_lens_sum,
decode_wrappers=self.decode_cuda_graph_metadata[bs],
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
fixed_split_size=None,
disable_split_kv=self.disable_cuda_graph_kv_split,
)
elif forward_mode.is_target_verify():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
seq_lens_sum,
prefix_lens=None,
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
use_ragged=False,
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
)
elif forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
seq_lens_sum,
prefix_lens=None,
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
use_ragged=False,
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
)
else:
raise ValueError("Invalid forward mode")
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
logits_soft_cap = layer.logit_cap
q = q.contiguous()
if not self.forward_metadata.use_ragged:
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = prefill_wrapper_paged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
# Disable sliding window attention for multi-item scoring:
# - Sliding window could cut across item boundaries, breaking semantic coherence
# - Multi-item sequences need full attention to properly handle delimiter tokens
# - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
# provide more precise attention control than simple sliding windows
# - Item-aware masking takes precedence over window-based masking
window_left=(
layer.sliding_window_size
if not (
self.forward_metadata.multi_item_params
and self.forward_metadata.multi_item_params.is_enabled()
)
else -1
),
logits_soft_cap=logits_soft_cap,
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
k_scale=layer.k_scale_float,
v_scale=layer.v_scale_float,
)
else:
causal = True
if (
layer.is_cross_attention
or layer.attn_type == AttentionType.ENCODER_ONLY
):
causal = False
if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY:
save_kv_cache = False
if self.forward_metadata.extend_no_prefix:
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
# The FlashInfer head_dim limitation itself is tracked here:
# https://github.com/flashinfer-ai/flashinfer/issues/1048
o = self.prefill_wrapper_ragged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=causal,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
o, _ = merge_state(o1, s1, o2, s2)
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
):
decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer)
]
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
# Call the wrapped function
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
k_scale=layer.k_scale_float,
v_scale=layer.v_scale_float,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def _get_wrapper_idx(self, layer: RadixAttention):
if self.num_wrappers == 1:
return 0
if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
return layer.sliding_window_size == -1
if self.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
return layer.is_cross_attention
raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}")
class FlashInferIndicesUpdaterDecode:
def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
# Parse Constants
self.num_qo_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
# Dispatch the update function
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention
else:
assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper
def update(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
def update_single_wrapper(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward(
decode_wrappers[0],
req_pool_indices,
seq_lens,
seq_lens_sum,
self.kv_indptr[0],
None,
spec_info,
seq_lens_cpu,
fixed_split_size=fixed_split_size,
disable_split_kv=disable_split_kv,
)
def update_sliding_window(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
assert self.sliding_window_size is not None
for wrapper_id in range(2):
if wrapper_id == 0:
# Sliding window attention
paged_kernel_lens_tmp = torch.clamp(
seq_lens, max=self.sliding_window_size + 1
)
if seq_lens_cpu is not None:
seq_lens_cpu_tmp = torch.clamp(
seq_lens_cpu, max=self.sliding_window_size + 1
)
paged_kernel_lens_sum_tmp = seq_lens_cpu_tmp.sum().item()
else:
paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
else:
# Full attention
paged_kernel_lens_tmp = seq_lens
paged_kernel_lens_sum_tmp = seq_lens_sum
seq_lens_cpu_tmp = seq_lens_cpu
kv_start_idx_tmp = None
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.call_begin_forward(
decode_wrappers[wrapper_id],
req_pool_indices,
paged_kernel_lens_tmp,
paged_kernel_lens_sum_tmp,
self.kv_indptr[wrapper_id],
kv_start_idx_tmp,
spec_info,
seq_lens_cpu=seq_lens_cpu_tmp,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
)
def update_cross_attention(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
for wrapper_id in range(2):
if wrapper_id == 0:
# Normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
else:
# Cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
seq_lens_sum = encoder_lens.sum().item()
self.call_begin_forward(
decode_wrappers[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens_sum,
self.kv_indptr[wrapper_id],
kv_start_idx,
spec_info,
seq_lens_cpu=seq_lens_cpu,
)
def call_begin_forward(
self,
wrapper: BatchDecodeWithPagedKVCacheWrapper,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
if spec_info is None:
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
if wrapper.is_cuda_graph_enabled:
# Directly write to the cuda graph input buffer
kv_indices = wrapper._paged_kv_indices_buf
else:
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
if use_sliding_window_kv_pool:
kv_last_index = kv_indptr[-1]
kv_indices[:kv_last_index] = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
kv_indices[:kv_last_index]
)
)
global global_override_indptr_cpu
locally_override = False
if seq_lens_cpu is not None and global_override_indptr_cpu is None:
locally_override = True
global_override_indptr_cpu = torch.empty_like(kv_indptr, device="cpu")
global_override_indptr_cpu[0] = 0
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
# Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan
# by checking if it's a partial function with fast_decode_plan as the func
wrapper_uses_fast_decode_plan = (
hasattr(wrapper.begin_forward, "func")
and wrapper.begin_forward.func == fast_decode_plan
)
if wrapper_uses_fast_decode_plan:
# When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu
wrapper.begin_forward(
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
non_blocking=True,
fixed_split_size=fixed_split_size,
disable_split_kv=(
disable_split_kv if disable_split_kv is not None else False
),
global_override_indptr_cpu=global_override_indptr_cpu,
)
else:
# When using original begin_forward, don't pass global_override_indptr_cpu
wrapper.begin_forward(
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
non_blocking=True,
fixed_split_size=fixed_split_size,
disable_split_kv=(
disable_split_kv if disable_split_kv is not None else False
),
)
if locally_override:
global_override_indptr_cpu = None
class FlashInferIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
# Parse Constants
self.num_qo_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.head_dim = model_runner.model_config.head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend
# Buffers and wrappers
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
# Dispatch the update function
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
self.update = self.update_cross_attention
else:
assert self.attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper
def update(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
def update_single_wrapper(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
if use_ragged:
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
# and forward_batch.extend_seq_lens_cpu
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else:
paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
self.call_begin_forward(
self.prefill_wrapper_ragged,
prefill_wrappers[0],
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens,
prefix_lens,
None,
self.kv_indptr[0],
self.qo_indptr[0],
use_ragged,
spec_info,
fixed_split_size=fixed_split_size,
multi_item_params=multi_item_params,
)
def update_sliding_window(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
for wrapper_id in range(2):
if wrapper_id == 0:
# window attention use paged only
paged_kernel_lens = torch.minimum(
seq_lens,
torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
)
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else:
# full attention
paged_kernel_lens = seq_lens
paged_kernel_lens_sum = seq_lens_sum
kv_start_idx = seq_lens - paged_kernel_lens
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.call_begin_forward(
self.prefill_wrapper_ragged,
prefill_wrappers[wrapper_id],
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens,
prefix_lens,
kv_start_idx,
self.kv_indptr[wrapper_id],
self.qo_indptr[wrapper_id],
use_ragged,
spec_info,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
multi_item_params=multi_item_params,
)
def update_cross_attention(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
for wrapper_id in range(2):
if wrapper_id == 0:
# normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
paged_kernel_lens_sum = seq_lens_sum
else:
# cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
self.call_begin_forward(
self.prefill_wrapper_ragged,
prefill_wrappers[wrapper_id],
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
seq_lens,
prefix_lens,
kv_start_idx,
self.kv_indptr[wrapper_id],
self.qo_indptr[wrapper_id],
use_ragged,
spec_info,
multi_item_params=multi_item_params,
)
def call_begin_forward(
self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
seq_lens: torch.Tensor,
prefix_lens: torch.Tensor,
kv_start_idx: torch.Tensor,
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[SpecInput],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
bs = len(seq_lens)
if spec_info is None:
assert len(seq_lens) == len(req_pool_indices)
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum + 256,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
kv_indices,
self.req_to_token.shape[1],
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
assert isinstance(spec_info, SpecInput)
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
)
)
# extend part
if use_ragged:
wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
q_data_type=self.q_data_type,
)
if use_sliding_window_kv_pool:
kv_last_index = kv_indptr[-1]
kv_indices[:kv_last_index] = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
kv_indices[:kv_last_index]
)
)
# cached part
# Conditionally set multi-item parameters
if multi_item_params is not None and multi_item_params.is_enabled():
# Multi-item scoring is active - use specialized parameters and disable generic custom_mask
use_custom_mask = None
prefix_len_ptr = multi_item_params.prefix_len_ptr
token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
token_pos_in_items_len = multi_item_params.token_pos_in_items_len
max_item_len_ptr = multi_item_params.max_item_len_ptr
else:
# No multi-item scoring - use standard parameters
use_custom_mask = custom_mask
prefix_len_ptr = None
token_pos_in_items_ptr = None
token_pos_in_items_len = 0
max_item_len_ptr = None
wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
self.kv_last_page_len[:bs],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
1,
q_data_type=self.q_data_type,
kv_data_type=self.data_type,
custom_mask=use_custom_mask,
non_blocking=True,
fixed_split_size=fixed_split_size,
prefix_len_ptr=prefix_len_ptr,
token_pos_in_items_ptr=token_pos_in_items_ptr,
token_pos_in_items_len=token_pos_in_items_len,
max_item_len_ptr=max_item_len_ptr,
)
class FlashInferMultiStepDraftBackend:
"""
Wrap multiple flashinfer attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
self.page_size = model_runner.page_size
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.attn_backends: List[FlashInferAttnBackend] = []
for i in range(self.speculative_num_steps - 1):
self.attn_backends.append(
FlashInferAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
kv_last_page_len_buf=self.kv_last_page_len,
)
)
self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
def common_template(
self,
forward_batch: ForwardBatch,
kv_indices_buffer: torch.Tensor,
call_fn: Callable,
):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)
assert forward_batch.spec_info is not None
assert forward_batch.spec_info.is_draft_input()
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
global global_override_indptr_cpu
for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
global_override_indptr_cpu = indptr_cpu_whole[i]
call_fn(i, forward_batch)
global_override_indptr_cpu = None
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.empty(
(
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,
),
dtype=torch.int32,
device="cuda",
)
def call_fn(i, forward_batch):
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
forward_batch.spec_info.kv_indices = (
forward_batch.spec_info.kv_indices.clone()
)
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
device="cuda",
)
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def should_use_tensor_core(
kv_cache_dtype: torch.dtype,
num_attention_heads: int,
num_kv_heads: int,
) -> bool:
"""
Determine whether to use tensor cores for attention computation.
Args:
kv_cache_dtype: Data type of the KV cache
num_attention_heads: Number of attention heads
num_kv_heads: Number of key/value heads
Returns:
bool: Whether to use tensor cores
"""
# Try to use environment variable first
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
if env_override is not None:
return env_override.lower() == "true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else:
return False
except (ImportError, AttributeError):
pass
# Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads
# For Flashinfer, a GQA group size of at least 4 is needed to efficiently
# use Tensor Cores, as it fuses the head group with the token dimension in MMA.
if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
return True
elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
return gqa_group_size >= 4
else:
return False