From 0096798ed60b9eadce468c2d206cd2982e97b978 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Tue, 9 Sep 2025 00:00:33 +0800 Subject: [PATCH] [1/2] Speed up prefill mla attention (#10156) --- sgl-kernel/CMakeLists.txt | 1 + sgl-kernel/csrc/common_extension.cc | 2 + sgl-kernel/csrc/elementwise/concat_mla.cu | 117 ++++++++++++++++++++ sgl-kernel/include/sgl_kernel_ops.h | 1 + sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/elementwise.py | 8 ++ 6 files changed, 130 insertions(+) create mode 100644 sgl-kernel/csrc/elementwise/concat_mla.cu diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 58ac06c08..3ae1b00d5 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -259,6 +259,7 @@ set(SOURCES "csrc/elementwise/activation.cu" "csrc/elementwise/cast.cu" "csrc/elementwise/copy.cu" + "csrc/elementwise/concat_mla.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu" "csrc/elementwise/rope.cu" "csrc/common_extension.cc" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 5a87dd483..c603e4bb6 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -436,6 +436,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()"); m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce); + m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()"); + m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/elementwise/concat_mla.cu b/sgl-kernel/csrc/elementwise/concat_mla.cu new file mode 100644 index 000000000..b6c236333 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/concat_mla.cu @@ -0,0 +1,117 @@ +#include +#include +#include + +#include "pytorch_extension_utils.h" + +constexpr int NUM_LOCAL_HEADS = 128; +constexpr int QK_NOPE_HEAD_DIM = 128; +constexpr int QK_ROPE_HEAD_DIM = 64; +constexpr int K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM; + +constexpr int HEAD_CHUNK_SIZE = 16; +constexpr int NUM_HEAD_CHUNKS = NUM_LOCAL_HEADS / HEAD_CHUNK_SIZE; + +__forceinline__ __device__ int get_lane_id() { + int lane_id; + asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +__global__ void concat_mla_k_kernel( + nv_bfloat16* k, + nv_bfloat16* k_nope, + nv_bfloat16* k_rope, + const int num_tokens, + const int k_stride_0, + const int k_stride_1, + const int k_nope_stride_0, + const int k_nope_stride_1, + const int k_rope_stride_0) { + const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + const int token_id = flat_warp_id / NUM_HEAD_CHUNKS; + const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS; + const int lane_id = get_lane_id(); + + if (token_id >= num_tokens) { + return; + } + + using KNopeBufType = int2; + static_assert(sizeof(KNopeBufType) == QK_NOPE_HEAD_DIM * sizeof(k[0]) / 32); + KNopeBufType k_nope_buf[HEAD_CHUNK_SIZE]; + + using KRopeBufType = int; + static_assert(sizeof(KRopeBufType) == QK_ROPE_HEAD_DIM * sizeof(k[0]) / 32); + KRopeBufType k_rope_buf; + + { + const int* base_addr = reinterpret_cast(k_rope + token_id * k_rope_stride_0); + k_rope_buf = *(base_addr + lane_id); + } + +#pragma unroll + for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) { + const int head_id = head_chunk_id * HEAD_CHUNK_SIZE + i; + const int2* base_addr = reinterpret_cast(k_nope + token_id * k_nope_stride_0 + head_id * k_nope_stride_1); + k_nope_buf[i] = *(base_addr + lane_id); + } + +#pragma unroll + for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) { + const int head_id = head_chunk_id * HEAD_CHUNK_SIZE + i; + + { + int2* base_addr = reinterpret_cast(k + token_id * k_stride_0 + head_id * k_stride_1); + *(base_addr + lane_id) = k_nope_buf[i]; + } + { + int* base_addr = reinterpret_cast(k + token_id * k_stride_0 + head_id * k_stride_1 + QK_NOPE_HEAD_DIM); + *(base_addr + lane_id) = k_rope_buf; + } + } +} + +inline void check_tensor(const at::Tensor& t, int64_t shape0, int64_t shape1, int64_t shape2, c10::ScalarType dtype) { + TORCH_CHECK_EQ(t.dim(), 3); + TORCH_CHECK_EQ(t.size(0), shape0); + TORCH_CHECK_EQ(t.size(1), shape1); + TORCH_CHECK_EQ(t.size(2), shape2); + TORCH_CHECK_EQ(t.dtype(), dtype); + TORCH_CHECK(t.device().is_cuda()); + TORCH_CHECK_EQ(((int64_t)t.data_ptr()) % 16, 0); // alignment +} + +void concat_mla_k(at::Tensor k, at::Tensor k_nope, at::Tensor k_rope) { + const int num_tokens = k.size(0); + + check_tensor(k, num_tokens, NUM_LOCAL_HEADS, K_HEAD_DIM, at::kBFloat16); + check_tensor(k_nope, num_tokens, NUM_LOCAL_HEADS, QK_NOPE_HEAD_DIM, at::kBFloat16); + check_tensor(k_rope, num_tokens, 1, QK_ROPE_HEAD_DIM, at::kBFloat16); + TORCH_CHECK_EQ(k.stride(2), 1); + TORCH_CHECK_EQ(k_nope.stride(2), 1); + TORCH_CHECK_EQ(k_rope.stride(2), 1); + + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + + constexpr int num_warps_per_block = 32; + const int grid_size = ceil_div(num_tokens * NUM_HEAD_CHUNKS, num_warps_per_block); + const int block_size = num_warps_per_block * 32; + + concat_mla_k_kernel<<>>( + reinterpret_cast(k.data_ptr()), + reinterpret_cast(k_nope.data_ptr()), + reinterpret_cast(k_rope.data_ptr()), + num_tokens, + k.stride(0), + k.stride(1), + k_nope.stride(0), + k_nope.stride(1), + k_rope.stride(0)); + cudaError_t err = cudaGetLastError(); + TORCH_CHECK(err == cudaSuccess, "CUDA kernel launch failed: ", cudaGetErrorString(err)); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 76969a6ee..6315e0418 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -723,3 +723,4 @@ std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, i void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v); void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); +void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 25e4eaf3b..8d7053bbd 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -23,6 +23,7 @@ from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_ from sgl_kernel.elementwise import ( FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace, + concat_mla_k, copy_to_gpu_no_ce, downcast_fp8, fused_add_rmsnorm, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 863b4d97e..af3adfd4a 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -371,3 +371,11 @@ def downcast_fp8( def copy_to_gpu_no_ce(input: List[int], output: torch.Tensor): torch.ops.sgl_kernel.copy_to_gpu_no_ce(input, output) + + +def concat_mla_k( + k: torch.Tensor, + k_nope: torch.Tensor, + k_rope: torch.Tensor, +): + torch.ops.sgl_kernel.concat_mla_k(k, k_nope, k_rope)