diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index 8364a82ca..669d3b3a2 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -39,6 +39,7 @@ class AttentionBackend(ABC): def init_forward_metadata_replay_cuda_graph( self, bs: int, + num_kv_heads: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index fba806010..913a693f4 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -349,6 +349,7 @@ class FlashInferAttnBackend(AttentionBackend): def init_forward_metadata_replay_cuda_graph( self, bs: int, + num_kv_heads: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, @@ -1062,6 +1063,7 @@ class FlashInferMultiStepDraftBackend: def call_fn(i, forward_batch): self.attn_backends[i].init_forward_metadata_replay_cuda_graph( bs, + -1, forward_batch.req_pool_indices, forward_batch.seq_lens, seq_lens_sum=-1, diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 9af027bd1..6a5518a67 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -279,6 +279,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): def init_forward_metadata_replay_cuda_graph( self, bs: int, + num_kv_heads: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, @@ -791,6 +792,7 @@ class FlashInferMLAMultiStepDraftBackend: def call_fn(i, forward_batch): self.attn_backends[i].init_forward_metadata_replay_cuda_graph( bs, + -1, forward_batch.req_pool_indices, forward_batch.seq_lens, seq_lens_sum=-1, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index f5cb29a0f..00bd68d60 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -4,11 +4,13 @@ from typing import TYPE_CHECKING, Optional, Union import torch import triton +import triton.language as tl from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import get_bool_env_var, get_device_core_count if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -16,6 +18,51 @@ if TYPE_CHECKING: from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +@triton.jit +def get_num_kv_splits_triton( + num_kv_splits_ptr, + seq_lens_ptr, + bs, + num_head, + num_kv_head, + max_kv_splits, + device_core_count, + MAX_BS: tl.constexpr, +): + # TODO: this method is tunable + offs_b = tl.arange(0, MAX_BS) + mask_b = offs_b < bs + + seq_lens = tl.load(seq_lens_ptr + offs_b, mask=mask_b, other=0) + max_seq_len = tl.max(seq_lens) + seq_lens = tl.load(seq_lens_ptr + offs_b, mask=mask_b, other=max_seq_len) + min_seq_len = tl.min(seq_lens) + if max_seq_len * 8 < min_seq_len * 10: + min_seq_len = max_seq_len + max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits) + kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1) + + # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually + ext_seq_len = tl.cast(tl.cdiv(max_seq_len, 256), tl.float32) + ext_device_core_count = device_core_count * tl.maximum( + tl.cast(tl.ceil(tl.log2(ext_seq_len)), tl.int32), 1 + ) + block_h, num_kv_group = 16, num_head // num_kv_head + if num_kv_group == 1: + bh_grid = bs * num_head + else: + # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd + block_h = tl.minimum(block_h, num_kv_group) + bh_grid = bs * tl.cdiv(num_head, block_h) + max_kv_splits_2 = tl.minimum(tl.cdiv(ext_device_core_count, bh_grid), max_kv_splits) + kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2) + + num_kv_splits = tl.maximum( + tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2) + ) + tl.store(num_kv_splits_ptr + offs_b, num_kv_splits, mask=mask_b) + + class TritonAttnBackend(AttentionBackend): def __init__( self, @@ -64,7 +111,10 @@ class TritonAttnBackend(AttentionBackend): model_runner.model_config.num_attention_heads // get_attention_tp_size() ) - self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + self.static_kv_splits = get_bool_env_var( + "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false" + ) + self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] self.forward_metadata = None @@ -72,6 +122,30 @@ class TritonAttnBackend(AttentionBackend): self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device + self.device_core_count = get_device_core_count(model_runner.gpu_id) + + def get_num_kv_splits( + self, + num_kv_splits: torch.Tensor, + seq_lens: torch.Tensor, + bs: int, + num_kv_head: int, + ): + MAX_SCHEDULE_BS = 4096 + if self.static_kv_splits or self.device_core_count <= 0 or bs > MAX_SCHEDULE_BS: + num_kv_splits.fill_(self.max_kv_splits) + return + + get_num_kv_splits_triton[(1,)]( + num_kv_splits, + seq_lens, + bs, + self.num_head, + num_kv_head, + self.max_kv_splits, + self.device_core_count, + MAX_BS=MAX_SCHEDULE_BS, + ) def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" @@ -100,15 +174,35 @@ class TritonAttnBackend(AttentionBackend): kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 - attn_logits = torch.empty( - ( - bs, - self.num_head, - self.num_kv_splits, - self.v_head_dim + 1, + attn_logits = [ + torch.empty( + ( + bs, + self.num_head, + self.max_kv_splits, + self.v_head_dim, + ), + dtype=torch.float32, + device=self.device, ), - dtype=torch.float32, - device=self.device, + torch.empty( + ( + bs, + self.num_head, + self.max_kv_splits, + ), + dtype=torch.float32, + device=self.device, + ), + ] + num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device) + + num_kv_heads = self.num_head + if hasattr(forward_batch.token_to_kv_pool, "k_buffer"): + if isinstance(forward_batch.token_to_kv_pool.k_buffer, list): + num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1] + self.get_num_kv_splits( + num_kv_splits, forward_batch.seq_lens, bs, num_kv_heads ) qo_indptr = None @@ -148,6 +242,7 @@ class TritonAttnBackend(AttentionBackend): mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0) mask_indptr = mask_indptr[: bs + 1] max_extend_len = self.num_draft_tokens + num_kv_splits = None attn_logits = None elif forward_batch.forward_mode.is_draft_extend(): kv_indices, kv_indptr, qo_indptr, custom_mask = ( @@ -160,6 +255,7 @@ class TritonAttnBackend(AttentionBackend): ) mask_indptr = None max_extend_len = torch.max(spec_info.accept_length).item() + num_kv_splits = None attn_logits = None else: kv_indptr[1 : bs + 1] = torch.cumsum( @@ -188,10 +284,12 @@ class TritonAttnBackend(AttentionBackend): mask_indptr = None attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() + num_kv_splits = None self.forward_metadata = ( attn_logits, max_extend_len, + num_kv_splits, kv_indptr, kv_indices, qo_indptr, @@ -202,10 +300,20 @@ class TritonAttnBackend(AttentionBackend): def init_cuda_graph_state( self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None ): - self.cuda_graph_attn_logits = torch.zeros( - (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), - dtype=torch.float32, - device=self.device, + self.cuda_graph_attn_logits = [ + torch.zeros( + (max_bs, self.num_head, self.max_kv_splits, self.v_head_dim), + dtype=torch.float32, + device=self.device, + ), + torch.zeros( + (max_bs, self.num_head, self.max_kv_splits), + dtype=torch.float32, + device=self.device, + ), + ] + self.cuda_graph_num_kv_splits = torch.full( + (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device ) if kv_indices_buf is None: self.cuda_graph_kv_indices = torch.zeros( @@ -255,6 +363,7 @@ class TritonAttnBackend(AttentionBackend): attn_logits = self.cuda_graph_attn_logits max_extend_len = None + num_kv_splits = self.cuda_graph_num_kv_splits qo_indptr = None custom_mask = None mask_indptr = None @@ -285,6 +394,7 @@ class TritonAttnBackend(AttentionBackend): mask_indptr = self.mask_indptr[: bs + 1] mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) max_extend_len = self.num_draft_tokens + num_kv_splits = None attn_logits = None else: raise ValueError( @@ -294,6 +404,7 @@ class TritonAttnBackend(AttentionBackend): self.forward_metadata = ( attn_logits, max_extend_len, + num_kv_splits, kv_indptr, kv_indices, qo_indptr, @@ -304,6 +415,7 @@ class TritonAttnBackend(AttentionBackend): def init_forward_metadata_replay_cuda_graph( self, bs: int, + num_kv_head: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, @@ -317,6 +429,7 @@ class TritonAttnBackend(AttentionBackend): # Update kv_indptr, kv_indices kv_indptr = self.kv_indptr kv_indices = self.cuda_graph_kv_indices + num_kv_splits = self.cuda_graph_num_kv_splits if spec_info is None: kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) kv_indptr = kv_indptr[: bs + 1] @@ -332,6 +445,7 @@ class TritonAttnBackend(AttentionBackend): else: kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices + self.get_num_kv_splits(num_kv_splits, seq_lens, bs, num_kv_head) elif forward_mode.is_target_verify(): # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr bs = len(req_pool_indices) @@ -391,6 +505,7 @@ class TritonAttnBackend(AttentionBackend): ( _, max_extend_len, + _, kv_indptr, kv_indices, qo_indptr, @@ -435,7 +550,9 @@ class TritonAttnBackend(AttentionBackend): else: o = torch.empty_like(q) - attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata + attn_logits, _, num_kv_splits, kv_indptr, kv_indices, _, _, _ = ( + self.forward_metadata + ) if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( @@ -450,7 +567,8 @@ class TritonAttnBackend(AttentionBackend): kv_indptr, kv_indices, attn_logits, - self.num_kv_splits, + num_kv_splits, + self.max_kv_splits, layer.scaling, layer.logit_cap, ) @@ -493,6 +611,9 @@ class TritonMultiStepDraftBackend: ) ) self.max_context_len = self.attn_backends[0].max_context_len + self.num_head = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) self.device = model_runner.device # Cached variables for generate_draft_decode_kv_indices self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] @@ -579,9 +700,15 @@ class TritonMultiStepDraftBackend: def init_forward_metadata_replay_cuda_graph( self, forward_batch: ForwardBatch, bs: int ): + num_kv_heads = self.num_head + if hasattr(forward_batch.token_to_kv_pool, "k_buffer"): + if isinstance(forward_batch.token_to_kv_pool.k_buffer, list): + num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1] + def call_fn(i, forward_batch): self.attn_backends[i].init_forward_metadata_replay_cuda_graph( bs, + num_kv_heads, forward_batch.req_pool_indices, forward_batch.seq_lens, seq_lens_sum=-1, diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index a9ab44546..4ae5cc205 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -37,6 +37,9 @@ logger.warning( ) +_MIN_BLOCK_KV = 32 + + @triton.jit def tanh(x): # Tanh is just a scaled sigmoid @@ -52,6 +55,8 @@ def _fwd_kernel_stage1( kv_indptr, kv_indices, Att_Out, + Att_Lse, + num_kv_splits, stride_qbs, stride_qh, stride_buf_kbs, @@ -65,7 +70,7 @@ def _fwd_kernel_stage1( BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, - NUM_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, Lv: tl.constexpr, @@ -83,11 +88,13 @@ def _fwd_kernel_stage1( cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + kv_splits = tl.load(num_kv_splits + cur_batch) off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - q = tl.load(Q + off_q, mask=mask_d, other=0.0) - kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) @@ -96,6 +103,7 @@ def _fwd_kernel_stage1( acc = tl.zeros([BLOCK_DV], dtype=tl.float32) if split_kv_end > split_kv_start: + q = tl.load(Q + off_q, mask=mask_d, other=0.0) for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_loc = tl.load( @@ -158,11 +166,10 @@ def _fwd_kernel_stage1( cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os - + Lv - ) + ) // Lv tl.store( - Att_Out + offs_mid_o_1, + Att_Lse + offs_mid_o_1, e_max + tl.log(e_sum), ) @@ -172,9 +179,11 @@ def _decode_att_m_fwd( k_buffer, v_buffer, att_out, + att_lse, kv_indptr, kv_indices, num_kv_splits, + max_kv_splits, sm_scale, logit_cap, ): @@ -182,13 +191,13 @@ def _decode_att_m_fwd( # [TODO] work around SGPR limit on MI3xx if _is_hip: BLOCK = 8 - NUM_KV_SPLITS = num_kv_splits + MAX_KV_SPLITS = max_kv_splits Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] - grid = (batch, head_num, NUM_KV_SPLITS) + grid = (batch, head_num, MAX_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[1] if kv_group_num == 1: @@ -209,6 +218,8 @@ def _decode_att_m_fwd( kv_indptr, kv_indices, att_out, + att_lse, + num_kv_splits, q.stride(0), q.stride(1), k_buffer.stride(0), @@ -222,7 +233,7 @@ def _decode_att_m_fwd( BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, - NUM_KV_SPLITS=NUM_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, logit_cap=logit_cap, num_warps=num_warps, num_stages=2, @@ -240,6 +251,8 @@ def _fwd_grouped_kernel_stage1( kv_indptr, kv_indices, Att_Out, + Att_Lse, + num_kv_splits, stride_qbs, stride_qh, stride_buf_kbs, @@ -256,7 +269,7 @@ def _fwd_grouped_kernel_stage1( BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, - NUM_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, Lv: tl.constexpr, @@ -281,9 +294,9 @@ def _fwd_grouped_kernel_stage1( cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + kv_splits = tl.load(num_kv_splits + cur_batch) offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] - q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) @@ -291,11 +304,10 @@ def _fwd_grouped_kernel_stage1( off_qpe = ( cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] ) - qpe = tl.load( - Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 - ) - kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) @@ -304,6 +316,11 @@ def _fwd_grouped_kernel_stage1( acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) if split_kv_end > split_kv_start: + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + if BLOCK_DPE > 0: + qpe = tl.load( + Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 + ) for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_loc = tl.load( @@ -380,11 +397,10 @@ def _fwd_grouped_kernel_stage1( cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os - + Lv - ) + ) // Lv tl.store( - Att_Out + offs_mid_o_1, + Att_Lse + offs_mid_o_1, e_max + tl.log(e_sum), mask=mask_h, ) @@ -395,9 +411,11 @@ def _decode_grouped_att_m_fwd( k_buffer, v_buffer, att_out, + att_lse, kv_indptr, kv_indices, num_kv_splits, + max_kv_splits, sm_scale, logit_cap, ): @@ -424,11 +442,11 @@ def _decode_grouped_att_m_fwd( kv_group_num = q.shape[1] // k_buffer.shape[1] BLOCK_H = 16 - NUM_KV_SPLITS = num_kv_splits + MAX_KV_SPLITS = max_kv_splits grid = ( batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), - NUM_KV_SPLITS, + MAX_KV_SPLITS, ) extra_kargs = {} @@ -447,6 +465,8 @@ def _decode_grouped_att_m_fwd( kv_indptr, kv_indices, att_out, + att_lse, + num_kv_splits, q.stride(0), q.stride(1), k_buffer.stride(0), @@ -463,7 +483,7 @@ def _decode_grouped_att_m_fwd( BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, - NUM_KV_SPLITS=NUM_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, logit_cap=logit_cap, num_warps=4, num_stages=num_stages, @@ -476,14 +496,17 @@ def _decode_grouped_att_m_fwd( @triton.jit def _fwd_kernel_stage2( Mid_O, + Mid_O_1, O, kv_indptr, + num_kv_splits, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_obs, stride_oh, - NUM_KV_SPLITS: tl.constexpr, + MAX_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, BLOCK_DV: tl.constexpr, Lv: tl.constexpr, ): @@ -493,6 +516,7 @@ def _fwd_kernel_stage2( cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load( kv_indptr + cur_batch ) + kv_splits = tl.load(num_kv_splits + cur_batch) offs_d = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lv @@ -502,10 +526,12 @@ def _fwd_kernel_stage2( acc = tl.zeros([BLOCK_DV], dtype=tl.float32) offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d - offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) - for split_kv_id in range(0, NUM_KV_SPLITS): - kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + for split_kv_id in range(0, MAX_KV_SPLITS): split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) @@ -513,7 +539,7 @@ def _fwd_kernel_stage2( tv = tl.load( Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 ) - tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) n_e_max = tl.maximum(tlogic, e_max) old_scale = tl.exp(e_max - n_e_max) @@ -533,17 +559,19 @@ def _fwd_kernel_stage2( def _decode_softmax_reducev_fwd( logits, + lse, q, o, v_buffer, kv_indptr, num_kv_splits, + max_kv_splits, ): batch, head_num = q.shape[0], q.shape[1] Lv = v_buffer.shape[-1] BLOCK_DV = triton.next_power_of_2(Lv) - NUM_KV_SPLITS = num_kv_splits + MAX_KV_SPLITS = max_kv_splits extra_kargs = {} if _is_hip: @@ -554,14 +582,17 @@ def _decode_softmax_reducev_fwd( grid = (batch, head_num) _fwd_kernel_stage2[grid]( logits, + lse, o, kv_indptr, + num_kv_splits, logits.stride(0), logits.stride(1), logits.stride(2), o.stride(0), o.stride(1), - NUM_KV_SPLITS=NUM_KV_SPLITS, + MAX_KV_SPLITS=MAX_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, BLOCK_DV=BLOCK_DV, Lv=Lv, num_warps=4, @@ -579,6 +610,7 @@ def decode_attention_fwd_normal( kv_indices, attn_logits, num_kv_splits, + max_kv_splits, sm_scale, logit_cap=0.0, ): @@ -586,14 +618,25 @@ def decode_attention_fwd_normal( q, k_buffer, v_buffer, - attn_logits, + attn_logits[0], + attn_logits[1], kv_indptr, kv_indices, num_kv_splits, + max_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) + _decode_softmax_reducev_fwd( + attn_logits[0], + attn_logits[1], + q, + o, + v_buffer, + kv_indptr, + num_kv_splits, + max_kv_splits, + ) def decode_attention_fwd_grouped( @@ -605,6 +648,7 @@ def decode_attention_fwd_grouped( kv_indices, attn_logits, num_kv_splits, + max_kv_splits, sm_scale, logit_cap=0.0, ): @@ -612,14 +656,25 @@ def decode_attention_fwd_grouped( q, k_buffer, v_buffer, - attn_logits, + attn_logits[0], + attn_logits[1], kv_indptr, kv_indices, num_kv_splits, + max_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) + _decode_softmax_reducev_fwd( + attn_logits[0], + attn_logits[1], + q, + o, + v_buffer, + kv_indptr, + num_kv_splits, + max_kv_splits, + ) def decode_attention_fwd( @@ -631,12 +686,13 @@ def decode_attention_fwd( kv_indices, attn_logits, num_kv_splits, + max_kv_splits, sm_scale, logit_cap=0.0, ): - assert num_kv_splits == attn_logits.shape[2] + assert max_kv_splits == attn_logits[0].shape[2] assert q.shape[0] <= kv_indptr.shape[0] - 1 - assert q.shape[0] <= attn_logits.shape[0] + assert q.shape[0] <= attn_logits[0].shape[0] kv_group_num = q.shape[1] // v_buffer.shape[1] @@ -651,6 +707,7 @@ def decode_attention_fwd( kv_indices, attn_logits, num_kv_splits, + max_kv_splits, sm_scale, logit_cap, ) @@ -665,6 +722,7 @@ def decode_attention_fwd( kv_indices, attn_logits, num_kv_splits, + max_kv_splits, sm_scale, logit_cap, ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 95a4dd6af..336cf8b4c 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -26,6 +26,7 @@ import tqdm from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.torchao_utils import save_gemlite_cache @@ -195,6 +196,9 @@ class CudaGraphRunner: # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs + self.num_head = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token) self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() @@ -503,9 +507,15 @@ class CudaGraphRunner: if hasattr(forward_batch.spec_info, "hidden_states"): self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states + num_kv_heads = self.num_head + if hasattr(forward_batch.token_to_kv_pool, "k_buffer"): + if isinstance(forward_batch.token_to_kv_pool.k_buffer, list): + num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1] + # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( bs, + num_kv_heads, self.req_pool_indices, self.seq_lens, forward_batch.seq_lens_sum + (bs - raw_bs), diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 14372593d..af1ced319 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -228,7 +228,8 @@ class TestTritonAttention(unittest.TestCase): seq_len = 10 # This represents the number of tokens already in the sequence total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) - num_kv_splits = 8 + max_kv_splits = 8 + num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda") # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") @@ -247,7 +248,12 @@ class TestTritonAttention(unittest.TestCase): kv_indices = torch.arange(total_tokens, device="cuda") attn_logits = torch.empty( - (B, H_Q, num_kv_splits, D + 1), + (B, H_Q, max_kv_splits, D), + dtype=torch.float32, + device="cuda", + ) + attn_lse = torch.empty( + (B, H_Q, max_kv_splits), dtype=torch.float32, device="cuda", ) @@ -259,8 +265,9 @@ class TestTritonAttention(unittest.TestCase): o, kv_indptr, kv_indices, - attn_logits, + (attn_logits, attn_lse), num_kv_splits, + max_kv_splits, sm_scale, ) @@ -284,7 +291,8 @@ class TestTritonAttention(unittest.TestCase): seq_len = S # This represents the number of tokens already in the sequence total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) - num_kv_splits = 8 + max_kv_splits = 8 + num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda") # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") @@ -304,7 +312,12 @@ class TestTritonAttention(unittest.TestCase): kv_indices = torch.arange(total_tokens, device="cuda") attn_logits = torch.empty( - (B, H_Q, num_kv_splits, D_V + 1), + (B, H_Q, max_kv_splits, D_V), + dtype=torch.float32, + device="cuda", + ) + attn_lse = torch.empty( + (B, H_Q, max_kv_splits), dtype=torch.float32, device="cuda", ) @@ -316,13 +329,19 @@ class TestTritonAttention(unittest.TestCase): o, kv_indptr, kv_indices, - attn_logits, + (attn_logits, attn_lse), num_kv_splits, + max_kv_splits, sm_scale, ) attn_logits1 = torch.empty( - (B, H_Q, num_kv_splits, D_V + 1), + (B, H_Q, max_kv_splits, D_V), + dtype=torch.float32, + device="cuda", + ) + attn_lse1 = torch.empty( + (B, H_Q, max_kv_splits, D_V), dtype=torch.float32, device="cuda", ) @@ -334,8 +353,9 @@ class TestTritonAttention(unittest.TestCase): o_grouped, kv_indptr, kv_indices, - attn_logits1, + (attn_logits1, attn_lse1), num_kv_splits, + max_kv_splits, sm_scale, )