diff --git a/benchmark/kernels/elementwise/benchmark_concat_mla.py b/benchmark/kernels/elementwise/benchmark_concat_mla.py new file mode 100644 index 000000000..c4d7bb1c8 --- /dev/null +++ b/benchmark/kernels/elementwise/benchmark_concat_mla.py @@ -0,0 +1,198 @@ +import torch +import triton +import triton.language as tl +from sgl_kernel import concat_mla_k as concat_mla_k_cuda + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +num_local_heads = 128 +qk_nope_head_dim = 128 +qk_rope_head_dim = 64 + + +def create_data(num_tokens): + k_nope_container = torch.randn( + (num_tokens, num_local_heads, qk_nope_head_dim + 128), + dtype=torch.bfloat16, + device="cuda", + ) + k_nope = k_nope_container[:, :, :qk_nope_head_dim] + + k_rope_container = torch.randn( + (num_tokens, 1, 128 + qk_rope_head_dim), dtype=torch.bfloat16, device="cuda" + ) + k_rope = k_rope_container[:, :, -qk_rope_head_dim:] + + k = torch.empty( + (num_tokens, num_local_heads, qk_nope_head_dim + qk_rope_head_dim), + dtype=torch.bfloat16, + device="cuda", + ) + return dict(k=k, k_nope=k_nope, k_rope=k_rope) + + +def fn_torch(k, k_nope, k_rope): + k[..., :qk_nope_head_dim] = k_nope + k[..., qk_nope_head_dim:] = k_rope + + +def fn_hack_non_strided(k, k_nope, k_rope): + k_flatten_view = k.flatten() + k_flatten_view[: k_nope.numel()] = k_nope.flatten() + + k2 = k_flatten_view[k_nope.numel() :].view(k_rope.numel(), -1) + k2 = k_rope.flatten()[:, None] + + +@torch.compile(dynamic=True) +def fn_torch_compiled(k, k_nope, k_rope): + return fn_torch(k, k_nope, k_rope) + + +def fn_cuda(k, k_nope, k_rope): + concat_mla_k_cuda(k, k_nope, k_rope) + + +@triton.jit +def fn_triton_kernel( + k_ptr, + k_nope_ptr, + k_rope_ptr, + num_tokens, + QK_NOPE_HEAD_DIM: tl.constexpr, + QK_ROPE_HEAD_DIM: tl.constexpr, + NUM_LOCAL_HEADS: tl.constexpr, + K_NOPE_STRIDE_0: tl.constexpr, + K_NOPE_STRIDE_1: tl.constexpr, + K_STRIDE_0: tl.constexpr, + K_STRIDE_1: tl.constexpr, + K_ROPE_STRIDE_0: tl.constexpr, + BLOCK_ROWS: tl.constexpr, +): + pid = tl.program_id(axis=0) + + token_id = pid * BLOCK_ROWS + tl.arange(0, BLOCK_ROWS) + token_mask = token_id < num_tokens + + head_id = tl.arange(0, NUM_LOCAL_HEADS) + + # nope + nope_sub_id = tl.arange(0, QK_NOPE_HEAD_DIM) + offs_nope = ( + token_id[:, None, None] * K_NOPE_STRIDE_0 + + head_id[None, :, None] * K_NOPE_STRIDE_1 + + nope_sub_id[None, None, :] + ) + offs_k = ( + token_id[:, None, None] * K_STRIDE_0 + + head_id[None, :, None] * K_STRIDE_1 + + nope_sub_id[None, None, :] + ) + vals_nope = tl.load(k_nope_ptr + offs_nope, mask=token_mask[:, None, None]) + tl.store(k_ptr + offs_k, vals_nope, mask=token_mask[:, None, None]) + + # rope + rope_sub_id = tl.arange(0, QK_ROPE_HEAD_DIM) + offs_rope = token_id[:, None, None] * K_ROPE_STRIDE_0 + rope_sub_id[None, None, :] + offs_k = ( + token_id[:, None, None] * K_STRIDE_0 + + head_id[None, :, None] * K_STRIDE_1 + + rope_sub_id[None, None, :] + + QK_NOPE_HEAD_DIM + ) + vals_rope = tl.load(k_rope_ptr + offs_rope, mask=token_mask[:, None, None]) + tl.store(k_ptr + offs_k, vals_rope, mask=token_mask[:, None, None]) + + +def fn_triton(k, k_nope, k_rope): + assert k.device == DEVICE and k_nope.device == DEVICE and k_rope.device == DEVICE + num_tokens, _, _ = k.shape + grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_ROWS"]),) + fn_triton_kernel[grid]( + k, + k_nope, + k_rope, + num_tokens, + QK_NOPE_HEAD_DIM=qk_nope_head_dim, + QK_ROPE_HEAD_DIM=qk_rope_head_dim, + NUM_LOCAL_HEADS=num_local_heads, + K_NOPE_STRIDE_0=k_nope.stride(0), + K_NOPE_STRIDE_1=k_nope.stride(1), + K_STRIDE_0=k.stride(0), + K_STRIDE_1=k.stride(1), + K_ROPE_STRIDE_0=k_rope.stride(0), + BLOCK_ROWS=16, + ) + + +def execute_and_get_output(f, data): + data["k"].zero_() + f(**data) + assert data["k"].sum().item() != 0 + return data["k"].clone() + + +torch.manual_seed(0) +data = create_data(num_tokens=32768) +output_ref = execute_and_get_output(fn_torch, data) +output_exp = execute_and_get_output(fn_cuda, data) +# print(output_ref) +# print(output_exp) +if not torch.all(output_ref == output_exp): + abs_delta = torch.abs(output_ref - output_exp) + raise AssertionError( + f"{output_ref=} {output_exp=} " + f"{abs_delta=} " + f"{torch.argwhere(abs_delta != 0.0)=} " + ) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens"], # Argument names to use as an x-axis for the plot. + x_vals=[ + 2048, + 4096, + 8192, + 16384, + 32768, + ], # Different possible values for `x_name`. + x_log=False, # x axis is logarithmic. + line_arg="provider", # Argument name whose value corresponds to a different line in the plot. + line_vals=[ + "torch", + "torch_compiled", + "triton", + "hack_non_strided", + "cuda", + ], # Possible values for `line_arg`. + line_names=[ + "torch", + "torch_compiled", + "triton", + "hack_non_strided", + "cuda", + ], # Label name for the lines. + plot_name="vector-add-performance", # Name for the plot. Used also as a file name for saving the plot. + args={}, # Values for function arguments not in `x_names` and `y_name`. + ) +) +def benchmark(num_tokens, provider): + data = create_data(num_tokens=num_tokens) + quantiles = [0.5, 0.2, 0.8] + fn = { + "torch": fn_torch, + "torch_compiled": fn_torch_compiled, + "triton": fn_triton, + "hack_non_strided": fn_hack_non_strided, + "cuda": fn_cuda, + }[provider] + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: fn(**data), quantiles=quantiles + ) + return ms, min_ms, max_ms + + +torch.cuda.cudart().cudaProfilerStart() +benchmark.run(print_data=True, show_plots=True) +torch.cuda.cudart().cudaProfilerStop() diff --git a/sgl-kernel/csrc/elementwise/concat_mla.cu b/sgl-kernel/csrc/elementwise/concat_mla.cu index 13ff16e22..0335dc724 100644 --- a/sgl-kernel/csrc/elementwise/concat_mla.cu +++ b/sgl-kernel/csrc/elementwise/concat_mla.cu @@ -3,6 +3,7 @@ #include #include "pytorch_extension_utils.h" +#include "utils.cuh" constexpr int NUM_LOCAL_HEADS = 128; constexpr int QK_NOPE_HEAD_DIM = 128; @@ -12,20 +13,10 @@ constexpr int K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM; constexpr int HEAD_CHUNK_SIZE = 16; constexpr int NUM_HEAD_CHUNKS = NUM_LOCAL_HEADS / HEAD_CHUNK_SIZE; -__forceinline__ __device__ int get_lane_id() { - int lane_id; - asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); - return lane_id; -} - -int ceil_div(int a, int b) { - return (a + b - 1) / b; -} - __global__ void concat_mla_k_kernel( - nv_bfloat16* k, - nv_bfloat16* k_nope, - nv_bfloat16* k_rope, + nv_bfloat16* __restrict__ k, + const nv_bfloat16* __restrict__ k_nope, + const nv_bfloat16* __restrict__ k_rope, const int num_tokens, const int k_stride_0, const int k_stride_1, @@ -36,43 +27,50 @@ __global__ void concat_mla_k_kernel( const int token_id = flat_warp_id / NUM_HEAD_CHUNKS; const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS; const int lane_id = get_lane_id(); + if (token_id >= num_tokens) return; - if (token_id >= num_tokens) { - return; - } + using NopeVec = int2; // 8B/thread,32 thread = 256B/row + using RopeVec = int; // 4B/thread,32 thread = 128B/row + static_assert(sizeof(NopeVec) * 32 == QK_NOPE_HEAD_DIM * sizeof(nv_bfloat16), "nope vec mismatch"); + static_assert(sizeof(RopeVec) * 32 == QK_ROPE_HEAD_DIM * sizeof(nv_bfloat16), "rope vec mismatch"); - using KNopeBufType = int2; - static_assert(sizeof(KNopeBufType) == QK_NOPE_HEAD_DIM * sizeof(k[0]) / 32); - KNopeBufType k_nope_buf[HEAD_CHUNK_SIZE]; + const int head_row0 = head_chunk_id * HEAD_CHUNK_SIZE; - using KRopeBufType = int; - static_assert(sizeof(KRopeBufType) == QK_ROPE_HEAD_DIM * sizeof(k[0]) / 32); - KRopeBufType k_rope_buf; + const int2* __restrict__ nope_src = + reinterpret_cast(k_nope + token_id * k_nope_stride_0 + head_row0 * k_nope_stride_1) + lane_id; - { - const int* base_addr = reinterpret_cast(k_rope + token_id * k_rope_stride_0); - k_rope_buf = *(base_addr + lane_id); - } + int2* __restrict__ nope_dst = reinterpret_cast(k + token_id * k_stride_0 + head_row0 * k_stride_1) + lane_id; + + int* __restrict__ rope_dst = + reinterpret_cast(k + token_id * k_stride_0 + head_row0 * k_stride_1 + QK_NOPE_HEAD_DIM) + lane_id; + + const int nope_src_stride_v = (k_nope_stride_1 >> 2); // int2 covers 4 bf16 + const int nope_dst_stride_v = (k_stride_1 >> 2); + const int rope_dst_stride_v = (k_stride_1 >> 1); // int covers 2 bf16 + + const int* rope_base = reinterpret_cast(k_rope + token_id * k_rope_stride_0); + const RopeVec rope_val = ld_na_global_v1(rope_base + lane_id); + + prefetch_L2(nope_src); + NopeVec cur = ld_na_global_v2(nope_src); #pragma unroll for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) { - const int head_id = head_chunk_id * HEAD_CHUNK_SIZE + i; - const int2* base_addr = reinterpret_cast(k_nope + token_id * k_nope_stride_0 + head_id * k_nope_stride_1); - k_nope_buf[i] = *(base_addr + lane_id); - } - -#pragma unroll - for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) { - const int head_id = head_chunk_id * HEAD_CHUNK_SIZE + i; - - { - int2* base_addr = reinterpret_cast(k + token_id * k_stride_0 + head_id * k_stride_1); - *(base_addr + lane_id) = k_nope_buf[i]; - } - { - int* base_addr = reinterpret_cast(k + token_id * k_stride_0 + head_id * k_stride_1 + QK_NOPE_HEAD_DIM); - *(base_addr + lane_id) = k_rope_buf; + NopeVec next; + if (i + 1 < HEAD_CHUNK_SIZE) { + const int2* next_src = nope_src + nope_src_stride_v; + prefetch_L2(next_src); + next = ld_na_global_v2(next_src); } + + st_na_global_v2(nope_dst, cur); + st_na_global_v1(rope_dst, rope_val); + + nope_src += nope_src_stride_v; + nope_dst += nope_dst_stride_v; + rope_dst += rope_dst_stride_v; + + cur = next; } } diff --git a/sgl-kernel/csrc/elementwise/utils.cuh b/sgl-kernel/csrc/elementwise/utils.cuh new file mode 100644 index 000000000..3010a5452 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/utils.cuh @@ -0,0 +1,72 @@ +// Adapted from https://github.com/deepseek-ai/DeepEP/blob/main/csrc/kernels/utils.cuh + +#pragma once + +#include +#include + +#include + +__forceinline__ __device__ int get_lane_id() { + int lane_id; + asm("mov.s32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +__device__ __forceinline__ void st_na_global_v1(const int* ptr, int v) { + asm volatile("st.global.L1::no_allocate.s32 [%0], %1;" ::"l"(ptr), "r"(v) : "memory"); +} + +__device__ __forceinline__ void st_na_global_v2(const int2* ptr, const int2& v) { + asm volatile("st.global.L1::no_allocate.v2.s32 [%0], {%1, %2};" ::"l"(ptr), "r"(v.x), "r"(v.y) : "memory"); +} + +__device__ __forceinline__ void st_na_global_v4(const int4* ptr, const int4& v) { + asm volatile( + "st.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(v.x), "r"(v.y), "r"(v.z), "r"(v.w) + : "memory"); +} + +__device__ __forceinline__ int ld_na_global_v1(const int* ptr) { + int r; +#ifdef USE_L2_HINT + asm volatile("ld.global.nc.L1::no_allocate.L2::128B.s32 %0, [%1];" : "=r"(r) : "l"(ptr)); +#else + asm volatile("ld.global.nc.L1::no_allocate.s32 %0, [%1];" : "=r"(r) : "l"(ptr)); +#endif + return r; +} + +__device__ __forceinline__ int2 ld_na_global_v2(const int2* ptr) { + int2 r; +#ifdef USE_L2_HINT + asm volatile("ld.global.nc.L1::no_allocate.L2::128B.v2.s32 {%0, %1}, [%2];" : "=r"(r.x), "=r"(r.y) : "l"(ptr)); +#else + asm volatile("ld.global.nc.L1::no_allocate.v2.s32 {%0, %1}, [%2];" : "=r"(r.x), "=r"(r.y) : "l"(ptr)); +#endif + return r; +} + +__device__ __forceinline__ int4 ld_na_global_v4(const int4* ptr) { + int4 r; +#ifdef USE_L2_HINT + asm volatile("ld.global.nc.L1::no_allocate.L2::128B.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(r.x), "=r"(r.y), "=r"(r.z), "=r"(r.w) + : "l"(ptr)); +#else + asm volatile("ld.global.nc.L1::no_allocate.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(r.x), "=r"(r.y), "=r"(r.z), "=r"(r.w) + : "l"(ptr)); +#endif + return r; +} + +__device__ __forceinline__ void prefetch_L2(const void* p) { +#if defined(ENABLE_L2_PREFETCH) + asm volatile("prefetch.global.L2 [%0];" ::"l"(p)); +#endif +}