[sgl-kernel] Optimize concat_mla_k kernel (#10543)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: PGFLMG <1106310035@qq.com>
This commit is contained in:
Yuan Luo
2025-09-28 23:04:22 +08:00
committed by GitHub
parent 2a9d995c09
commit 42245551ef
3 changed files with 310 additions and 42 deletions

View File

@@ -3,6 +3,7 @@
#include <cuda_runtime.h>
#include "pytorch_extension_utils.h"
#include "utils.cuh"
constexpr int NUM_LOCAL_HEADS = 128;
constexpr int QK_NOPE_HEAD_DIM = 128;
@@ -12,20 +13,10 @@ constexpr int K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM;
constexpr int HEAD_CHUNK_SIZE = 16;
constexpr int NUM_HEAD_CHUNKS = NUM_LOCAL_HEADS / HEAD_CHUNK_SIZE;
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
int ceil_div(int a, int b) {
return (a + b - 1) / b;
}
__global__ void concat_mla_k_kernel(
nv_bfloat16* k,
nv_bfloat16* k_nope,
nv_bfloat16* k_rope,
nv_bfloat16* __restrict__ k,
const nv_bfloat16* __restrict__ k_nope,
const nv_bfloat16* __restrict__ k_rope,
const int num_tokens,
const int k_stride_0,
const int k_stride_1,
@@ -36,43 +27,50 @@ __global__ void concat_mla_k_kernel(
const int token_id = flat_warp_id / NUM_HEAD_CHUNKS;
const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS;
const int lane_id = get_lane_id();
if (token_id >= num_tokens) return;
if (token_id >= num_tokens) {
return;
}
using NopeVec = int2; // 8B/thread32 thread = 256B/row
using RopeVec = int; // 4B/thread32 thread = 128B/row
static_assert(sizeof(NopeVec) * 32 == QK_NOPE_HEAD_DIM * sizeof(nv_bfloat16), "nope vec mismatch");
static_assert(sizeof(RopeVec) * 32 == QK_ROPE_HEAD_DIM * sizeof(nv_bfloat16), "rope vec mismatch");
using KNopeBufType = int2;
static_assert(sizeof(KNopeBufType) == QK_NOPE_HEAD_DIM * sizeof(k[0]) / 32);
KNopeBufType k_nope_buf[HEAD_CHUNK_SIZE];
const int head_row0 = head_chunk_id * HEAD_CHUNK_SIZE;
using KRopeBufType = int;
static_assert(sizeof(KRopeBufType) == QK_ROPE_HEAD_DIM * sizeof(k[0]) / 32);
KRopeBufType k_rope_buf;
const int2* __restrict__ nope_src =
reinterpret_cast<const int2*>(k_nope + token_id * k_nope_stride_0 + head_row0 * k_nope_stride_1) + lane_id;
{
const int* base_addr = reinterpret_cast<int*>(k_rope + token_id * k_rope_stride_0);
k_rope_buf = *(base_addr + lane_id);
}
int2* __restrict__ nope_dst = reinterpret_cast<int2*>(k + token_id * k_stride_0 + head_row0 * k_stride_1) + lane_id;
int* __restrict__ rope_dst =
reinterpret_cast<int*>(k + token_id * k_stride_0 + head_row0 * k_stride_1 + QK_NOPE_HEAD_DIM) + lane_id;
const int nope_src_stride_v = (k_nope_stride_1 >> 2); // int2 covers 4 bf16
const int nope_dst_stride_v = (k_stride_1 >> 2);
const int rope_dst_stride_v = (k_stride_1 >> 1); // int covers 2 bf16
const int* rope_base = reinterpret_cast<const int*>(k_rope + token_id * k_rope_stride_0);
const RopeVec rope_val = ld_na_global_v1(rope_base + lane_id);
prefetch_L2(nope_src);
NopeVec cur = ld_na_global_v2(nope_src);
#pragma unroll
for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) {
const int head_id = head_chunk_id * HEAD_CHUNK_SIZE + i;
const int2* base_addr = reinterpret_cast<int2*>(k_nope + token_id * k_nope_stride_0 + head_id * k_nope_stride_1);
k_nope_buf[i] = *(base_addr + lane_id);
}
#pragma unroll
for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) {
const int head_id = head_chunk_id * HEAD_CHUNK_SIZE + i;
{
int2* base_addr = reinterpret_cast<int2*>(k + token_id * k_stride_0 + head_id * k_stride_1);
*(base_addr + lane_id) = k_nope_buf[i];
}
{
int* base_addr = reinterpret_cast<int*>(k + token_id * k_stride_0 + head_id * k_stride_1 + QK_NOPE_HEAD_DIM);
*(base_addr + lane_id) = k_rope_buf;
NopeVec next;
if (i + 1 < HEAD_CHUNK_SIZE) {
const int2* next_src = nope_src + nope_src_stride_v;
prefetch_L2(next_src);
next = ld_na_global_v2(next_src);
}
st_na_global_v2(nope_dst, cur);
st_na_global_v1(rope_dst, rope_val);
nope_src += nope_src_stride_v;
nope_dst += nope_dst_stride_v;
rope_dst += rope_dst_stride_v;
cur = next;
}
}

View File

@@ -0,0 +1,72 @@
// Adapted from https://github.com/deepseek-ai/DeepEP/blob/main/csrc/kernels/utils.cuh
#pragma once
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <cstdint>
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
int ceil_div(int a, int b) {
return (a + b - 1) / b;
}
__device__ __forceinline__ void st_na_global_v1(const int* ptr, int v) {
asm volatile("st.global.L1::no_allocate.s32 [%0], %1;" ::"l"(ptr), "r"(v) : "memory");
}
__device__ __forceinline__ void st_na_global_v2(const int2* ptr, const int2& v) {
asm volatile("st.global.L1::no_allocate.v2.s32 [%0], {%1, %2};" ::"l"(ptr), "r"(v.x), "r"(v.y) : "memory");
}
__device__ __forceinline__ void st_na_global_v4(const int4* ptr, const int4& v) {
asm volatile(
"st.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(v.x), "r"(v.y), "r"(v.z), "r"(v.w)
: "memory");
}
__device__ __forceinline__ int ld_na_global_v1(const int* ptr) {
int r;
#ifdef USE_L2_HINT
asm volatile("ld.global.nc.L1::no_allocate.L2::128B.s32 %0, [%1];" : "=r"(r) : "l"(ptr));
#else
asm volatile("ld.global.nc.L1::no_allocate.s32 %0, [%1];" : "=r"(r) : "l"(ptr));
#endif
return r;
}
__device__ __forceinline__ int2 ld_na_global_v2(const int2* ptr) {
int2 r;
#ifdef USE_L2_HINT
asm volatile("ld.global.nc.L1::no_allocate.L2::128B.v2.s32 {%0, %1}, [%2];" : "=r"(r.x), "=r"(r.y) : "l"(ptr));
#else
asm volatile("ld.global.nc.L1::no_allocate.v2.s32 {%0, %1}, [%2];" : "=r"(r.x), "=r"(r.y) : "l"(ptr));
#endif
return r;
}
__device__ __forceinline__ int4 ld_na_global_v4(const int4* ptr) {
int4 r;
#ifdef USE_L2_HINT
asm volatile("ld.global.nc.L1::no_allocate.L2::128B.v4.s32 {%0, %1, %2, %3}, [%4];"
: "=r"(r.x), "=r"(r.y), "=r"(r.z), "=r"(r.w)
: "l"(ptr));
#else
asm volatile("ld.global.nc.L1::no_allocate.v4.s32 {%0, %1, %2, %3}, [%4];"
: "=r"(r.x), "=r"(r.y), "=r"(r.z), "=r"(r.w)
: "l"(ptr));
#endif
return r;
}
__device__ __forceinline__ void prefetch_L2(const void* p) {
#if defined(ENABLE_L2_PREFETCH)
asm volatile("prefetch.global.L2 [%0];" ::"l"(p));
#endif
}