From 9e93ef3f8e82b54a3a33687b8434b8ec8050c79b Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Thu, 20 Mar 2025 17:01:52 +0800 Subject: [PATCH] [fix] fix illegal mem access and clean up triton attention backend (#4571) --- .../srt/layers/attention/base_attn_backend.py | 1 - .../layers/attention/flashinfer_backend.py | 2 - .../attention/flashinfer_mla_backend.py | 2 - .../srt/layers/attention/triton_backend.py | 200 +++++++++--------- .../attention/triton_ops/decode_attention.py | 25 ++- .../srt/model_executor/cuda_graph_runner.py | 10 - test/srt/test_triton_attention_kernels.py | 9 +- 7 files changed, 124 insertions(+), 125 deletions(-) diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index 669d3b3a2..8364a82ca 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -39,7 +39,6 @@ class AttentionBackend(ABC): def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_kv_heads: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 913a693f4..fba806010 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -349,7 +349,6 @@ class FlashInferAttnBackend(AttentionBackend): def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_kv_heads: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, @@ -1063,7 +1062,6 @@ class FlashInferMultiStepDraftBackend: def call_fn(i, forward_batch): self.attn_backends[i].init_forward_metadata_replay_cuda_graph( bs, - -1, forward_batch.req_pool_indices, forward_batch.seq_lens, seq_lens_sum=-1, diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 6a5518a67..9af027bd1 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -279,7 +279,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_kv_heads: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, @@ -792,7 +791,6 @@ class FlashInferMLAMultiStepDraftBackend: def call_fn(i, forward_batch): self.attn_backends[i].init_forward_metadata_replay_cuda_graph( bs, - -1, forward_batch.req_pool_indices, forward_batch.seq_lens, seq_lens_sum=-1, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 00bd68d60..29547ed43 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union import torch @@ -22,20 +23,21 @@ if TYPE_CHECKING: def get_num_kv_splits_triton( num_kv_splits_ptr, seq_lens_ptr, - bs, + num_seq, + num_group, num_head, num_kv_head, max_kv_splits, device_core_count, - MAX_BS: tl.constexpr, + MAX_NUM_SEQ: tl.constexpr, ): - # TODO: this method is tunable - offs_b = tl.arange(0, MAX_BS) - mask_b = offs_b < bs + # TODO: this method is tunable, we need more online serving data to tune it + offs_seq = tl.arange(0, MAX_NUM_SEQ) + mask_seq = offs_seq < num_seq - seq_lens = tl.load(seq_lens_ptr + offs_b, mask=mask_b, other=0) + seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0) max_seq_len = tl.max(seq_lens) - seq_lens = tl.load(seq_lens_ptr + offs_b, mask=mask_b, other=max_seq_len) + seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len) min_seq_len = tl.min(seq_lens) if max_seq_len * 8 < min_seq_len * 10: min_seq_len = max_seq_len @@ -43,24 +45,43 @@ def get_num_kv_splits_triton( kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1) # NOTE: this is a hack to let num_kv_split grows up with seqlen gradually - ext_seq_len = tl.cast(tl.cdiv(max_seq_len, 256), tl.float32) - ext_device_core_count = device_core_count * tl.maximum( - tl.cast(tl.ceil(tl.log2(ext_seq_len)), tl.int32), 1 + ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0 + ext_device_core_count = tl.cast( + device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32 ) block_h, num_kv_group = 16, num_head // num_kv_head if num_kv_group == 1: - bh_grid = bs * num_head + token_grid = num_seq * num_group * num_head else: # from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd block_h = tl.minimum(block_h, num_kv_group) - bh_grid = bs * tl.cdiv(num_head, block_h) - max_kv_splits_2 = tl.minimum(tl.cdiv(ext_device_core_count, bh_grid), max_kv_splits) + token_grid = num_seq * num_group * tl.cdiv(num_head, block_h) + max_kv_splits_2 = tl.minimum( + tl.cdiv(ext_device_core_count, token_grid), max_kv_splits + ) kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2) num_kv_splits = tl.maximum( tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2) ) - tl.store(num_kv_splits_ptr + offs_b, num_kv_splits, mask=mask_b) + + offs_token = offs_seq * num_group + mask_token = offs_token < num_seq * num_group + for i in range(0, num_group): + tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token) + + +@dataclass +class ForwardMetadata: + attn_logits: torch.Tensor + attn_lse: torch.Tensor + max_extend_len: int + num_kv_splits: torch.Tensor + kv_indptr: torch.Tensor + kv_indices: torch.Tensor + qo_indptr: torch.Tensor + custom_mask: torch.Tensor + mask_indptr: torch.Tensor class TritonAttnBackend(AttentionBackend): @@ -110,6 +131,9 @@ class TritonAttnBackend(AttentionBackend): self.num_head = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) + self.num_kv_head = model_runner.model_config.get_num_kv_heads( + get_attention_tp_size() + ) self.static_kv_splits = get_bool_env_var( "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false" @@ -117,7 +141,7 @@ class TritonAttnBackend(AttentionBackend): self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] - self.forward_metadata = None + self.forward_metadata: ForwardMetadata = None self.max_context_len = model_runner.model_config.context_len @@ -128,23 +152,33 @@ class TritonAttnBackend(AttentionBackend): self, num_kv_splits: torch.Tensor, seq_lens: torch.Tensor, - bs: int, - num_kv_head: int, ): - MAX_SCHEDULE_BS = 4096 - if self.static_kv_splits or self.device_core_count <= 0 or bs > MAX_SCHEDULE_BS: + num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0] + num_group = num_token // num_seq + + assert ( + num_group * num_seq == num_token + ), f"num_seq({num_seq}), num_token({num_token}), something goes wrong!" + + if self.static_kv_splits or self.device_core_count <= 0: num_kv_splits.fill_(self.max_kv_splits) return + if num_seq < 256: + SCHEDULE_SEQ = 256 + else: + SCHEDULE_SEQ = triton.next_power_of_2(num_seq) + get_num_kv_splits_triton[(1,)]( num_kv_splits, seq_lens, - bs, + num_seq, + num_group, self.num_head, - num_kv_head, + self.num_kv_head, self.max_kv_splits, self.device_core_count, - MAX_BS=MAX_SCHEDULE_BS, + MAX_NUM_SEQ=SCHEDULE_SEQ, ) def init_forward_metadata(self, forward_batch: ForwardBatch): @@ -174,36 +208,19 @@ class TritonAttnBackend(AttentionBackend): kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 - attn_logits = [ - torch.empty( - ( - bs, - self.num_head, - self.max_kv_splits, - self.v_head_dim, - ), - dtype=torch.float32, - device=self.device, - ), - torch.empty( - ( - bs, - self.num_head, - self.max_kv_splits, - ), - dtype=torch.float32, - device=self.device, - ), - ] + attn_logits = torch.empty( + (bs, self.num_head, self.max_kv_splits, self.v_head_dim), + dtype=torch.float32, + device=self.device, + ) + attn_lse = torch.empty( + (bs, self.num_head, self.max_kv_splits), + dtype=torch.float32, + device=self.device, + ) num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device) - num_kv_heads = self.num_head - if hasattr(forward_batch.token_to_kv_pool, "k_buffer"): - if isinstance(forward_batch.token_to_kv_pool.k_buffer, list): - num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1] - self.get_num_kv_splits( - num_kv_splits, forward_batch.seq_lens, bs, num_kv_heads - ) + self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens) qo_indptr = None custom_mask = None @@ -244,6 +261,7 @@ class TritonAttnBackend(AttentionBackend): max_extend_len = self.num_draft_tokens num_kv_splits = None attn_logits = None + attn_lse = None elif forward_batch.forward_mode.is_draft_extend(): kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( @@ -254,9 +272,13 @@ class TritonAttnBackend(AttentionBackend): ) ) mask_indptr = None + # TODO(FIXME): This will trigger an invalid Eagle tree when using + # `max(spec_info.accept_length_cpu)`. + # It might have been forgotten to update somewhere. max_extend_len = torch.max(spec_info.accept_length).item() num_kv_splits = None attn_logits = None + attn_lse = None else: kv_indptr[1 : bs + 1] = torch.cumsum( forward_batch.extend_prefix_lens, dim=0 @@ -283,11 +305,13 @@ class TritonAttnBackend(AttentionBackend): custom_mask = None mask_indptr = None attn_logits = None + attn_lse = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() num_kv_splits = None - self.forward_metadata = ( + self.forward_metadata = ForwardMetadata( attn_logits, + attn_lse, max_extend_len, num_kv_splits, kv_indptr, @@ -300,18 +324,16 @@ class TritonAttnBackend(AttentionBackend): def init_cuda_graph_state( self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None ): - self.cuda_graph_attn_logits = [ - torch.zeros( - (max_bs, self.num_head, self.max_kv_splits, self.v_head_dim), - dtype=torch.float32, - device=self.device, - ), - torch.zeros( - (max_bs, self.num_head, self.max_kv_splits), - dtype=torch.float32, - device=self.device, - ), - ] + self.cuda_graph_attn_logits = torch.zeros( + (max_bs, self.num_head, self.max_kv_splits, self.v_head_dim), + dtype=torch.float32, + device=self.device, + ) + self.cuda_graph_attn_lse = torch.zeros( + (max_bs, self.num_head, self.max_kv_splits), + dtype=torch.float32, + device=self.device, + ) self.cuda_graph_num_kv_splits = torch.full( (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device ) @@ -362,6 +384,7 @@ class TritonAttnBackend(AttentionBackend): kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices attn_logits = self.cuda_graph_attn_logits + attn_lse = self.cuda_graph_attn_lse max_extend_len = None num_kv_splits = self.cuda_graph_num_kv_splits qo_indptr = None @@ -396,13 +419,15 @@ class TritonAttnBackend(AttentionBackend): max_extend_len = self.num_draft_tokens num_kv_splits = None attn_logits = None + attn_lse = None else: raise ValueError( f"Invalid forward mode: {forward_mode=} for CUDA Graph capture." ) - self.forward_metadata = ( + self.forward_metadata = ForwardMetadata( attn_logits, + attn_lse, max_extend_len, num_kv_splits, kv_indptr, @@ -415,7 +440,6 @@ class TritonAttnBackend(AttentionBackend): def init_forward_metadata_replay_cuda_graph( self, bs: int, - num_kv_head: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, @@ -442,10 +466,12 @@ class TritonAttnBackend(AttentionBackend): kv_indices, self.req_to_token.stride(0), ) + num_token = bs else: kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices - self.get_num_kv_splits(num_kv_splits, seq_lens, bs, num_kv_head) + num_token = spec_info.kv_indptr.shape[0] - 1 + self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs]) elif forward_mode.is_target_verify(): # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr bs = len(req_pool_indices) @@ -502,17 +528,6 @@ class TritonAttnBackend(AttentionBackend): layer, forward_batch.out_cache_loc, k, v ) - ( - _, - max_extend_len, - _, - kv_indptr, - kv_indices, - qo_indptr, - custom_mask, - mask_indptr, - ) = self.forward_metadata - self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -520,12 +535,12 @@ class TritonAttnBackend(AttentionBackend): o.view(-1, layer.tp_q_head_num, layer.v_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), - qo_indptr, - kv_indptr, - kv_indices, - custom_mask, - mask_indptr, - max_extend_len, + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.custom_mask, + self.forward_metadata.mask_indptr, + self.forward_metadata.max_extend_len, layer.scaling, layer.logit_cap, ) @@ -550,10 +565,6 @@ class TritonAttnBackend(AttentionBackend): else: o = torch.empty_like(q) - attn_logits, _, num_kv_splits, kv_indptr, kv_indices, _, _, _ = ( - self.forward_metadata - ) - if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v @@ -564,10 +575,11 @@ class TritonAttnBackend(AttentionBackend): forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - kv_indptr, - kv_indices, - attn_logits, - num_kv_splits, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.attn_logits, + self.forward_metadata.attn_lse, + self.forward_metadata.num_kv_splits, self.max_kv_splits, layer.scaling, layer.logit_cap, @@ -700,15 +712,9 @@ class TritonMultiStepDraftBackend: def init_forward_metadata_replay_cuda_graph( self, forward_batch: ForwardBatch, bs: int ): - num_kv_heads = self.num_head - if hasattr(forward_batch.token_to_kv_pool, "k_buffer"): - if isinstance(forward_batch.token_to_kv_pool.k_buffer, list): - num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1] - def call_fn(i, forward_batch): self.attn_backends[i].init_forward_metadata_replay_cuda_graph( bs, - num_kv_heads, forward_batch.req_pool_indices, forward_batch.seq_lens, seq_lens_sum=-1, diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 4ae5cc205..b7dbdb16d 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -609,6 +609,7 @@ def decode_attention_fwd_normal( kv_indptr, kv_indices, attn_logits, + attn_lse, num_kv_splits, max_kv_splits, sm_scale, @@ -618,8 +619,8 @@ def decode_attention_fwd_normal( q, k_buffer, v_buffer, - attn_logits[0], - attn_logits[1], + attn_logits, + attn_lse, kv_indptr, kv_indices, num_kv_splits, @@ -628,8 +629,8 @@ def decode_attention_fwd_normal( logit_cap, ) _decode_softmax_reducev_fwd( - attn_logits[0], - attn_logits[1], + attn_logits, + attn_lse, q, o, v_buffer, @@ -647,6 +648,7 @@ def decode_attention_fwd_grouped( kv_indptr, kv_indices, attn_logits, + attn_lse, num_kv_splits, max_kv_splits, sm_scale, @@ -656,8 +658,8 @@ def decode_attention_fwd_grouped( q, k_buffer, v_buffer, - attn_logits[0], - attn_logits[1], + attn_logits, + attn_lse, kv_indptr, kv_indices, num_kv_splits, @@ -666,8 +668,8 @@ def decode_attention_fwd_grouped( logit_cap, ) _decode_softmax_reducev_fwd( - attn_logits[0], - attn_logits[1], + attn_logits, + attn_lse, q, o, v_buffer, @@ -685,14 +687,15 @@ def decode_attention_fwd( kv_indptr, kv_indices, attn_logits, + attn_lse, num_kv_splits, max_kv_splits, sm_scale, logit_cap=0.0, ): - assert max_kv_splits == attn_logits[0].shape[2] + assert max_kv_splits == attn_logits.shape[2] assert q.shape[0] <= kv_indptr.shape[0] - 1 - assert q.shape[0] <= attn_logits[0].shape[0] + assert q.shape[0] <= attn_logits.shape[0] kv_group_num = q.shape[1] // v_buffer.shape[1] @@ -706,6 +709,7 @@ def decode_attention_fwd( kv_indptr, kv_indices, attn_logits, + attn_lse, num_kv_splits, max_kv_splits, sm_scale, @@ -721,6 +725,7 @@ def decode_attention_fwd( kv_indptr, kv_indices, attn_logits, + attn_lse, num_kv_splits, max_kv_splits, sm_scale, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 336cf8b4c..95a4dd6af 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -26,7 +26,6 @@ import tqdm from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture -from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.torchao_utils import save_gemlite_cache @@ -196,9 +195,6 @@ class CudaGraphRunner: # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs - self.num_head = ( - model_runner.model_config.num_attention_heads // get_attention_tp_size() - ) self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token) self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() @@ -507,15 +503,9 @@ class CudaGraphRunner: if hasattr(forward_batch.spec_info, "hidden_states"): self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states - num_kv_heads = self.num_head - if hasattr(forward_batch.token_to_kv_pool, "k_buffer"): - if isinstance(forward_batch.token_to_kv_pool.k_buffer, list): - num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1] - # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( bs, - num_kv_heads, self.req_pool_indices, self.seq_lens, forward_batch.seq_lens_sum + (bs - raw_bs), diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index af1ced319..2b90ce81b 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -265,7 +265,8 @@ class TestTritonAttention(unittest.TestCase): o, kv_indptr, kv_indices, - (attn_logits, attn_lse), + attn_logits, + attn_lse, num_kv_splits, max_kv_splits, sm_scale, @@ -329,7 +330,8 @@ class TestTritonAttention(unittest.TestCase): o, kv_indptr, kv_indices, - (attn_logits, attn_lse), + attn_logits, + attn_lse, num_kv_splits, max_kv_splits, sm_scale, @@ -353,7 +355,8 @@ class TestTritonAttention(unittest.TestCase): o_grouped, kv_indptr, kv_indices, - (attn_logits1, attn_lse1), + attn_logits1, + attn_lse1, num_kv_splits, max_kv_splits, sm_scale,