diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py index 35fe4ed92..3fc79fe0d 100644 --- a/python/sglang/srt/layers/attention_backend.py +++ b/python/sglang/srt/layers/attention_backend.py @@ -36,14 +36,41 @@ class AttentionBackend(ABC): def init_forward_metadata( self, batch: ScheduleBatch, input_metadata: InputMetadata ): - pass + """Init the metadata for a forward pass.""" + raise NotImplementedError() - def forward(self, q, k, v, layer, input_metadata: InputMetadata): + def init_cuda_graph_state(self, max_bs: int): + """Init the global shared states for cuda graph.""" + raise NotImplementedError() + + def init_forward_metadata_capture_cuda_graph( + self, bs: int, req_pool_indices, seq_lens + ): + """Init the metadata for a forward pass for capturing a cuda graph.""" + raise NotImplementedError() + + def init_forward_metadata_replay_cuda_graph( + self, bs: int, req_pool_indices, seq_lens + ): + """Init the metadata for a forward pass for replying a cuda graph.""" + raise NotImplementedError() + + def get_cuda_graph_seq_len_fill_value(self): + raise NotImplementedError() + + def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + """Run forward on an attention layer.""" if input_metadata.forward_mode.is_decode(): return self.forward_decode(q, k, v, layer, input_metadata) else: return self.forward_extend(q, k, v, layer, input_metadata) + def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + raise NotImplementedError() + + def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + raise NotImplementedError() + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" @@ -153,7 +180,9 @@ class FlashInferAttnBackend(AttentionBackend): self.cuda_graph_kv_indices.clone(), ] - def capture_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens): + def init_forward_metadata_capture_cuda_graph( + self, bs: int, req_pool_indices, seq_lens + ): if self.model_runner.sliding_window_size is None: decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, @@ -194,7 +223,9 @@ class FlashInferAttnBackend(AttentionBackend): self.forward_metadata = (False, None, decode_wrapper) - def replay_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens): + def init_forward_metadata_replay_cuda_graph( + self, bs: int, req_pool_indices, seq_lens + ): update_flashinfer_indices( ForwardMode.DECODE, self.model_runner, @@ -204,6 +235,9 @@ class FlashInferAttnBackend(AttentionBackend): self.cuda_graph_metadata[bs], ) + def get_cuda_graph_seq_len_fill_value(self): + return 0 + def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): if not isinstance(self.prefill_wrapper_paged, list): prefill_wrapper_paged = self.prefill_wrapper_paged @@ -290,6 +324,7 @@ class TritonAttnBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): # Lazy import to avoid the initialization of cuda context from sglang.srt.layers.triton_attention.decode_attention import ( + REDUCE_TORCH_TYPE, decode_attention_fwd, ) from sglang.srt.layers.triton_attention.extend_attention import ( @@ -300,29 +335,78 @@ class TritonAttnBackend(AttentionBackend): self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_fwd + self.REDUCE_TORCH_TYPE = REDUCE_TORCH_TYPE + self.num_head = model_runner.model_config.num_attention_heads self.forward_metadata = None + self.cuda_graph_max_seq_len = model_runner.model_config.context_len + def init_forward_metadata( self, batch: ScheduleBatch, input_metadata: InputMetadata ): """Init auxiliary variables for triton attention backend.""" if input_metadata.forward_mode.is_decode(): - max_seq_len = torch.max(input_metadata.seq_lens).item() start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32) start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0) total_num_tokens = torch.sum(input_metadata.seq_lens).item() + attn_logits = torch.empty( + (self.num_head, total_num_tokens), + dtype=self.REDUCE_TORCH_TYPE, + device="cuda", + ) + + max_seq_len = torch.max(input_metadata.seq_lens).item() max_extend_len = None else: - start_loc = max_seq_len = total_num_tokens = None + start_loc = attn_logits = max_seq_len = None prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item() - self.forward_metadata = start_loc, max_seq_len, max_extend_len, total_num_tokens + self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len + + def init_cuda_graph_state(self, max_bs: int): + self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len + + self.cuda_graph_start_loc = torch.zeros( + (max_bs,), dtype=torch.int32, device="cuda" + ) + self.cuda_graph_attn_logits = torch.empty( + (self.num_head, self.cuda_graph_max_total_num_tokens), + dtype=self.REDUCE_TORCH_TYPE, + device="cuda", + ) + + def init_forward_metadata_capture_cuda_graph( + self, bs: int, req_pool_indices, seq_lens + ): + self.forward_metadata = ( + self.cuda_graph_start_loc, + self.cuda_graph_attn_logits, + self.cuda_graph_max_seq_len, + None, + ) + + def init_forward_metadata_replay_cuda_graph( + self, bs: int, req_pool_indices, seq_lens + ): + self.cuda_graph_start_loc.zero_() + self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) + + self.forward_metadata = ( + self.cuda_graph_start_loc, + self.cuda_graph_attn_logits, + self.cuda_graph_max_seq_len, + None, + ) + + def get_cuda_graph_seq_len_fill_value(self): + return 1 def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) else: @@ -332,8 +416,7 @@ class TritonAttnBackend(AttentionBackend): layer.layer_id, input_metadata.out_cache_loc, k, v ) - start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata - + start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -350,16 +433,16 @@ class TritonAttnBackend(AttentionBackend): layer.scaling, layer.logit_cap, ) - return o def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + # TODO: reuse the buffer across layers if layer.qk_head_dim != layer.v_head_dim: o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) else: o = torch.empty_like(q) - start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata + start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata input_metadata.token_to_kv_pool.set_kv_buffer( layer.layer_id, input_metadata.out_cache_loc, k, v @@ -374,10 +457,9 @@ class TritonAttnBackend(AttentionBackend): input_metadata.req_pool_indices, start_loc, input_metadata.seq_lens, + attn_logits, max_seq_len, - total_num_tokens, layer.scaling, layer.logit_cap, ) - return o diff --git a/python/sglang/srt/layers/flashinfer_utils.py b/python/sglang/srt/layers/flashinfer_utils.py index c473d6e45..291091b10 100644 --- a/python/sglang/srt/layers/flashinfer_utils.py +++ b/python/sglang/srt/layers/flashinfer_utils.py @@ -66,18 +66,18 @@ class FlashinferUpdater: self.head_dim = model_runner.model_config.head_dim self.batch_size = len(req_pool_indices) - self.kv_last_page_len = torch.ones( - (self.batch_size,), dtype=torch.int32, device="cuda" + self.decode_wrapper = ( + decode_wrapper or self.model_runner.attn_backend.decode_wrapper + ) + self.prefill_wrapper_ragged = ( + self.model_runner.attn_backend.prefill_wrapper_ragged + ) + self.prefill_wrapper_paged = ( + self.model_runner.attn_backend.prefill_wrapper_paged ) - ( - self.decode_wrapper, - self.prefill_wrapper_ragged, - self.prefill_wrapper_paged, - ) = ( - decode_wrapper or self.model_runner.attn_backend.decode_wrapper, - self.model_runner.attn_backend.prefill_wrapper_ragged, - self.model_runner.attn_backend.prefill_wrapper_paged, + self.kv_last_page_len = torch.ones( + (self.batch_size,), dtype=torch.int32, device="cuda" ) def _init_indices_no_sliding_window(self): diff --git a/python/sglang/srt/layers/triton_attention/decode_attention.py b/python/sglang/srt/layers/triton_attention/decode_attention.py index adfa0d936..5d8eb9ae4 100644 --- a/python/sglang/srt/layers/triton_attention/decode_attention.py +++ b/python/sglang/srt/layers/triton_attention/decode_attention.py @@ -114,7 +114,7 @@ def _fwd_kernel_stage1( @triton.jit def _fwd_kernel_stage2( - Logics, + logits, V_Buffer, Out, Req_to_tokens, @@ -162,7 +162,7 @@ def _fwd_kernel_stage2( ) qk = tl.load( - Logics + logits + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n), mask=start_n + offs_n < cur_batch_seq_len, @@ -238,7 +238,7 @@ def _decode_att_m_fwd( def _decode_softmax_reducev_fwd( - logics, + logits, v_buffer, o, req_to_tokens, @@ -247,9 +247,9 @@ def _decode_softmax_reducev_fwd( b_seq_len, ): BLOCK = 64 - batch, head = b_seq_len.shape[0], logics.shape[0] + batch, head = b_seq_len.shape[0], logits.shape[0] grid = (batch, head, 1) - kv_group_num = logics.shape[0] // v_buffer.shape[1] + kv_group_num = logits.shape[0] // v_buffer.shape[1] num_warps = 1 @@ -257,14 +257,14 @@ def _decode_softmax_reducev_fwd( BLOCK_DMODEL = triton.next_power_of_2(Lv) _fwd_kernel_stage2[grid]( - logics, + logits, v_buffer, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len, - logics.stride(0), + logits.stride(0), v_buffer.stride(0), v_buffer.stride(1), o.stride(0), @@ -387,7 +387,7 @@ def _fwd_grouped_kernel_stage1( @triton.jit def _fwd_grouped_kernel_stage2( - Logics, + logits, V_Buffer, Out, Req_to_tokens, @@ -443,7 +443,7 @@ def _fwd_grouped_kernel_stage2( ) qk = tl.load( - Logics + offs_qk, + logits + offs_qk, mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len), other=float("-inf"), ) @@ -531,7 +531,7 @@ def _decode_grouped_att_m_fwd( def _decode_grouped_softmax_reducev_fwd( - logics, + logits, v_buffer, o, req_to_tokens, @@ -540,8 +540,8 @@ def _decode_grouped_softmax_reducev_fwd( b_seq_len, ): BLOCK = 128 - batch, head_num = b_seq_len.shape[0], logics.shape[0] - kv_group_num = logics.shape[0] // v_buffer.shape[1] + batch, head_num = b_seq_len.shape[0], logits.shape[0] + kv_group_num = logits.shape[0] // v_buffer.shape[1] BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) @@ -551,14 +551,14 @@ def _decode_grouped_softmax_reducev_fwd( BLOCK_DMODEL = triton.next_power_of_2(Lv) _fwd_grouped_kernel_stage2[grid]( - logics, + logits, v_buffer, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len, - logics.stride(0), + logits.stride(0), v_buffer.stride(0), v_buffer.stride(1), o.stride(0), @@ -584,17 +584,11 @@ def decode_attention_fwd( b_req_idx, b_start_loc, b_seq_len, + attn_logits, max_len_in_batch, - total_num_tokens, sm_scale, logit_cap=0.0, - att_m=None, ): - if att_m is None: - att_m = torch.empty( - (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda" - ) - kv_group_num = q.shape[1] // v_buffer.shape[1] if kv_group_num == 1: @@ -602,7 +596,7 @@ def decode_attention_fwd( _decode_att_m_fwd( q, k_buffer, - att_m, + attn_logits, req_to_token, b_req_idx, b_start_loc, @@ -612,7 +606,7 @@ def decode_attention_fwd( logit_cap, ) _decode_softmax_reducev_fwd( - att_m, + attn_logits, v_buffer, o, req_to_token, @@ -625,7 +619,7 @@ def decode_attention_fwd( _decode_grouped_att_m_fwd( q, k_buffer, - att_m, + attn_logits, req_to_token, b_req_idx, b_start_loc, @@ -635,7 +629,7 @@ def decode_attention_fwd( logit_cap, ) _decode_grouped_softmax_reducev_fwd( - att_m, + attn_logits, v_buffer, o, req_to_token, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index ecaeb404c..27699f65d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,13 +19,12 @@ limitations under the License. import bisect from contextlib import contextmanager -from typing import Callable +from typing import TYPE_CHECKING, Callable import torch from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp -from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices from sglang.srt.layers.logits_processor import ( LogitsMetadata, LogitsProcessor, @@ -35,6 +36,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import monkey_patch_vllm_all_gather +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + def _to_torch(model: torch.nn.Module, reverse: bool = False): for sub in model._modules.values(): @@ -111,7 +115,7 @@ class CudaGraphRunner: self.req_pool_indices = torch.zeros( (self.max_bs,), dtype=torch.int32, device="cuda" ) - self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda") self.position_ids_offsets = torch.ones( (self.max_bs,), dtype=torch.int32, device="cuda" ) @@ -121,6 +125,9 @@ class CudaGraphRunner: # Attention backend self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) + self.seq_len_fill_value = ( + self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() + ) # Sampling info vocab_size = model_runner.model_config.vocab_size @@ -176,7 +183,7 @@ class CudaGraphRunner: out_cache_loc = self.out_cache_loc[:bs] # Attention backend - self.model_runner.attn_backend.capture_cuda_graph_init( + self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( bs, req_pool_indices, seq_lens ) @@ -227,7 +234,7 @@ class CudaGraphRunner: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs != raw_bs: - self.seq_lens.zero_() + self.seq_lens.fill_(self.seq_len_fill_value) self.position_ids_offsets.fill_(1) self.out_cache_loc.zero_() @@ -239,7 +246,7 @@ class CudaGraphRunner: self.out_cache_loc[:raw_bs] = batch.out_cache_loc # Attention backend - self.model_runner.attn_backend.replay_cuda_graph_init( + self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( bs, self.req_pool_indices, self.seq_lens ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 80c741652..5868b0074 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -445,12 +445,6 @@ class ModelRunner: if self.server_args.disable_cuda_graph: return - if self.server_args.attention_backend != "flashinfer": - logger.warning( - f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}" - ) - return - logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) diff --git a/test/srt/test_serving_throughput.py b/test/srt/test_serving_throughput.py index 81aff3ed2..16da1d963 100644 --- a/test/srt/test_serving_throughput.py +++ b/test/srt/test_serving_throughput.py @@ -96,6 +96,16 @@ class TestServingThroughput(unittest.TestCase): if os.getenv("SGLANG_IS_IN_CI", "false") == "true": assert res["output_throughput"] > 2400 + def test_default_with_triton_attention_backend(self): + res = self.run_test( + disable_radix_cache=ServerArgs.disable_radix_cache, + attention_backend="triton", + chunked_prefill_size=-1, + ) + + if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + assert res["output_throughput"] > 2400 + if __name__ == "__main__": unittest.main()