Deterministic Mode: Add 1-stage triton kernel for prefill (#11147)
Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com> Co-authored-by: Binyao Jiang <bijiang@linkedin.com>
This commit is contained in:
@@ -64,13 +64,19 @@ class TritonAttnBackend(AttentionBackend):
|
||||
decode_attention_fwd,
|
||||
)
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
||||
build_unified_kv_indices,
|
||||
extend_attention_fwd,
|
||||
extend_attention_fwd_unified,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
|
||||
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
||||
self.extend_attention_fwd_unified = torch.compiler.disable(
|
||||
extend_attention_fwd_unified
|
||||
)
|
||||
self.build_unified_kv_indices = torch.compiler.disable(build_unified_kv_indices)
|
||||
|
||||
# Parse args
|
||||
self.skip_prefill = skip_prefill
|
||||
@@ -794,6 +800,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
# Save KV cache first (must do this before unified kernel)
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
@@ -805,6 +812,13 @@ class TritonAttnBackend(AttentionBackend):
|
||||
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||
causal = False
|
||||
|
||||
# Deterministic mode: use unified 1-stage kernel
|
||||
if self.enable_deterministic:
|
||||
return self._forward_extend_unified(
|
||||
q, o, layer, forward_batch, causal, logits_soft_cap, sinks
|
||||
)
|
||||
|
||||
# Normal mode: use original 2-stage kernel
|
||||
if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
|
||||
sliding_window_size = (
|
||||
layer.sliding_window_size
|
||||
@@ -841,6 +855,127 @@ class TritonAttnBackend(AttentionBackend):
|
||||
)
|
||||
return o
|
||||
|
||||
def _forward_extend_unified(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
causal: bool,
|
||||
logits_soft_cap: float,
|
||||
sinks: Optional[torch.Tensor],
|
||||
):
|
||||
"""
|
||||
Unified 1-stage extend attention for deterministic inference.
|
||||
Both prefix and extend KV are accessed through unified kv_indices.
|
||||
"""
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
# Determine sliding window settings
|
||||
if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
|
||||
sliding_window_size = layer.sliding_window_size
|
||||
# Note: for unified kernel, we use full kv_indptr (not window)
|
||||
prefix_kv_indptr = self.forward_metadata.window_kv_indptr
|
||||
prefix_kv_indices = self.forward_metadata.window_kv_indices
|
||||
# Compute window start positions (absolute position of first key in window)
|
||||
# window_start_pos = seq_len - window_len
|
||||
window_kv_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs]
|
||||
# Handle TARGET_VERIFY mode where extend_prefix_lens might not be set
|
||||
if forward_batch.extend_prefix_lens is not None:
|
||||
window_start_pos = (
|
||||
forward_batch.extend_prefix_lens[:bs] - window_kv_lens
|
||||
)
|
||||
else:
|
||||
# Infer from spec_info: prefix_len = seq_len - draft_token_num
|
||||
if forward_batch.spec_info is not None and hasattr(
|
||||
forward_batch.spec_info, "draft_token_num"
|
||||
):
|
||||
extend_prefix_lens = (
|
||||
forward_batch.seq_lens[:bs]
|
||||
- forward_batch.spec_info.draft_token_num
|
||||
)
|
||||
window_start_pos = extend_prefix_lens - window_kv_lens
|
||||
else:
|
||||
window_start_pos = None
|
||||
else:
|
||||
sliding_window_size = -1
|
||||
prefix_kv_indptr = self.forward_metadata.kv_indptr
|
||||
prefix_kv_indices = self.forward_metadata.kv_indices
|
||||
window_start_pos = None
|
||||
|
||||
# Build unified kv_indices using fused Triton kernel
|
||||
extend_kv_indices = forward_batch.out_cache_loc
|
||||
|
||||
# Handle cases where extend_seq_lens or extend_start_loc might not be set
|
||||
# In speculative decoding, we can infer these from spec_info or compute them
|
||||
if forward_batch.extend_seq_lens is None:
|
||||
# TARGET_VERIFY mode: infer extend_seq_lens from spec_info
|
||||
if forward_batch.spec_info is not None and hasattr(
|
||||
forward_batch.spec_info, "draft_token_num"
|
||||
):
|
||||
draft_token_num = forward_batch.spec_info.draft_token_num
|
||||
extend_seq_lens = torch.full(
|
||||
(bs,), draft_token_num, dtype=torch.int32, device=self.device
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"extend_seq_lens is None but cannot infer from spec_info. "
|
||||
"This should not happen in TARGET_VERIFY mode."
|
||||
)
|
||||
else:
|
||||
extend_seq_lens = forward_batch.extend_seq_lens
|
||||
|
||||
# Check extend_start_loc separately - it might be None even when extend_seq_lens is set
|
||||
if forward_batch.extend_start_loc is None:
|
||||
# Compute extend_start_loc from extend_seq_lens
|
||||
# extend_start_loc[i] = sum(extend_seq_lens[0:i])
|
||||
extend_start_loc = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32, device=self.device),
|
||||
torch.cumsum(extend_seq_lens[:-1], dim=0),
|
||||
]
|
||||
)
|
||||
else:
|
||||
extend_start_loc = forward_batch.extend_start_loc
|
||||
|
||||
unified_kv_indptr, unified_kv_indices, prefix_lens = (
|
||||
self.build_unified_kv_indices(
|
||||
prefix_kv_indptr,
|
||||
prefix_kv_indices,
|
||||
extend_start_loc,
|
||||
extend_seq_lens,
|
||||
extend_kv_indices,
|
||||
bs,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert prefix_lens to int32 for the kernel
|
||||
prefix_lens = prefix_lens.to(torch.int32)
|
||||
|
||||
# Call unified kernel
|
||||
self.extend_attention_fwd_unified(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
||||
self.forward_metadata.qo_indptr,
|
||||
unified_kv_indptr,
|
||||
unified_kv_indices,
|
||||
prefix_lens,
|
||||
self.forward_metadata.max_extend_len,
|
||||
custom_mask=self.forward_metadata.custom_mask,
|
||||
mask_indptr=self.forward_metadata.mask_indptr,
|
||||
sm_scale=layer.scaling,
|
||||
logit_cap=logits_soft_cap,
|
||||
is_causal=causal,
|
||||
sliding_window_size=sliding_window_size,
|
||||
sinks=sinks,
|
||||
window_start_pos=window_start_pos,
|
||||
xai_temperature_len=layer.xai_temperature_len,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -32,12 +32,182 @@ if _is_cuda:
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):
|
||||
"""
|
||||
Get block sizes and configuration for extend attention kernels.
|
||||
|
||||
Args:
|
||||
Lq: Query head dimension
|
||||
Lv: Value head dimension
|
||||
|
||||
Returns:
|
||||
tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps)
|
||||
"""
|
||||
# Determine BLOCK_DMODEL and BLOCK_DPE based on head dimension
|
||||
if Lq == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
elif Lq == 288:
|
||||
BLOCK_DMODEL = 256
|
||||
BLOCK_DPE = 32
|
||||
elif Lq == 192:
|
||||
BLOCK_DMODEL = 128
|
||||
BLOCK_DPE = 64
|
||||
else:
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lq)
|
||||
BLOCK_DPE = 0
|
||||
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
# Determine BLOCK_M, BLOCK_N, and num_warps based on hardware
|
||||
if _is_hip:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
num_warps = 4
|
||||
else:
|
||||
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
|
||||
# Hopper architecture (H100, etc.)
|
||||
if Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (128, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
|
||||
# Ampere architecture (A100, etc.)
|
||||
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (64, 128)
|
||||
elif Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 32)
|
||||
else:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (128, 128)
|
||||
elif Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
else:
|
||||
# Older architectures
|
||||
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
||||
|
||||
num_warps = 4 if Lq <= 64 else 8
|
||||
|
||||
return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _copy_unified_indices_kernel(
|
||||
# Input buffers
|
||||
prefix_kv_indptr,
|
||||
prefix_kv_indices,
|
||||
extend_start_loc,
|
||||
extend_seq_lens,
|
||||
extend_kv_indices,
|
||||
unified_kv_indptr,
|
||||
# Output buffer
|
||||
unified_kv_indices,
|
||||
# Size
|
||||
bs,
|
||||
):
|
||||
"""
|
||||
Triton kernel to copy indices to unified buffer (parallel per sequence).
|
||||
Each thread block processes one sequence with vectorized loads/stores.
|
||||
"""
|
||||
pid = tl.program_id(0)
|
||||
|
||||
if pid >= bs:
|
||||
return
|
||||
|
||||
# Load sequence info
|
||||
prefix_start = tl.load(prefix_kv_indptr + pid)
|
||||
prefix_end = tl.load(prefix_kv_indptr + pid + 1)
|
||||
extend_start = tl.load(extend_start_loc + pid)
|
||||
extend_len = tl.load(extend_seq_lens + pid)
|
||||
|
||||
prefix_len = prefix_end - prefix_start
|
||||
unified_start = tl.load(unified_kv_indptr + pid)
|
||||
|
||||
# Copy indices in vectorized chunks
|
||||
BLOCK_SIZE: tl.constexpr = 128
|
||||
|
||||
# Process prefix indices
|
||||
for block_start in range(0, prefix_len, BLOCK_SIZE):
|
||||
offs = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offs < prefix_len
|
||||
|
||||
src_idx = prefix_start + offs
|
||||
dst_idx = unified_start + offs
|
||||
|
||||
vals = tl.load(prefix_kv_indices + src_idx, mask=mask, other=0)
|
||||
tl.store(unified_kv_indices + dst_idx, vals, mask=mask)
|
||||
|
||||
# Process extend indices
|
||||
for block_start in range(0, extend_len, BLOCK_SIZE):
|
||||
offs = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offs < extend_len
|
||||
|
||||
src_idx = extend_start + offs
|
||||
dst_idx = unified_start + prefix_len + offs
|
||||
|
||||
vals = tl.load(extend_kv_indices + src_idx, mask=mask, other=0)
|
||||
tl.store(unified_kv_indices + dst_idx, vals, mask=mask)
|
||||
|
||||
|
||||
def build_unified_kv_indices(
|
||||
prefix_kv_indptr: torch.Tensor,
|
||||
prefix_kv_indices: torch.Tensor,
|
||||
extend_start_loc: torch.Tensor,
|
||||
extend_seq_lens: torch.Tensor,
|
||||
extend_kv_indices: torch.Tensor,
|
||||
bs: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Build unified KV indices efficiently:
|
||||
- Use PyTorch's optimized cumsum (NVIDIA CUB) for indptr
|
||||
- Use Triton kernel for parallel index copying
|
||||
|
||||
Returns:
|
||||
(unified_kv_indptr, unified_kv_indices, prefix_lens)
|
||||
"""
|
||||
device = prefix_kv_indptr.device
|
||||
|
||||
prefix_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs]
|
||||
|
||||
# Create unified_kv_indptr avoiding direct assignment (for CUDA graph compatibility)
|
||||
unified_lens = prefix_lens + extend_seq_lens[:bs]
|
||||
unified_kv_indptr = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32, device=device),
|
||||
torch.cumsum(unified_lens, dim=0),
|
||||
]
|
||||
)
|
||||
|
||||
max_unified_len = len(prefix_kv_indices) + len(extend_kv_indices)
|
||||
|
||||
unified_kv_indices = torch.empty(max_unified_len, dtype=torch.int64, device=device)
|
||||
|
||||
# Launch Triton kernel for parallel index copying
|
||||
_copy_unified_indices_kernel[(bs,)](
|
||||
prefix_kv_indptr,
|
||||
prefix_kv_indices,
|
||||
extend_start_loc,
|
||||
extend_seq_lens,
|
||||
extend_kv_indices,
|
||||
unified_kv_indptr,
|
||||
unified_kv_indices,
|
||||
bs,
|
||||
)
|
||||
|
||||
return unified_kv_indptr, unified_kv_indices, prefix_lens
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q_Extend,
|
||||
@@ -402,50 +572,10 @@ def extend_attention_fwd(
|
||||
v_extend.shape[-1],
|
||||
)
|
||||
|
||||
if Lq == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
elif Lq == 288:
|
||||
BLOCK_DMODEL = 256
|
||||
BLOCK_DPE = 32
|
||||
elif Lq == 192:
|
||||
BLOCK_DMODEL = 128
|
||||
BLOCK_DPE = 64
|
||||
else:
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lq)
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
if _is_hip:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
num_warps = 4
|
||||
|
||||
else:
|
||||
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
|
||||
if Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (128, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
|
||||
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
||||
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (64, 128)
|
||||
elif Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 32)
|
||||
else:
|
||||
if Lq <= 128:
|
||||
BLOCK_M, BLOCK_N = (128, 128)
|
||||
elif Lq <= 256:
|
||||
BLOCK_M, BLOCK_N = (64, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (32, 64)
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
||||
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
# Get block sizes and configuration
|
||||
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
|
||||
_get_block_sizes_for_extend_attention(Lq, Lv)
|
||||
)
|
||||
|
||||
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
||||
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
|
||||
@@ -548,3 +678,368 @@ def redundant_attention(
|
||||
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
|
||||
o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
|
||||
pt += cur_seq_len_extend
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_unified(
|
||||
Q,
|
||||
O,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
prefix_lens,
|
||||
mask_ptr,
|
||||
mask_indptr,
|
||||
sink_ptr,
|
||||
window_start_pos,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
SLIDING_WINDOW_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
xai_temperature_len: tl.constexpr,
|
||||
Lq: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DPE: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
USE_CUSTOM_MASK: tl.constexpr,
|
||||
HAS_SINK: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Unified 1-stage kernel for deterministic extend attention.
|
||||
Both prefix and extend KV are accessed through the unified kv_indices.
|
||||
"""
|
||||
cur_seq = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
cur_block_m = tl.program_id(2)
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
# Load sequence information
|
||||
cur_seq_q_start_idx = tl.load(qo_indptr + cur_seq)
|
||||
cur_seq_q_len = tl.load(qo_indptr + cur_seq + 1) - cur_seq_q_start_idx
|
||||
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
|
||||
cur_seq_kv_len = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
|
||||
cur_seq_prefix_len = tl.load(prefix_lens + cur_seq)
|
||||
|
||||
# Load window start position for sliding window attention
|
||||
# This is the absolute position of the first key in the window (0 if no sliding window)
|
||||
cur_window_start = 0
|
||||
if SLIDING_WINDOW_SIZE > 0:
|
||||
cur_window_start = tl.load(window_start_pos + cur_seq)
|
||||
|
||||
# Load custom mask start index if using custom mask (for speculative decoding)
|
||||
if USE_CUSTOM_MASK:
|
||||
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_q_len
|
||||
mask_d = offs_d < Lq
|
||||
mask_dv = offs_dv < Lv
|
||||
|
||||
# XAI temperature handling
|
||||
if xai_temperature_len > 0:
|
||||
offs_qidx = cur_seq_prefix_len + cur_block_m * BLOCK_M + offs_m
|
||||
xai_temperature_reg = tl.where(
|
||||
offs_qidx < xai_temperature_len,
|
||||
1.0,
|
||||
xai_temperature_len / (offs_qidx + 1.0),
|
||||
)
|
||||
|
||||
# Load Q
|
||||
offs_q = (
|
||||
(cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
q = tl.load(Q + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0)
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
offs_qpe = (
|
||||
(cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_dpe[None, :]
|
||||
)
|
||||
qpe = tl.load(Q + offs_qpe, mask=mask_m[:, None], other=0.0)
|
||||
|
||||
# Initialize accumulators
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
||||
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
|
||||
# Unified loop: process all KV tokens (prefix + extend)
|
||||
for start_n in range(0, cur_seq_kv_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_seq_kv_len
|
||||
|
||||
# Compute mask
|
||||
final_mask = mask_m[:, None] & mask_n[None, :]
|
||||
|
||||
# Apply custom mask if provided
|
||||
if USE_CUSTOM_MASK:
|
||||
custom_mask = tl.load(
|
||||
mask_ptr
|
||||
+ cur_seq_mask_start_idx
|
||||
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_kv_len
|
||||
+ start_n
|
||||
+ offs_n[None, :],
|
||||
mask=(mask_m[:, None] & mask_n[None, :]),
|
||||
other=0,
|
||||
)
|
||||
final_mask &= custom_mask
|
||||
|
||||
# Apply causal mask for extend part
|
||||
if IS_CAUSAL and not USE_CUSTOM_MASK:
|
||||
# Determine if current KV block is in extend region
|
||||
# Only apply causal mask when both Q and K are in extend region
|
||||
q_idx = cur_block_m * BLOCK_M + offs_m[:, None]
|
||||
k_idx_in_total = start_n + offs_n[None, :]
|
||||
|
||||
# Causal mask: q_idx >= (k_idx - prefix_len) when k_idx >= prefix_len
|
||||
# For prefix region (k_idx < prefix_len), no causal mask
|
||||
k_is_extend = k_idx_in_total >= cur_seq_prefix_len
|
||||
k_idx_in_extend = k_idx_in_total - cur_seq_prefix_len
|
||||
causal_mask = tl.where(
|
||||
k_is_extend,
|
||||
q_idx >= k_idx_in_extend,
|
||||
True, # No causal mask for prefix
|
||||
)
|
||||
final_mask &= causal_mask
|
||||
|
||||
if SLIDING_WINDOW_SIZE > 0:
|
||||
# Sliding window mask with correct absolute positions
|
||||
# Q absolute position: window_start + prefix_len + q_position_in_extend
|
||||
q_abs_pos = (
|
||||
cur_window_start
|
||||
+ cur_seq_prefix_len
|
||||
+ cur_block_m * BLOCK_M
|
||||
+ offs_m[:, None]
|
||||
)
|
||||
|
||||
# K absolute position: window_start + k_index_in_unified_array
|
||||
k_abs_pos = cur_window_start + start_n + offs_n[None, :]
|
||||
|
||||
# Sliding window: query can attend to keys within window_size
|
||||
window_mask = q_abs_pos <= (k_abs_pos + SLIDING_WINDOW_SIZE)
|
||||
final_mask &= window_mask
|
||||
|
||||
# Check if we can skip this tile
|
||||
SKIP_TILE = False
|
||||
if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
|
||||
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
|
||||
|
||||
if not SKIP_TILE:
|
||||
# Load KV indices
|
||||
offs_kv_loc = tl.load(
|
||||
kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
|
||||
mask=mask_n,
|
||||
other=0,
|
||||
)
|
||||
|
||||
# Load K
|
||||
offs_buf_k = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(mask_n[None, :]) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Compute QK
|
||||
qk = tl.dot(q.to(k.dtype), k)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_kpe,
|
||||
mask=mask_n[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe.to(kpe.dtype), kpe)
|
||||
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
if xai_temperature_len > 0:
|
||||
qk *= xai_temperature_reg[:, None]
|
||||
|
||||
qk = tl.where(final_mask, qk, float("-inf"))
|
||||
|
||||
# Online softmax
|
||||
row_max = tl.max(qk, 1)
|
||||
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
|
||||
n_e_max = tl.maximum(row_max_fixed, e_max)
|
||||
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
|
||||
# Load V
|
||||
offs_buf_v = (
|
||||
offs_kv_loc[:, None] * stride_buf_vbs
|
||||
+ cur_kv_head * stride_buf_vh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=mask_n[:, None] & mask_dv[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
p = p.to(v.dtype)
|
||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
# Handle sink tokens
|
||||
if HAS_SINK:
|
||||
cur_sink = tl.load(sink_ptr + cur_head)
|
||||
deno += tl.exp(cur_sink - e_max)
|
||||
|
||||
# Store output
|
||||
offs_o = (
|
||||
(cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
tl.store(
|
||||
O + offs_o,
|
||||
acc / deno[:, None],
|
||||
mask=mask_m[:, None] & mask_dv[None, :],
|
||||
)
|
||||
|
||||
|
||||
def extend_attention_fwd_unified(
|
||||
q,
|
||||
o,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
prefix_lens,
|
||||
max_len_extend,
|
||||
custom_mask=None,
|
||||
mask_indptr=None,
|
||||
sm_scale=None,
|
||||
logit_cap=0.0,
|
||||
is_causal=True,
|
||||
sliding_window_size=-1,
|
||||
sinks=None,
|
||||
window_start_pos=None,
|
||||
xai_temperature_len=-1,
|
||||
):
|
||||
"""
|
||||
Unified 1-stage extend attention for deterministic inference.
|
||||
|
||||
Args:
|
||||
q: Query tensor [num_tokens, num_heads, head_dim]
|
||||
o: Output tensor [num_tokens, num_heads, head_dim]
|
||||
k_buffer: Key cache buffer
|
||||
v_buffer: Value cache buffer
|
||||
qo_indptr: Query offsets [batch_size + 1]
|
||||
kv_indptr: KV offsets [batch_size + 1] (includes both prefix and extend)
|
||||
kv_indices: Unified KV indices (both prefix and extend)
|
||||
prefix_lens: Prefix length for each sequence [batch_size]
|
||||
max_len_extend: Maximum extend length
|
||||
custom_mask: Custom attention mask (for speculative decoding tree attention)
|
||||
mask_indptr: Mask offsets [batch_size + 1]
|
||||
sm_scale: Softmax scale
|
||||
logit_cap: Logit capping value
|
||||
is_causal: Whether to apply causal mask
|
||||
sliding_window_size: Sliding window size (-1 for no sliding window)
|
||||
sinks: Sink tokens
|
||||
window_start_pos: Absolute position of first key in sliding window [batch_size]
|
||||
(None if sliding window not used)
|
||||
xai_temperature_len: XAI temperature length
|
||||
"""
|
||||
Lq, Lv = q.shape[-1], v_buffer.shape[-1]
|
||||
|
||||
# Get block sizes and configuration
|
||||
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
|
||||
_get_block_sizes_for_extend_attention(Lq, Lv)
|
||||
)
|
||||
|
||||
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
||||
batch_size, head_num = qo_indptr.shape[0] - 1, q.shape[1]
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||
|
||||
USE_CUSTOM_MASK = custom_mask is not None
|
||||
HAS_SINK = sinks is not None
|
||||
|
||||
# For sliding window attention, window_start_pos tracks the absolute position
|
||||
# of the first key in each sequence's window
|
||||
if sliding_window_size > 0 and window_start_pos is None:
|
||||
# If not provided, assume window starts at position 0
|
||||
window_start_pos = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
|
||||
|
||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||
num_stages = 1
|
||||
|
||||
extra_kargs = {}
|
||||
if _is_hip:
|
||||
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
|
||||
_fwd_kernel_unified[grid](
|
||||
q,
|
||||
o,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
prefix_lens,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
sinks,
|
||||
window_start_pos,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
k_buffer.stride(0),
|
||||
k_buffer.stride(1),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
SLIDING_WINDOW_SIZE=sliding_window_size,
|
||||
logit_cap=logit_cap,
|
||||
xai_temperature_len=xai_temperature_len,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
Lq=Lq,
|
||||
Lv=Lv,
|
||||
IS_CAUSAL=is_causal,
|
||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||
HAS_SINK=HAS_SINK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
@@ -1431,8 +1431,8 @@ class ServerArgs:
|
||||
f"but you explicitly specified '{self.attention_backend}'."
|
||||
)
|
||||
|
||||
# Currently, only FA3 supports radix cache. Support for other backends is in progress
|
||||
if self.attention_backend != "fa3":
|
||||
# Currently, only FA3 and Triton supports radix cache. Support for other backends is in progress
|
||||
if self.attention_backend not in ["fa3", "triton"]:
|
||||
self.disable_radix_cache = True
|
||||
logger.warning(
|
||||
f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
|
||||
|
||||
@@ -424,4 +424,7 @@ if __name__ == "__main__":
|
||||
BenchArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sampling_seed is None:
|
||||
args.sampling_seed = 42
|
||||
|
||||
test_deterministic(args)
|
||||
|
||||
Reference in New Issue
Block a user