Files
sglang/benchmark/kernels/elementwise/benchmark_concat_mla.py
Yuan Luo 42245551ef [sgl-kernel] Optimize concat_mla_k kernel (#10543)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: PGFLMG <1106310035@qq.com>
2025-09-28 23:04:22 +08:00

199 lines
5.7 KiB
Python

import torch
import triton
import triton.language as tl
from sgl_kernel import concat_mla_k as concat_mla_k_cuda
DEVICE = triton.runtime.driver.active.get_active_torch_device()
num_local_heads = 128
qk_nope_head_dim = 128
qk_rope_head_dim = 64
def create_data(num_tokens):
k_nope_container = torch.randn(
(num_tokens, num_local_heads, qk_nope_head_dim + 128),
dtype=torch.bfloat16,
device="cuda",
)
k_nope = k_nope_container[:, :, :qk_nope_head_dim]
k_rope_container = torch.randn(
(num_tokens, 1, 128 + qk_rope_head_dim), dtype=torch.bfloat16, device="cuda"
)
k_rope = k_rope_container[:, :, -qk_rope_head_dim:]
k = torch.empty(
(num_tokens, num_local_heads, qk_nope_head_dim + qk_rope_head_dim),
dtype=torch.bfloat16,
device="cuda",
)
return dict(k=k, k_nope=k_nope, k_rope=k_rope)
def fn_torch(k, k_nope, k_rope):
k[..., :qk_nope_head_dim] = k_nope
k[..., qk_nope_head_dim:] = k_rope
def fn_hack_non_strided(k, k_nope, k_rope):
k_flatten_view = k.flatten()
k_flatten_view[: k_nope.numel()] = k_nope.flatten()
k2 = k_flatten_view[k_nope.numel() :].view(k_rope.numel(), -1)
k2 = k_rope.flatten()[:, None]
@torch.compile(dynamic=True)
def fn_torch_compiled(k, k_nope, k_rope):
return fn_torch(k, k_nope, k_rope)
def fn_cuda(k, k_nope, k_rope):
concat_mla_k_cuda(k, k_nope, k_rope)
@triton.jit
def fn_triton_kernel(
k_ptr,
k_nope_ptr,
k_rope_ptr,
num_tokens,
QK_NOPE_HEAD_DIM: tl.constexpr,
QK_ROPE_HEAD_DIM: tl.constexpr,
NUM_LOCAL_HEADS: tl.constexpr,
K_NOPE_STRIDE_0: tl.constexpr,
K_NOPE_STRIDE_1: tl.constexpr,
K_STRIDE_0: tl.constexpr,
K_STRIDE_1: tl.constexpr,
K_ROPE_STRIDE_0: tl.constexpr,
BLOCK_ROWS: tl.constexpr,
):
pid = tl.program_id(axis=0)
token_id = pid * BLOCK_ROWS + tl.arange(0, BLOCK_ROWS)
token_mask = token_id < num_tokens
head_id = tl.arange(0, NUM_LOCAL_HEADS)
# nope
nope_sub_id = tl.arange(0, QK_NOPE_HEAD_DIM)
offs_nope = (
token_id[:, None, None] * K_NOPE_STRIDE_0
+ head_id[None, :, None] * K_NOPE_STRIDE_1
+ nope_sub_id[None, None, :]
)
offs_k = (
token_id[:, None, None] * K_STRIDE_0
+ head_id[None, :, None] * K_STRIDE_1
+ nope_sub_id[None, None, :]
)
vals_nope = tl.load(k_nope_ptr + offs_nope, mask=token_mask[:, None, None])
tl.store(k_ptr + offs_k, vals_nope, mask=token_mask[:, None, None])
# rope
rope_sub_id = tl.arange(0, QK_ROPE_HEAD_DIM)
offs_rope = token_id[:, None, None] * K_ROPE_STRIDE_0 + rope_sub_id[None, None, :]
offs_k = (
token_id[:, None, None] * K_STRIDE_0
+ head_id[None, :, None] * K_STRIDE_1
+ rope_sub_id[None, None, :]
+ QK_NOPE_HEAD_DIM
)
vals_rope = tl.load(k_rope_ptr + offs_rope, mask=token_mask[:, None, None])
tl.store(k_ptr + offs_k, vals_rope, mask=token_mask[:, None, None])
def fn_triton(k, k_nope, k_rope):
assert k.device == DEVICE and k_nope.device == DEVICE and k_rope.device == DEVICE
num_tokens, _, _ = k.shape
grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_ROWS"]),)
fn_triton_kernel[grid](
k,
k_nope,
k_rope,
num_tokens,
QK_NOPE_HEAD_DIM=qk_nope_head_dim,
QK_ROPE_HEAD_DIM=qk_rope_head_dim,
NUM_LOCAL_HEADS=num_local_heads,
K_NOPE_STRIDE_0=k_nope.stride(0),
K_NOPE_STRIDE_1=k_nope.stride(1),
K_STRIDE_0=k.stride(0),
K_STRIDE_1=k.stride(1),
K_ROPE_STRIDE_0=k_rope.stride(0),
BLOCK_ROWS=16,
)
def execute_and_get_output(f, data):
data["k"].zero_()
f(**data)
assert data["k"].sum().item() != 0
return data["k"].clone()
torch.manual_seed(0)
data = create_data(num_tokens=32768)
output_ref = execute_and_get_output(fn_torch, data)
output_exp = execute_and_get_output(fn_cuda, data)
# print(output_ref)
# print(output_exp)
if not torch.all(output_ref == output_exp):
abs_delta = torch.abs(output_ref - output_exp)
raise AssertionError(
f"{output_ref=} {output_exp=} "
f"{abs_delta=} "
f"{torch.argwhere(abs_delta != 0.0)=} "
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"], # Argument names to use as an x-axis for the plot.
x_vals=[
2048,
4096,
8192,
16384,
32768,
], # Different possible values for `x_name`.
x_log=False, # x axis is logarithmic.
line_arg="provider", # Argument name whose value corresponds to a different line in the plot.
line_vals=[
"torch",
"torch_compiled",
"triton",
"hack_non_strided",
"cuda",
], # Possible values for `line_arg`.
line_names=[
"torch",
"torch_compiled",
"triton",
"hack_non_strided",
"cuda",
], # Label name for the lines.
plot_name="vector-add-performance", # Name for the plot. Used also as a file name for saving the plot.
args={}, # Values for function arguments not in `x_names` and `y_name`.
)
)
def benchmark(num_tokens, provider):
data = create_data(num_tokens=num_tokens)
quantiles = [0.5, 0.2, 0.8]
fn = {
"torch": fn_torch,
"torch_compiled": fn_torch_compiled,
"triton": fn_triton,
"hack_non_strided": fn_hack_non_strided,
"cuda": fn_cuda,
}[provider]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fn(**data), quantiles=quantiles
)
return ms, min_ms, max_ms
torch.cuda.cudart().cudaProfilerStart()
benchmark.run(print_data=True, show_plots=True)
torch.cuda.cudart().cudaProfilerStop()