From be2d985df83a230affa1cf0e4e59142d1beeb847 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 13 Jun 2025 16:01:23 -0700 Subject: [PATCH] Minor style change of triton backend (#7165) --- .../srt/layers/attention/triton_backend.py | 226 +++++++++--------- 1 file changed, 113 insertions(+), 113 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 4b81176aa..38c239177 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -20,117 +20,6 @@ if TYPE_CHECKING: from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput -@triton.jit -def get_num_kv_splits_triton( - num_kv_splits_ptr, - seq_lens_ptr, - num_seq, - num_group, - num_head, - num_kv_head, - max_kv_splits, - device_core_count, - MAX_NUM_SEQ: tl.constexpr, -): - # TODO: this method is tunable, we need more online serving data to tune it - offs_seq = tl.arange(0, MAX_NUM_SEQ) - mask_seq = offs_seq < num_seq - - seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0) - max_seq_len = tl.max(seq_lens) - seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len) - min_seq_len = tl.min(seq_lens) - if max_seq_len * 8 < min_seq_len * 10: - min_seq_len = max_seq_len - max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits) - kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1) - - # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually - ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0 - ext_device_core_count = tl.cast( - device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32 - ) - block_h, num_kv_group = 16, num_head // num_kv_head - if num_kv_group == 1: - token_grid = num_seq * num_group * num_head - else: - # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd - block_h = tl.minimum(block_h, num_kv_group) - token_grid = num_seq * num_group * tl.cdiv(num_head, block_h) - max_kv_splits_2 = tl.minimum( - tl.cdiv(ext_device_core_count, token_grid), max_kv_splits - ) - kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2) - - num_kv_splits = tl.maximum( - tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2) - ) - - offs_token = offs_seq * num_group - mask_token = offs_token < num_seq * num_group - for i in range(0, num_group): - tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token) - - -def update_sliding_window_buffer( - window_kv_indptr, - req_to_token, - sliding_window_size, - seq_lens, - req_pool_indices, - bs, - device, -): - window_kv_lens = torch.minimum( - seq_lens, - torch.tensor(sliding_window_size + 1), - ) - window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0) - window_kv_indptr = window_kv_indptr[: bs + 1] - window_kv_indices = torch.empty( - window_kv_indptr[-1], dtype=torch.int32, device=device - ) - window_kv_start_idx = seq_lens - window_kv_lens - create_flashinfer_kv_indices_triton[(bs,)]( - req_to_token, - req_pool_indices, - window_kv_lens, - window_kv_indptr, - window_kv_start_idx, - window_kv_indices, - req_to_token.stride(0), - ) - return window_kv_indptr, window_kv_indices, window_kv_lens - - -def update_sliding_window_buffer_cuda_graph( - window_kv_indptr, - window_kv_indices, - req_to_token, - sliding_window_size, - seq_lens, - req_pool_indices, - bs, -): - window_kv_lens = torch.minimum( - seq_lens, - torch.tensor(sliding_window_size + 1), - ) - window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0) - window_kv_indptr = window_kv_indptr[: bs + 1] - window_kv_start_idx = seq_lens - window_kv_lens - create_flashinfer_kv_indices_triton[(bs,)]( - req_to_token, - req_pool_indices, - window_kv_lens, - window_kv_indptr, - window_kv_start_idx, - window_kv_indices, - req_to_token.stride(0), - ) - return window_kv_indptr, window_kv_lens - - @dataclass class ForwardMetadata: attn_logits: torch.Tensor @@ -165,8 +54,8 @@ class TritonAttnBackend(AttentionBackend): super().__init__() - self.decode_attention_fwd = decode_attention_fwd - self.extend_attention_fwd = extend_attention_fwd + self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd) + self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd) self.skip_prefill = skip_prefill @@ -973,3 +862,114 @@ class TritonMultiStepDraftBackend: ) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) + + +@triton.jit +def get_num_kv_splits_triton( + num_kv_splits_ptr, + seq_lens_ptr, + num_seq, + num_group, + num_head, + num_kv_head, + max_kv_splits, + device_core_count, + MAX_NUM_SEQ: tl.constexpr, +): + # TODO: this method is tunable, we need more online serving data to tune it + offs_seq = tl.arange(0, MAX_NUM_SEQ) + mask_seq = offs_seq < num_seq + + seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0) + max_seq_len = tl.max(seq_lens) + seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len) + min_seq_len = tl.min(seq_lens) + if max_seq_len * 8 < min_seq_len * 10: + min_seq_len = max_seq_len + max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits) + kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1) + + # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually + ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0 + ext_device_core_count = tl.cast( + device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32 + ) + block_h, num_kv_group = 16, num_head // num_kv_head + if num_kv_group == 1: + token_grid = num_seq * num_group * num_head + else: + # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd + block_h = tl.minimum(block_h, num_kv_group) + token_grid = num_seq * num_group * tl.cdiv(num_head, block_h) + max_kv_splits_2 = tl.minimum( + tl.cdiv(ext_device_core_count, token_grid), max_kv_splits + ) + kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2) + + num_kv_splits = tl.maximum( + tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2) + ) + + offs_token = offs_seq * num_group + mask_token = offs_token < num_seq * num_group + for i in range(0, num_group): + tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token) + + +def update_sliding_window_buffer( + window_kv_indptr, + req_to_token, + sliding_window_size, + seq_lens, + req_pool_indices, + bs, + device, +): + window_kv_lens = torch.minimum( + seq_lens, + torch.tensor(sliding_window_size + 1), + ) + window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0) + window_kv_indptr = window_kv_indptr[: bs + 1] + window_kv_indices = torch.empty( + window_kv_indptr[-1], dtype=torch.int32, device=device + ) + window_kv_start_idx = seq_lens - window_kv_lens + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + window_kv_lens, + window_kv_indptr, + window_kv_start_idx, + window_kv_indices, + req_to_token.stride(0), + ) + return window_kv_indptr, window_kv_indices, window_kv_lens + + +def update_sliding_window_buffer_cuda_graph( + window_kv_indptr, + window_kv_indices, + req_to_token, + sliding_window_size, + seq_lens, + req_pool_indices, + bs, +): + window_kv_lens = torch.minimum( + seq_lens, + torch.tensor(sliding_window_size + 1), + ) + window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0) + window_kv_indptr = window_kv_indptr[: bs + 1] + window_kv_start_idx = seq_lens - window_kv_lens + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + window_kv_lens, + window_kv_indptr, + window_kv_start_idx, + window_kv_indices, + req_to_token.stride(0), + ) + return window_kv_indptr, window_kv_lens