Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: PGFLMG <1106310035@qq.com>
199 lines
5.7 KiB
Python
199 lines
5.7 KiB
Python
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
from sgl_kernel import concat_mla_k as concat_mla_k_cuda
|
|
|
|
DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
|
|
|
num_local_heads = 128
|
|
qk_nope_head_dim = 128
|
|
qk_rope_head_dim = 64
|
|
|
|
|
|
def create_data(num_tokens):
|
|
k_nope_container = torch.randn(
|
|
(num_tokens, num_local_heads, qk_nope_head_dim + 128),
|
|
dtype=torch.bfloat16,
|
|
device="cuda",
|
|
)
|
|
k_nope = k_nope_container[:, :, :qk_nope_head_dim]
|
|
|
|
k_rope_container = torch.randn(
|
|
(num_tokens, 1, 128 + qk_rope_head_dim), dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
k_rope = k_rope_container[:, :, -qk_rope_head_dim:]
|
|
|
|
k = torch.empty(
|
|
(num_tokens, num_local_heads, qk_nope_head_dim + qk_rope_head_dim),
|
|
dtype=torch.bfloat16,
|
|
device="cuda",
|
|
)
|
|
return dict(k=k, k_nope=k_nope, k_rope=k_rope)
|
|
|
|
|
|
def fn_torch(k, k_nope, k_rope):
|
|
k[..., :qk_nope_head_dim] = k_nope
|
|
k[..., qk_nope_head_dim:] = k_rope
|
|
|
|
|
|
def fn_hack_non_strided(k, k_nope, k_rope):
|
|
k_flatten_view = k.flatten()
|
|
k_flatten_view[: k_nope.numel()] = k_nope.flatten()
|
|
|
|
k2 = k_flatten_view[k_nope.numel() :].view(k_rope.numel(), -1)
|
|
k2 = k_rope.flatten()[:, None]
|
|
|
|
|
|
@torch.compile(dynamic=True)
|
|
def fn_torch_compiled(k, k_nope, k_rope):
|
|
return fn_torch(k, k_nope, k_rope)
|
|
|
|
|
|
def fn_cuda(k, k_nope, k_rope):
|
|
concat_mla_k_cuda(k, k_nope, k_rope)
|
|
|
|
|
|
@triton.jit
|
|
def fn_triton_kernel(
|
|
k_ptr,
|
|
k_nope_ptr,
|
|
k_rope_ptr,
|
|
num_tokens,
|
|
QK_NOPE_HEAD_DIM: tl.constexpr,
|
|
QK_ROPE_HEAD_DIM: tl.constexpr,
|
|
NUM_LOCAL_HEADS: tl.constexpr,
|
|
K_NOPE_STRIDE_0: tl.constexpr,
|
|
K_NOPE_STRIDE_1: tl.constexpr,
|
|
K_STRIDE_0: tl.constexpr,
|
|
K_STRIDE_1: tl.constexpr,
|
|
K_ROPE_STRIDE_0: tl.constexpr,
|
|
BLOCK_ROWS: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
|
|
token_id = pid * BLOCK_ROWS + tl.arange(0, BLOCK_ROWS)
|
|
token_mask = token_id < num_tokens
|
|
|
|
head_id = tl.arange(0, NUM_LOCAL_HEADS)
|
|
|
|
# nope
|
|
nope_sub_id = tl.arange(0, QK_NOPE_HEAD_DIM)
|
|
offs_nope = (
|
|
token_id[:, None, None] * K_NOPE_STRIDE_0
|
|
+ head_id[None, :, None] * K_NOPE_STRIDE_1
|
|
+ nope_sub_id[None, None, :]
|
|
)
|
|
offs_k = (
|
|
token_id[:, None, None] * K_STRIDE_0
|
|
+ head_id[None, :, None] * K_STRIDE_1
|
|
+ nope_sub_id[None, None, :]
|
|
)
|
|
vals_nope = tl.load(k_nope_ptr + offs_nope, mask=token_mask[:, None, None])
|
|
tl.store(k_ptr + offs_k, vals_nope, mask=token_mask[:, None, None])
|
|
|
|
# rope
|
|
rope_sub_id = tl.arange(0, QK_ROPE_HEAD_DIM)
|
|
offs_rope = token_id[:, None, None] * K_ROPE_STRIDE_0 + rope_sub_id[None, None, :]
|
|
offs_k = (
|
|
token_id[:, None, None] * K_STRIDE_0
|
|
+ head_id[None, :, None] * K_STRIDE_1
|
|
+ rope_sub_id[None, None, :]
|
|
+ QK_NOPE_HEAD_DIM
|
|
)
|
|
vals_rope = tl.load(k_rope_ptr + offs_rope, mask=token_mask[:, None, None])
|
|
tl.store(k_ptr + offs_k, vals_rope, mask=token_mask[:, None, None])
|
|
|
|
|
|
def fn_triton(k, k_nope, k_rope):
|
|
assert k.device == DEVICE and k_nope.device == DEVICE and k_rope.device == DEVICE
|
|
num_tokens, _, _ = k.shape
|
|
grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_ROWS"]),)
|
|
fn_triton_kernel[grid](
|
|
k,
|
|
k_nope,
|
|
k_rope,
|
|
num_tokens,
|
|
QK_NOPE_HEAD_DIM=qk_nope_head_dim,
|
|
QK_ROPE_HEAD_DIM=qk_rope_head_dim,
|
|
NUM_LOCAL_HEADS=num_local_heads,
|
|
K_NOPE_STRIDE_0=k_nope.stride(0),
|
|
K_NOPE_STRIDE_1=k_nope.stride(1),
|
|
K_STRIDE_0=k.stride(0),
|
|
K_STRIDE_1=k.stride(1),
|
|
K_ROPE_STRIDE_0=k_rope.stride(0),
|
|
BLOCK_ROWS=16,
|
|
)
|
|
|
|
|
|
def execute_and_get_output(f, data):
|
|
data["k"].zero_()
|
|
f(**data)
|
|
assert data["k"].sum().item() != 0
|
|
return data["k"].clone()
|
|
|
|
|
|
torch.manual_seed(0)
|
|
data = create_data(num_tokens=32768)
|
|
output_ref = execute_and_get_output(fn_torch, data)
|
|
output_exp = execute_and_get_output(fn_cuda, data)
|
|
# print(output_ref)
|
|
# print(output_exp)
|
|
if not torch.all(output_ref == output_exp):
|
|
abs_delta = torch.abs(output_ref - output_exp)
|
|
raise AssertionError(
|
|
f"{output_ref=} {output_exp=} "
|
|
f"{abs_delta=} "
|
|
f"{torch.argwhere(abs_delta != 0.0)=} "
|
|
)
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["num_tokens"], # Argument names to use as an x-axis for the plot.
|
|
x_vals=[
|
|
2048,
|
|
4096,
|
|
8192,
|
|
16384,
|
|
32768,
|
|
], # Different possible values for `x_name`.
|
|
x_log=False, # x axis is logarithmic.
|
|
line_arg="provider", # Argument name whose value corresponds to a different line in the plot.
|
|
line_vals=[
|
|
"torch",
|
|
"torch_compiled",
|
|
"triton",
|
|
"hack_non_strided",
|
|
"cuda",
|
|
], # Possible values for `line_arg`.
|
|
line_names=[
|
|
"torch",
|
|
"torch_compiled",
|
|
"triton",
|
|
"hack_non_strided",
|
|
"cuda",
|
|
], # Label name for the lines.
|
|
plot_name="vector-add-performance", # Name for the plot. Used also as a file name for saving the plot.
|
|
args={}, # Values for function arguments not in `x_names` and `y_name`.
|
|
)
|
|
)
|
|
def benchmark(num_tokens, provider):
|
|
data = create_data(num_tokens=num_tokens)
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
fn = {
|
|
"torch": fn_torch,
|
|
"torch_compiled": fn_torch_compiled,
|
|
"triton": fn_triton,
|
|
"hack_non_strided": fn_hack_non_strided,
|
|
"cuda": fn_cuda,
|
|
}[provider]
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: fn(**data), quantiles=quantiles
|
|
)
|
|
return ms, min_ms, max_ms
|
|
|
|
|
|
torch.cuda.cudart().cudaProfilerStart()
|
|
benchmark.run(print_data=True, show_plots=True)
|
|
torch.cuda.cudart().cudaProfilerStop()
|