diff --git a/python/sglang/srt/layers/attention/cutlass_mla_backend.py b/python/sglang/srt/layers/attention/cutlass_mla_backend.py index a49ed1ab5..416eff724 100644 --- a/python/sglang/srt/layers/attention/cutlass_mla_backend.py +++ b/python/sglang/srt/layers/attention/cutlass_mla_backend.py @@ -108,7 +108,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): PAGE_SIZE, ) workspace_size = cutlass_mla_get_workspace_size( - max_seqlen_pad * PAGE_SIZE, bs + max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1 ) workspace = torch.empty( workspace_size, device="cuda", dtype=torch.uint8 @@ -138,7 +138,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): cuda_graph_kv_indices = block_kv_indices workspace_size = cutlass_mla_get_workspace_size( - cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs + cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs, num_kv_splits=1 ) self.cuda_graph_mla_workspace = torch.empty( workspace_size, device="cuda", dtype=torch.uint8 @@ -280,6 +280,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): seq_lens=forward_batch.seq_lens.to(torch.int32), page_table=self.forward_metadata.block_kv_indices, workspace=self.forward_metadata.workspace, + num_kv_splits=1, ) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) diff --git a/sgl-kernel/benchmark/bench_cutlass_mla.py b/sgl-kernel/benchmark/bench_cutlass_mla.py index a5a8fe0c4..785e51033 100644 --- a/sgl-kernel/benchmark/bench_cutlass_mla.py +++ b/sgl-kernel/benchmark/bench_cutlass_mla.py @@ -95,6 +95,7 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits): seq_lens, block_table, workspace, + 1.44, num_kv_splits, ), quantiles=quantiles, diff --git a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu index 9ba335e23..7c060274b 100644 --- a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu +++ b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu @@ -36,7 +36,8 @@ void cutlass_mla_decode( torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, - torch::Tensor const& workspace) { + torch::Tensor const& workspace, + int64_t num_kv_splits) { TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); } int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { @@ -98,6 +99,7 @@ typename T::Fmha::Arguments args_from_options( at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, at::Tensor const& page_table, + double sm_scale, int64_t num_kv_splits) { cutlass::KernelHardwareInfo hw_info; hw_info.device_id = q_nope.device().index(); @@ -115,10 +117,7 @@ typename T::Fmha::Arguments args_from_options( auto [H, K, D, B] = problem_shape; auto [D_latent, D_rope] = D; - // the scale is based on the non-absorbed sizes, change as appropriate - // we can't determine this parameter from the info we have, it's an input - int D_non_latent = 128; - float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope)); + float scale = float(sm_scale); using StrideQ = typename T::StrideQ; using StrideK = typename T::StrideK; @@ -183,11 +182,12 @@ void runMla( at::Tensor const& seq_lens, at::Tensor const& page_table, at::Tensor const& workspace, + double sm_scale, int64_t num_kv_splits, cudaStream_t stream) { using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits); + auto arguments = args_from_options(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); CUTLASS_CHECK(fmha.can_implement(arguments)); @@ -215,6 +215,7 @@ void cutlass_mla_decode( torch::Tensor const& seq_lens, torch::Tensor const& page_table, torch::Tensor const& workspace, + double sm_scale, int64_t num_kv_splits) { auto in_dtype = q_nope.dtype(); at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; @@ -228,13 +229,13 @@ void cutlass_mla_decode( DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { if (in_dtype == at::ScalarType::Half) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::BFloat16) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); } diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 9f3c2be9c..68424f07c 100755 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -60,7 +60,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2); m.def( "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor " - "page_table, Tensor! workspace, int num_kv_splits) -> ()"); + "page_table, Tensor! workspace, float sm_scale, int num_kv_splits) -> ()"); m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h old mode 100755 new mode 100644 index 1cc88afa0..bb267735b --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -111,9 +111,13 @@ void cutlass_mla_decode( torch::Tensor const& seq_lens, torch::Tensor const& page_table, torch::Tensor const& workspace, - int64_t num_kv_splits = -1); + double sm_scale, + int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */); int64_t cutlass_mla_get_workspace_size( - int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0, int64_t num_kv_splits = -1); + int64_t max_seq_len, + int64_t num_batches, + int64_t sm_count = 0, + int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */); /* * From csrc/elementwise */ diff --git a/sgl-kernel/python/sgl_kernel/attention.py b/sgl-kernel/python/sgl_kernel/attention.py index bb3c2af1c..f15b4fa24 100644 --- a/sgl-kernel/python/sgl_kernel/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -58,7 +58,8 @@ def cutlass_mla_decode( seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, - num_kv_splits: int = -1, + sm_scale: float, + num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default. ) -> torch.Tensor: assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}" assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}" @@ -118,13 +119,17 @@ def cutlass_mla_decode( seq_lens, page_table, workspace, + sm_scale, num_kv_splits, ) return out[:, :H].contiguous() def cutlass_mla_get_workspace_size( - max_seq_len: int, num_batches: int, sm_count: int = 0, num_kv_splits: int = -1 + max_seq_len: int, + num_batches: int, + sm_count: int = 0, + num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default. ) -> int: assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}" assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}" diff --git a/sgl-kernel/tests/test_cutlass_mla.py b/sgl-kernel/tests/test_cutlass_mla.py index 37eadbf02..0f1829b5d 100644 --- a/sgl-kernel/tests/test_cutlass_mla.py +++ b/sgl-kernel/tests/test_cutlass_mla.py @@ -93,7 +93,7 @@ def test_cutlass_mla_decode( out_ref = q.new_zeros(bs, h_q, dv) ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) out = cutlass_mla_decode( - q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, num_kv_splits + q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, scale, num_kv_splits ) torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)