diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 731758286..b67acbc3f 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -64,13 +64,19 @@ class TritonAttnBackend(AttentionBackend): decode_attention_fwd, ) from sglang.srt.layers.attention.triton_ops.extend_attention import ( + build_unified_kv_indices, extend_attention_fwd, + extend_attention_fwd_unified, ) super().__init__() self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd) self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd) + self.extend_attention_fwd_unified = torch.compiler.disable( + extend_attention_fwd_unified + ) + self.build_unified_kv_indices = torch.compiler.disable(build_unified_kv_indices) # Parse args self.skip_prefill = skip_prefill @@ -794,6 +800,7 @@ class TritonAttnBackend(AttentionBackend): else: o = torch.empty_like(q) + # Save KV cache first (must do this before unified kernel) if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v @@ -805,6 +812,13 @@ class TritonAttnBackend(AttentionBackend): if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY: causal = False + # Deterministic mode: use unified 1-stage kernel + if self.enable_deterministic: + return self._forward_extend_unified( + q, o, layer, forward_batch, causal, logits_soft_cap, sinks + ) + + # Normal mode: use original 2-stage kernel if layer.sliding_window_size is not None and layer.sliding_window_size > -1: sliding_window_size = ( layer.sliding_window_size @@ -841,6 +855,127 @@ class TritonAttnBackend(AttentionBackend): ) return o + def _forward_extend_unified( + self, + q: torch.Tensor, + o: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + causal: bool, + logits_soft_cap: float, + sinks: Optional[torch.Tensor], + ): + """ + Unified 1-stage extend attention for deterministic inference. + Both prefix and extend KV are accessed through unified kv_indices. + """ + bs = forward_batch.batch_size + + # Determine sliding window settings + if layer.sliding_window_size is not None and layer.sliding_window_size > -1: + sliding_window_size = layer.sliding_window_size + # Note: for unified kernel, we use full kv_indptr (not window) + prefix_kv_indptr = self.forward_metadata.window_kv_indptr + prefix_kv_indices = self.forward_metadata.window_kv_indices + # Compute window start positions (absolute position of first key in window) + # window_start_pos = seq_len - window_len + window_kv_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs] + # Handle TARGET_VERIFY mode where extend_prefix_lens might not be set + if forward_batch.extend_prefix_lens is not None: + window_start_pos = ( + forward_batch.extend_prefix_lens[:bs] - window_kv_lens + ) + else: + # Infer from spec_info: prefix_len = seq_len - draft_token_num + if forward_batch.spec_info is not None and hasattr( + forward_batch.spec_info, "draft_token_num" + ): + extend_prefix_lens = ( + forward_batch.seq_lens[:bs] + - forward_batch.spec_info.draft_token_num + ) + window_start_pos = extend_prefix_lens - window_kv_lens + else: + window_start_pos = None + else: + sliding_window_size = -1 + prefix_kv_indptr = self.forward_metadata.kv_indptr + prefix_kv_indices = self.forward_metadata.kv_indices + window_start_pos = None + + # Build unified kv_indices using fused Triton kernel + extend_kv_indices = forward_batch.out_cache_loc + + # Handle cases where extend_seq_lens or extend_start_loc might not be set + # In speculative decoding, we can infer these from spec_info or compute them + if forward_batch.extend_seq_lens is None: + # TARGET_VERIFY mode: infer extend_seq_lens from spec_info + if forward_batch.spec_info is not None and hasattr( + forward_batch.spec_info, "draft_token_num" + ): + draft_token_num = forward_batch.spec_info.draft_token_num + extend_seq_lens = torch.full( + (bs,), draft_token_num, dtype=torch.int32, device=self.device + ) + else: + raise RuntimeError( + "extend_seq_lens is None but cannot infer from spec_info. " + "This should not happen in TARGET_VERIFY mode." + ) + else: + extend_seq_lens = forward_batch.extend_seq_lens + + # Check extend_start_loc separately - it might be None even when extend_seq_lens is set + if forward_batch.extend_start_loc is None: + # Compute extend_start_loc from extend_seq_lens + # extend_start_loc[i] = sum(extend_seq_lens[0:i]) + extend_start_loc = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=self.device), + torch.cumsum(extend_seq_lens[:-1], dim=0), + ] + ) + else: + extend_start_loc = forward_batch.extend_start_loc + + unified_kv_indptr, unified_kv_indices, prefix_lens = ( + self.build_unified_kv_indices( + prefix_kv_indptr, + prefix_kv_indices, + extend_start_loc, + extend_seq_lens, + extend_kv_indices, + bs, + ) + ) + + # Convert prefix_lens to int32 for the kernel + prefix_lens = prefix_lens.to(torch.int32) + + # Call unified kernel + self.extend_attention_fwd_unified( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + self.forward_metadata.qo_indptr, + unified_kv_indptr, + unified_kv_indices, + prefix_lens, + self.forward_metadata.max_extend_len, + custom_mask=self.forward_metadata.custom_mask, + mask_indptr=self.forward_metadata.mask_indptr, + sm_scale=layer.scaling, + logit_cap=logits_soft_cap, + is_causal=causal, + sliding_window_size=sliding_window_size, + sinks=sinks, + window_start_pos=window_start_pos, + xai_temperature_len=layer.xai_temperature_len, + ) + + return o + def forward_decode( self, q: torch.Tensor, diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index e91467743..62132a340 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -32,12 +32,182 @@ if _is_cuda: _is_hip = is_hip() +def _get_block_sizes_for_extend_attention(Lq: int, Lv: int): + """ + Get block sizes and configuration for extend attention kernels. + + Args: + Lq: Query head dimension + Lv: Value head dimension + + Returns: + tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps) + """ + # Determine BLOCK_DMODEL and BLOCK_DPE based on head dimension + if Lq == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lq == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + elif Lq == 192: + BLOCK_DMODEL = 128 + BLOCK_DPE = 64 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lq) + BLOCK_DPE = 0 + + BLOCK_DV = triton.next_power_of_2(Lv) + + # Determine BLOCK_M, BLOCK_N, and num_warps based on hardware + if _is_hip: + BLOCK_M, BLOCK_N = (64, 64) + num_warps = 4 + else: + if _is_cuda and CUDA_CAPABILITY[0] >= 9: + # Hopper architecture (H100, etc.) + if Lq <= 256: + BLOCK_M, BLOCK_N = (128, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + elif _is_cuda and CUDA_CAPABILITY[0] >= 8: + # Ampere architecture (A100, etc.) + # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K) + if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6: + if Lq <= 128: + BLOCK_M, BLOCK_N = (64, 128) + elif Lq <= 256: + BLOCK_M, BLOCK_N = (64, 64) + else: + BLOCK_M, BLOCK_N = (32, 32) + else: + if Lq <= 128: + BLOCK_M, BLOCK_N = (128, 128) + elif Lq <= 256: + BLOCK_M, BLOCK_N = (64, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + else: + # Older architectures + BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) + + num_warps = 4 if Lq <= 64 else 8 + + return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps + + @triton.jit def tanh(x): # Tanh is just a scaled sigmoid return 2 * tl.sigmoid(2 * x) - 1 +@triton.jit +def _copy_unified_indices_kernel( + # Input buffers + prefix_kv_indptr, + prefix_kv_indices, + extend_start_loc, + extend_seq_lens, + extend_kv_indices, + unified_kv_indptr, + # Output buffer + unified_kv_indices, + # Size + bs, +): + """ + Triton kernel to copy indices to unified buffer (parallel per sequence). + Each thread block processes one sequence with vectorized loads/stores. + """ + pid = tl.program_id(0) + + if pid >= bs: + return + + # Load sequence info + prefix_start = tl.load(prefix_kv_indptr + pid) + prefix_end = tl.load(prefix_kv_indptr + pid + 1) + extend_start = tl.load(extend_start_loc + pid) + extend_len = tl.load(extend_seq_lens + pid) + + prefix_len = prefix_end - prefix_start + unified_start = tl.load(unified_kv_indptr + pid) + + # Copy indices in vectorized chunks + BLOCK_SIZE: tl.constexpr = 128 + + # Process prefix indices + for block_start in range(0, prefix_len, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < prefix_len + + src_idx = prefix_start + offs + dst_idx = unified_start + offs + + vals = tl.load(prefix_kv_indices + src_idx, mask=mask, other=0) + tl.store(unified_kv_indices + dst_idx, vals, mask=mask) + + # Process extend indices + for block_start in range(0, extend_len, BLOCK_SIZE): + offs = block_start + tl.arange(0, BLOCK_SIZE) + mask = offs < extend_len + + src_idx = extend_start + offs + dst_idx = unified_start + prefix_len + offs + + vals = tl.load(extend_kv_indices + src_idx, mask=mask, other=0) + tl.store(unified_kv_indices + dst_idx, vals, mask=mask) + + +def build_unified_kv_indices( + prefix_kv_indptr: torch.Tensor, + prefix_kv_indices: torch.Tensor, + extend_start_loc: torch.Tensor, + extend_seq_lens: torch.Tensor, + extend_kv_indices: torch.Tensor, + bs: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Build unified KV indices efficiently: + - Use PyTorch's optimized cumsum (NVIDIA CUB) for indptr + - Use Triton kernel for parallel index copying + + Returns: + (unified_kv_indptr, unified_kv_indices, prefix_lens) + """ + device = prefix_kv_indptr.device + + prefix_lens = prefix_kv_indptr[1 : bs + 1] - prefix_kv_indptr[:bs] + + # Create unified_kv_indptr avoiding direct assignment (for CUDA graph compatibility) + unified_lens = prefix_lens + extend_seq_lens[:bs] + unified_kv_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(unified_lens, dim=0), + ] + ) + + max_unified_len = len(prefix_kv_indices) + len(extend_kv_indices) + + unified_kv_indices = torch.empty(max_unified_len, dtype=torch.int64, device=device) + + # Launch Triton kernel for parallel index copying + _copy_unified_indices_kernel[(bs,)]( + prefix_kv_indptr, + prefix_kv_indices, + extend_start_loc, + extend_seq_lens, + extend_kv_indices, + unified_kv_indptr, + unified_kv_indices, + bs, + ) + + return unified_kv_indptr, unified_kv_indices, prefix_lens + + @triton.jit def _fwd_kernel( Q_Extend, @@ -402,50 +572,10 @@ def extend_attention_fwd( v_extend.shape[-1], ) - if Lq == 576: - BLOCK_DMODEL = 512 - BLOCK_DPE = 64 - elif Lq == 288: - BLOCK_DMODEL = 256 - BLOCK_DPE = 32 - elif Lq == 192: - BLOCK_DMODEL = 128 - BLOCK_DPE = 64 - else: - BLOCK_DMODEL = triton.next_power_of_2(Lq) - BLOCK_DPE = 0 - BLOCK_DV = triton.next_power_of_2(Lv) - - if _is_hip: - BLOCK_M, BLOCK_N = (64, 64) - num_warps = 4 - - else: - if _is_cuda and CUDA_CAPABILITY[0] >= 9: - if Lq <= 256: - BLOCK_M, BLOCK_N = (128, 64) - else: - BLOCK_M, BLOCK_N = (32, 64) - elif _is_cuda and CUDA_CAPABILITY[0] >= 8: - # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K) - if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6: - if Lq <= 128: - BLOCK_M, BLOCK_N = (64, 128) - elif Lq <= 256: - BLOCK_M, BLOCK_N = (64, 64) - else: - BLOCK_M, BLOCK_N = (32, 32) - else: - if Lq <= 128: - BLOCK_M, BLOCK_N = (128, 128) - elif Lq <= 256: - BLOCK_M, BLOCK_N = (64, 64) - else: - BLOCK_M, BLOCK_N = (32, 64) - else: - BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) - - num_warps = 4 if Lk <= 64 else 8 + # Get block sizes and configuration + BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = ( + _get_block_sizes_for_extend_attention(Lq, Lv) + ) sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] @@ -548,3 +678,368 @@ def redundant_attention( pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i] o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr] pt += cur_seq_len_extend + + +@triton.jit +def _fwd_kernel_unified( + Q, + O, + K_Buffer, + V_Buffer, + qo_indptr, + kv_indptr, + kv_indices, + prefix_lens, + mask_ptr, + mask_indptr, + sink_ptr, + window_start_pos, + sm_scale, + kv_group_num, + stride_qbs, + stride_qh, + stride_obs, + stride_oh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + SLIDING_WINDOW_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + xai_temperature_len: tl.constexpr, + Lq: tl.constexpr, + Lv: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + USE_CUSTOM_MASK: tl.constexpr, + HAS_SINK: tl.constexpr, +): + """ + Unified 1-stage kernel for deterministic extend attention. + Both prefix and extend KV are accessed through the unified kv_indices. + """ + cur_seq = tl.program_id(0) + cur_head = tl.program_id(1) + cur_block_m = tl.program_id(2) + cur_kv_head = cur_head // kv_group_num + + # Load sequence information + cur_seq_q_start_idx = tl.load(qo_indptr + cur_seq) + cur_seq_q_len = tl.load(qo_indptr + cur_seq + 1) - cur_seq_q_start_idx + cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq) + cur_seq_kv_len = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx + cur_seq_prefix_len = tl.load(prefix_lens + cur_seq) + + # Load window start position for sliding window attention + # This is the absolute position of the first key in the window (0 if no sliding window) + cur_window_start = 0 + if SLIDING_WINDOW_SIZE > 0: + cur_window_start = tl.load(window_start_pos + cur_seq) + + # Load custom mask start index if using custom mask (for speculative decoding) + if USE_CUSTOM_MASK: + cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + offs_m = tl.arange(0, BLOCK_M) + mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_q_len + mask_d = offs_d < Lq + mask_dv = offs_dv < Lv + + # XAI temperature handling + if xai_temperature_len > 0: + offs_qidx = cur_seq_prefix_len + cur_block_m * BLOCK_M + offs_m + xai_temperature_reg = tl.where( + offs_qidx < xai_temperature_len, + 1.0, + xai_temperature_len / (offs_qidx + 1.0), + ) + + # Load Q + offs_q = ( + (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + q = tl.load(Q + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + offs_qpe = ( + (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_dpe[None, :] + ) + qpe = tl.load(Q + offs_qpe, mask=mask_m[:, None], other=0.0) + + # Initialize accumulators + offs_n = tl.arange(0, BLOCK_N) + acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) + deno = tl.zeros([BLOCK_M], dtype=tl.float32) + e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + + # Unified loop: process all KV tokens (prefix + extend) + for start_n in range(0, cur_seq_kv_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_seq_kv_len + + # Compute mask + final_mask = mask_m[:, None] & mask_n[None, :] + + # Apply custom mask if provided + if USE_CUSTOM_MASK: + custom_mask = tl.load( + mask_ptr + + cur_seq_mask_start_idx + + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_kv_len + + start_n + + offs_n[None, :], + mask=(mask_m[:, None] & mask_n[None, :]), + other=0, + ) + final_mask &= custom_mask + + # Apply causal mask for extend part + if IS_CAUSAL and not USE_CUSTOM_MASK: + # Determine if current KV block is in extend region + # Only apply causal mask when both Q and K are in extend region + q_idx = cur_block_m * BLOCK_M + offs_m[:, None] + k_idx_in_total = start_n + offs_n[None, :] + + # Causal mask: q_idx >= (k_idx - prefix_len) when k_idx >= prefix_len + # For prefix region (k_idx < prefix_len), no causal mask + k_is_extend = k_idx_in_total >= cur_seq_prefix_len + k_idx_in_extend = k_idx_in_total - cur_seq_prefix_len + causal_mask = tl.where( + k_is_extend, + q_idx >= k_idx_in_extend, + True, # No causal mask for prefix + ) + final_mask &= causal_mask + + if SLIDING_WINDOW_SIZE > 0: + # Sliding window mask with correct absolute positions + # Q absolute position: window_start + prefix_len + q_position_in_extend + q_abs_pos = ( + cur_window_start + + cur_seq_prefix_len + + cur_block_m * BLOCK_M + + offs_m[:, None] + ) + + # K absolute position: window_start + k_index_in_unified_array + k_abs_pos = cur_window_start + start_n + offs_n[None, :] + + # Sliding window: query can attend to keys within window_size + window_mask = q_abs_pos <= (k_abs_pos + SLIDING_WINDOW_SIZE) + final_mask &= window_mask + + # Check if we can skip this tile + SKIP_TILE = False + if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0: + SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0 + + if not SKIP_TILE: + # Load KV indices + offs_kv_loc = tl.load( + kv_indices + cur_seq_kv_start_idx + start_n + offs_n, + mask=mask_n, + other=0, + ) + + # Load K + offs_buf_k = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(mask_n[None, :]) & (mask_d[:, None]), + other=0.0, + ) + + # Compute QK + qk = tl.dot(q.to(k.dtype), k) + if BLOCK_DPE > 0: + offs_kpe = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe.to(kpe.dtype), kpe) + + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + if xai_temperature_len > 0: + qk *= xai_temperature_reg[:, None] + + qk = tl.where(final_mask, qk, float("-inf")) + + # Online softmax + row_max = tl.max(qk, 1) + row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max) + n_e_max = tl.maximum(row_max_fixed, e_max) + + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + # Load V + offs_buf_v = ( + offs_kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], + other=0.0, + ) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + # Handle sink tokens + if HAS_SINK: + cur_sink = tl.load(sink_ptr + cur_head) + deno += tl.exp(cur_sink - e_max) + + # Store output + offs_o = ( + (cur_seq_q_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_dv[None, :] + ) + tl.store( + O + offs_o, + acc / deno[:, None], + mask=mask_m[:, None] & mask_dv[None, :], + ) + + +def extend_attention_fwd_unified( + q, + o, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + prefix_lens, + max_len_extend, + custom_mask=None, + mask_indptr=None, + sm_scale=None, + logit_cap=0.0, + is_causal=True, + sliding_window_size=-1, + sinks=None, + window_start_pos=None, + xai_temperature_len=-1, +): + """ + Unified 1-stage extend attention for deterministic inference. + + Args: + q: Query tensor [num_tokens, num_heads, head_dim] + o: Output tensor [num_tokens, num_heads, head_dim] + k_buffer: Key cache buffer + v_buffer: Value cache buffer + qo_indptr: Query offsets [batch_size + 1] + kv_indptr: KV offsets [batch_size + 1] (includes both prefix and extend) + kv_indices: Unified KV indices (both prefix and extend) + prefix_lens: Prefix length for each sequence [batch_size] + max_len_extend: Maximum extend length + custom_mask: Custom attention mask (for speculative decoding tree attention) + mask_indptr: Mask offsets [batch_size + 1] + sm_scale: Softmax scale + logit_cap: Logit capping value + is_causal: Whether to apply causal mask + sliding_window_size: Sliding window size (-1 for no sliding window) + sinks: Sink tokens + window_start_pos: Absolute position of first key in sliding window [batch_size] + (None if sliding window not used) + xai_temperature_len: XAI temperature length + """ + Lq, Lv = q.shape[-1], v_buffer.shape[-1] + + # Get block sizes and configuration + BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = ( + _get_block_sizes_for_extend_attention(Lq, Lv) + ) + + sm_scale = sm_scale or 1.0 / (Lq**0.5) + batch_size, head_num = qo_indptr.shape[0] - 1, q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[1] + + USE_CUSTOM_MASK = custom_mask is not None + HAS_SINK = sinks is not None + + # For sliding window attention, window_start_pos tracks the absolute position + # of the first key in each sequence's window + if sliding_window_size > 0 and window_start_pos is None: + # If not provided, assume window starts at position 0 + window_start_pos = torch.zeros(batch_size, dtype=torch.int32, device=q.device) + + grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) + num_stages = 1 + + extra_kargs = {} + if _is_hip: + extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} + + _fwd_kernel_unified[grid]( + q, + o, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + prefix_lens, + custom_mask, + mask_indptr, + sinks, + window_start_pos, + sm_scale, + kv_group_num, + q.stride(0), + q.stride(1), + o.stride(0), + o.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), + SLIDING_WINDOW_SIZE=sliding_window_size, + logit_cap=logit_cap, + xai_temperature_len=xai_temperature_len, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + Lq=Lq, + Lv=Lv, + IS_CAUSAL=is_causal, + USE_CUSTOM_MASK=USE_CUSTOM_MASK, + HAS_SINK=HAS_SINK, + num_warps=num_warps, + num_stages=num_stages, + **extra_kargs, + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3ffb7935f..a05d67b7d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1431,8 +1431,8 @@ class ServerArgs: f"but you explicitly specified '{self.attention_backend}'." ) - # Currently, only FA3 supports radix cache. Support for other backends is in progress - if self.attention_backend != "fa3": + # Currently, only FA3 and Triton supports radix cache. Support for other backends is in progress + if self.attention_backend not in ["fa3", "triton"]: self.disable_radix_cache = True logger.warning( f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future." diff --git a/python/sglang/test/test_deterministic.py b/python/sglang/test/test_deterministic.py index 58186a3b6..e33a1ed47 100644 --- a/python/sglang/test/test_deterministic.py +++ b/python/sglang/test/test_deterministic.py @@ -424,4 +424,7 @@ if __name__ == "__main__": BenchArgs.add_cli_args(parser) args = parser.parse_args() + if args.sampling_seed is None: + args.sampling_seed = 42 + test_deterministic(args) diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index b15684f9a..16c107006 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -10,7 +10,9 @@ from sglang.srt.layers.attention.triton_ops.decode_attention import ( decode_attention_fwd_normal, ) from sglang.srt.layers.attention.triton_ops.extend_attention import ( + build_unified_kv_indices, extend_attention_fwd, + extend_attention_fwd_unified, redundant_attention, ) from sglang.srt.layers.attention.triton_ops.prefill_attention import ( @@ -571,6 +573,204 @@ class TestTritonAttention(CustomTestCase): for B, H_Q, H_KV, D, D_V in configs: self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V) + def _test_extend_attention_unified_vs_regular_once(self, B, N_CTX, H_Q, H_KV, D): + """Test that unified kernel produces same results as 2-stage kernel.""" + dtype = torch.bfloat16 + + b_seq_len_prefix = torch.randint( + 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" + ) + b_seq_len_extend = torch.randint( + 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" + ) + b_seq_len = b_seq_len_prefix + b_seq_len_extend + + b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) + b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + # Setup prefix KV indices + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) + kv_indices = torch.zeros( + (b_seq_len_prefix.sum().item(),), dtype=torch.int64, device="cuda" + ) + + for i in range(B): + kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i] + ) + + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + k_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + v_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] + extend_start = b_start_loc_extend[i] + extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] + k_extend[extend_start:extend_end] = k_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + v_extend[extend_start:extend_end] = v_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + q_extend[extend_start:extend_end] = torch.empty( + (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + # Setup for extend attention + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) + + # Run 2-stage kernel + o_regular = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_regular, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + custom_mask=None, + is_causal=True, + mask_indptr=None, + max_len_extend=max_len_extend, + ) + + # Build unified KV indices + extend_kv_indices = torch.arange( + total_token_num - extend_token_num, + total_token_num, + dtype=torch.int64, + device="cuda", + ) + extend_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + extend_start_loc[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + unified_kv_indptr, unified_kv_indices, prefix_lens = build_unified_kv_indices( + kv_indptr, + kv_indices, + extend_start_loc, + b_seq_len_extend, + extend_kv_indices, + B, + ) + + # Run unified kernel + o_unified = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + extend_attention_fwd_unified( + q_extend, + o_unified, + k_buffer, + v_buffer, + qo_indptr, + unified_kv_indptr, + unified_kv_indices, + prefix_lens, + max_len_extend=max_len_extend, + custom_mask=None, + mask_indptr=None, + sm_scale=None, + logit_cap=0.0, + is_causal=True, + ) + + # Compare results + self.assertTrue( + torch.allclose(o_regular, o_unified, rtol=0.15, atol=0.15), + f"Unified kernel output differs from 2-stage kernel. " + f"Max diff: {(o_regular - o_unified).abs().max()}", + ) + + def test_extend_attention_unified_vs_regular(self): + """Test unified kernel matches 2-stage kernel across different configs.""" + configs = [ + (4, 512, 32, 8, 128), # Standard config + (2, 2048, 32, 8, 128), # Long sequence (test 2048 specifically) + (8, 256, 64, 8, 80), # Non-standard head dim + ] + + for B, N_CTX, H_Q, H_KV, D in configs: + with self.subTest(B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D): + self._test_extend_attention_unified_vs_regular_once( + B, N_CTX, H_Q, H_KV, D + ) + + def test_build_unified_kv_indices(self): + """Test build_unified_kv_indices correctness.""" + B = 4 + dtype = torch.int64 + device = "cuda" + + # Setup test data + prefix_lens = torch.tensor([10, 20, 15, 25], dtype=torch.int32, device=device) + extend_lens = torch.tensor([5, 3, 7, 4], dtype=torch.int32, device=device) + + # Build prefix indices + prefix_kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) + prefix_kv_indptr[1:] = torch.cumsum(prefix_lens, dim=0) + prefix_kv_indices = torch.arange( + prefix_lens.sum().item(), dtype=dtype, device=device + ) + + # Build extend indices + extend_start_loc = torch.zeros((B,), dtype=torch.int32, device=device) + extend_start_loc[1:] = torch.cumsum(extend_lens[:-1], dim=0) + extend_kv_indices = torch.arange( + prefix_lens.sum().item(), + prefix_lens.sum().item() + extend_lens.sum().item(), + dtype=dtype, + device=device, + ) + + # Build unified indices + unified_kv_indptr, unified_kv_indices, returned_prefix_lens = ( + build_unified_kv_indices( + prefix_kv_indptr, + prefix_kv_indices, + extend_start_loc, + extend_lens, + extend_kv_indices, + B, + ) + ) + + # Verify unified_kv_indptr + expected_lens = prefix_lens + extend_lens + expected_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device) + expected_indptr[1:] = torch.cumsum(expected_lens, dim=0) + self.assertTrue(torch.equal(unified_kv_indptr, expected_indptr)) + + # Verify prefix_lens + self.assertTrue(torch.equal(returned_prefix_lens, prefix_lens)) + + # Verify unified_kv_indices structure + for i in range(B): + start_idx = int(unified_kv_indptr[i]) + end_idx = int(unified_kv_indptr[i + 1]) + prefix_len = int(prefix_lens[i]) + extend_len = int(extend_lens[i]) + + # Check that prefix and extend are concatenated correctly + unified_seq = unified_kv_indices[start_idx:end_idx] + self.assertEqual(len(unified_seq), prefix_len + extend_len) + if __name__ == "__main__": unittest.main()