From 3b25dc127ae7df0680cfe0fd7800ba5aa280c64c Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Tue, 16 Sep 2025 02:53:21 +0800 Subject: [PATCH] [1/2] Speed up trtllm_mla attention backend (>10% e2e) (#10473) --- sgl-kernel/csrc/common_extension.cc | 3 + sgl-kernel/csrc/elementwise/concat_mla.cu | 102 ++++++++++++++++++++ sgl-kernel/include/sgl_kernel_ops.h | 1 + sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/elementwise.py | 12 +++ test/srt/models/test_generation_models.py | 3 - 6 files changed, 119 insertions(+), 3 deletions(-) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index ad67248a9..9c89bcad7 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -104,6 +104,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()"); m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k); + m.def("concat_mla_absorb_q(Tensor a, Tensor b, Tensor! out) -> ()"); + m.impl("concat_mla_absorb_q", torch::kCUDA, &concat_mla_absorb_q); + /* * From csrc/gemm */ diff --git a/sgl-kernel/csrc/elementwise/concat_mla.cu b/sgl-kernel/csrc/elementwise/concat_mla.cu index b6c236333..13ff16e22 100644 --- a/sgl-kernel/csrc/elementwise/concat_mla.cu +++ b/sgl-kernel/csrc/elementwise/concat_mla.cu @@ -115,3 +115,105 @@ void concat_mla_k(at::Tensor k, at::Tensor k_nope, at::Tensor k_rope) { cudaError_t err = cudaGetLastError(); TORCH_CHECK(err == cudaSuccess, "CUDA kernel launch failed: ", cudaGetErrorString(err)); } + +// ============================== concat_mla_absorb_q ============================== + +// TODO give a name prefix, also maybe refactor code above +constexpr int A_LAST_DIM = 512; +constexpr int B_LAST_DIM = 64; + +__global__ void concat_mla_absorb_q_kernel( + nv_bfloat16* a, + nv_bfloat16* b, + nv_bfloat16* out, + const int num_items, + const int dim_1, + const int a_stride_0, + const int a_stride_1, + const int b_stride_0, + const int b_stride_1, + const int out_stride_0, + const int out_stride_1) { + const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + const int lane_id = get_lane_id(); + + const int idx_0 = flat_warp_id / dim_1; + const int idx_1 = flat_warp_id % dim_1; + + if (flat_warp_id >= num_items) { + return; + } + + using ABufType = int4; + constexpr int A_NUM_UNROLL = 2; + static_assert(sizeof(ABufType) * A_NUM_UNROLL == A_LAST_DIM * sizeof(a[0]) / 32); + ABufType a_buf[A_NUM_UNROLL]; + + using BBufType = int; + constexpr int B_NUM_UNROLL = 1; + static_assert(sizeof(BBufType) * B_NUM_UNROLL == B_LAST_DIM * sizeof(b[0]) / 32); + BBufType b_buf; + + { + const BBufType* base_addr = reinterpret_cast(b + idx_0 * b_stride_0 + idx_1 * b_stride_1); + b_buf = *(base_addr + lane_id); + } + +#pragma unroll + for (int i = 0; i < A_NUM_UNROLL; ++i) { + const ABufType* base_addr = reinterpret_cast(a + idx_0 * a_stride_0 + idx_1 * a_stride_1); + a_buf[i] = *(base_addr + i * 32 + lane_id); + } + + { + BBufType* base_addr = reinterpret_cast(out + idx_0 * out_stride_0 + idx_1 * out_stride_1 + A_LAST_DIM); + *(base_addr + lane_id) = b_buf; + } + +#pragma unroll + for (int i = 0; i < A_NUM_UNROLL; ++i) { + ABufType* base_addr = reinterpret_cast(out + idx_0 * out_stride_0 + idx_1 * out_stride_1); + *(base_addr + i * 32 + lane_id) = a_buf[i]; + } +} + +inline void check_tensor_concat_mla_absorb_q(const at::Tensor& t, int64_t shape2) { + TORCH_CHECK_EQ(t.dim(), 3); + TORCH_CHECK_EQ(t.size(2), shape2); + TORCH_CHECK_EQ(t.stride(2), 1); + TORCH_CHECK_EQ(t.dtype(), at::kBFloat16); + TORCH_CHECK(t.device().is_cuda()); + TORCH_CHECK_EQ(((int64_t)t.data_ptr()) % 16, 0); // alignment +} + +// TODO further optimize it later +void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out) { + check_tensor_concat_mla_absorb_q(a, A_LAST_DIM); + check_tensor_concat_mla_absorb_q(b, B_LAST_DIM); + check_tensor_concat_mla_absorb_q(out, A_LAST_DIM + B_LAST_DIM); + + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_EQ(a.size(0) * a.size(1), b.size(0) * b.size(1)); + TORCH_CHECK_EQ(a.size(1), b.size(1)); + const int num_items = a.size(0) * a.size(1); + + constexpr int num_warps_per_block = 32; + const int grid_size = ceil_div(num_items, num_warps_per_block); + const int block_size = num_warps_per_block * 32; + + concat_mla_absorb_q_kernel<<>>( + reinterpret_cast(a.data_ptr()), + reinterpret_cast(b.data_ptr()), + reinterpret_cast(out.data_ptr()), + num_items, + a.size(1), + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + out.stride(0), + out.stride(1)); + cudaError_t err = cudaGetLastError(); + TORCH_CHECK(err == cudaSuccess, "CUDA kernel launch failed: ", cudaGetErrorString(err)); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index e1ac17de7..c166319c0 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -172,6 +172,7 @@ void downcast_fp8( void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope); +void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out); #ifdef USE_ROCM void gelu_quick(at::Tensor& out, const at::Tensor& input); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 9c676ea8b..01d0034d8 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -23,6 +23,7 @@ from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_ from sgl_kernel.elementwise import ( FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace, + concat_mla_absorb_q, concat_mla_k, copy_to_gpu_no_ce, downcast_fp8, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index af3adfd4a..13bb11be3 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -379,3 +379,15 @@ def concat_mla_k( k_rope: torch.Tensor, ): torch.ops.sgl_kernel.concat_mla_k(k, k_nope, k_rope) + + +def concat_mla_absorb_q( + a: torch.Tensor, + b: torch.Tensor, +): + *batch_dims, _ = a.shape + out = torch.empty( + (*batch_dims, a.shape[-1] + b.shape[-1]), device=a.device, dtype=a.dtype + ) + torch.ops.sgl_kernel.concat_mla_absorb_q(a, b, out) + return out diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 039acc18d..ef930f88e 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -67,11 +67,8 @@ ALL_MODELS = [ ModelCase("openai-community/gpt2"), ModelCase("microsoft/phi-1_5", trust_remote_code=True), ModelCase("adept/persimmon-8b-chat"), - ModelCase("upstage/SOLAR-10.7B-Instruct-v1.0"), - ModelCase("inclusionAI/Ling-lite", trust_remote_code=True), - ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),