From aa46ed34d25730d532fe15068c02ddbe7c83f730 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 13 Jun 2025 16:58:29 +0800 Subject: [PATCH] Remove 200us slow concat kernel (part 1: kernel) (#7145) --- sgl-kernel/benchmark/bench_cutlass_mla.py | 21 ++++++-- .../csrc/attention/cutlass_mla_kernel.cu | 49 +++++++++++-------- sgl-kernel/csrc/common_extension.cc | 2 +- sgl-kernel/include/sgl_kernel_ops.h | 3 +- sgl-kernel/python/sgl_kernel/attention.py | 46 +++++++++-------- sgl-kernel/tests/test_cutlass_mla.py | 6 ++- 6 files changed, 79 insertions(+), 48 deletions(-) diff --git a/sgl-kernel/benchmark/bench_cutlass_mla.py b/sgl-kernel/benchmark/bench_cutlass_mla.py index 9ac97e20d..a5a8fe0c4 100644 --- a/sgl-kernel/benchmark/bench_cutlass_mla.py +++ b/sgl-kernel/benchmark/bench_cutlass_mla.py @@ -38,6 +38,7 @@ configs = list(itertools.product(bs_range, qlen_range)) ) def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits): d = 576 + dn = 64 dv = 512 h_q_map = { @@ -63,7 +64,11 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits): pack_factor = 128 // block_size block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor - q = torch.randn(batch_size, h_q, d, dtype=torch.bfloat16, device="cuda") * 100.0 + qn = ( + torch.randn(h_q, batch_size, d - dn, dtype=torch.bfloat16, device="cuda") + * 100.0 + ) + qr = torch.randn(batch_size, h_q, dn, dtype=torch.bfloat16, device="cuda") * 100.0 block_table = torch.randint( 0, batch_size * block_num, @@ -84,16 +89,22 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits): quantiles = [0.5, 0.2, 0.8] ms, min_ms, max_ms = triton.testing.do_bench( lambda: cutlass_mla_decode( - q, kv_cache, seq_lens, block_table, workspace, num_kv_splits + qn.transpose(0, 1), + qr, + kv_cache, + seq_lens, + block_table, + workspace, + num_kv_splits, ), quantiles=quantiles, ) + q_size = qn.numel() * qn.element_size() + qr.numel() * qr.element_size() + gbps = ( lambda ms: ( - q.numel() * q.element_size() - + q.numel() * q.element_size() * dv / d - + kv_cache.numel() * kv_cache.element_size() + q_size + q_size * dv / d + kv_cache.numel() * kv_cache.element_size() ) * 1e-9 / (ms * 1e-3) diff --git a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu index 55f604257..9ba335e23 100644 --- a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu +++ b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include #include "cutlass_sm100_mla/device/sm100_mla.hpp" #include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp" @@ -30,7 +31,8 @@ limitations under the License. #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 void cutlass_mla_decode( torch::Tensor const& out, - torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, @@ -91,16 +93,17 @@ struct MlaSm100 { template typename T::Fmha::Arguments args_from_options( at::Tensor const& out, - at::Tensor const& q_nope_and_q_pe, + at::Tensor const& q_nope, + at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, at::Tensor const& page_table, int64_t num_kv_splits) { cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = q_nope_and_q_pe.device().index(); + hw_info.device_id = q_nope.device().index(); hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - int batches = q_nope_and_q_pe.sizes()[0]; + int batches = q_nope.sizes()[0]; int page_count_per_seq = page_table.sizes()[1]; int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; int page_size = kv_c_and_k_pe_cache.sizes()[1]; @@ -122,8 +125,11 @@ typename T::Fmha::Arguments args_from_options( using StrideO = typename T::StrideO; using StrideLSE = typename T::StrideLSE; - StrideQ stride_Q = cute::make_tuple( - static_cast(0 + D_latent + D_rope), _1{}, static_cast(H * (0 + D_latent + D_rope))); + StrideQ stride_Q_nope = cute::make_tuple( + static_cast(q_nope.stride(1)), _1{}, static_cast(q_nope.stride(0))); + StrideQ stride_Q_pe = cute::make_tuple( + static_cast(q_pe.stride(1)), _1{}, static_cast(q_pe.stride(0))); + StrideK stride_C = cute::make_tuple( static_cast(0 + D_latent + D_rope), _1{}, static_cast(page_size * (D_latent + D_rope))); StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); @@ -133,15 +139,16 @@ typename T::Fmha::Arguments args_from_options( using Element = typename T::Element; using ElementOut = typename T::ElementOut; using ElementAcc = typename T::ElementAcc; - auto Q_ptr = static_cast(q_nope_and_q_pe.data_ptr()); + auto Q_nope_ptr = static_cast(q_nope.data_ptr()); + auto Q_pe_ptr = static_cast(q_pe.data_ptr()); auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); typename T::Fmha::Arguments arguments{ problem_shape, {scale, - Q_ptr, - stride_Q, - Q_ptr + D_latent, - stride_Q, + Q_nope_ptr, + stride_Q_nope, + Q_pe_ptr, + stride_Q_pe, C_ptr, stride_C, C_ptr + D_latent, @@ -170,7 +177,8 @@ typename T::Fmha::Arguments args_from_options( template void runMla( at::Tensor const& out, - at::Tensor const& q_nope_and_q_pe, + at::Tensor const& q_nope, + at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, at::Tensor const& page_table, @@ -179,7 +187,7 @@ void runMla( cudaStream_t stream) { using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits); + auto arguments = args_from_options(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits); CUTLASS_CHECK(fmha.can_implement(arguments)); @@ -201,15 +209,16 @@ void runMla( void cutlass_mla_decode( torch::Tensor const& out, - torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, torch::Tensor const& workspace, int64_t num_kv_splits) { - auto in_dtype = q_nope_and_q_pe.dtype(); - at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()}; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device()); + auto in_dtype = q_nope.dtype(); + at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); const int page_size = kv_c_and_k_pe_cache.sizes()[1]; // NOTE(alcanderian): IsPersistent has bug with manual split_kv. @@ -219,13 +228,13 @@ void cutlass_mla_decode( DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { if (in_dtype == at::ScalarType::Half) { runMla>( - out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::BFloat16) { runMla>( - out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { runMla>( - out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); } diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 1886d668e..d8e9fb336 100755 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -59,7 +59,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"); m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2); m.def( - "cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor " + "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor " "page_table, Tensor! workspace, int num_kv_splits) -> ()"); m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index a20c26724..1fdfbeae1 100755 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -105,7 +105,8 @@ void merge_state_v2( at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged); void cutlass_mla_decode( torch::Tensor const& out, - torch::Tensor const& q_nope_and_q_pe, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, diff --git a/sgl-kernel/python/sgl_kernel/attention.py b/sgl-kernel/python/sgl_kernel/attention.py index 2ece6abdd..bb3c2af1c 100644 --- a/sgl-kernel/python/sgl_kernel/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -52,34 +52,42 @@ def merge_state_v2( def cutlass_mla_decode( - q_nope_and_q_pe: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, num_kv_splits: int = -1, ) -> torch.Tensor: - assert ( - q_nope_and_q_pe.ndim == 3 - ), f"q_nope_and_q_pe must be a 3D tensor, but got {q_nope_and_q_pe.ndim}" + assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}" + assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}" assert ( kv_c_and_k_pe_cache.ndim == 3 ), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}" - B_q, H, D_q = q_nope_and_q_pe.shape + + B_q, H, D_q_nope = q_nope.shape + B_q_2, H_2, D_q_pe = q_pe.shape + assert (B_q == B_q_2) and (H == H_2) + _, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape D_latent = 512 D_rope = 64 - assert D_q == D_ckv and D_q == D_latent + D_rope, ( - f"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, " - f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}" - ) + assert D_q_nope == D_latent + assert D_q_pe == D_rope + assert D_ckv == D_latent + D_rope + MAX_HEADS = 128 assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}" if H < MAX_HEADS: - q_nope_and_q_pe_padded = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_q)) - q_nope_and_q_pe_padded[:, :H] = q_nope_and_q_pe - q_nope_and_q_pe = q_nope_and_q_pe_padded + q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope)) + q_nope_padded[:, :H] = q_nope + q_nope = q_nope_padded + + q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe)) + q_pe_padded[:, :H] = q_pe + q_pe = q_pe_padded assert len(page_table.shape) == 2 B_block_table, block_num = page_table.shape @@ -88,14 +96,11 @@ def cutlass_mla_decode( assert block_num % (128 / PAGE_SIZE) == 0 # TODO(kaixih@nvidia): support fp8 - assert q_nope_and_q_pe.dtype in ( + assert q_nope.dtype in ( torch.float16, torch.bfloat16, - ), f"q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got {q_nope_and_q_pe.dtype}." - assert kv_c_and_k_pe_cache.dtype == q_nope_and_q_pe.dtype, ( - f"kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, " - f"but got {kv_c_and_k_pe_cache.dtype}." - ) + ), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}." + assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype assert ( seq_lens.dtype == torch.int32 ), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}." @@ -103,11 +108,12 @@ def cutlass_mla_decode( page_table.dtype == torch.int32 ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." - out = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_latent)) + out = q_nope.new_empty((B_q, MAX_HEADS, D_latent)) torch.ops.sgl_kernel.cutlass_mla_decode.default( out, - q_nope_and_q_pe, + q_nope, + q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, diff --git a/sgl-kernel/tests/test_cutlass_mla.py b/sgl-kernel/tests/test_cutlass_mla.py index 22f850af7..37eadbf02 100644 --- a/sgl-kernel/tests/test_cutlass_mla.py +++ b/sgl-kernel/tests/test_cutlass_mla.py @@ -86,10 +86,14 @@ def test_cutlass_mla_decode( ) workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8) + q_nope = torch.empty((h_q, bs, dv)).transpose(0, 1) + q_nope.copy_(q[:, :, :dv]) + q_pe = q[:, :, dv:].clone() + out_ref = q.new_zeros(bs, h_q, dv) ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) out = cutlass_mla_decode( - q, kv_cache, seq_lens, block_table, workspace, num_kv_splits + q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, num_kv_splits ) torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)