adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
299
sgl-kernel/benchmark/bench_lightning_attention_decode.py
Normal file
299
sgl-kernel/benchmark/bench_lightning_attention_decode.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import lightning_attention_decode
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
return 2 ** (int(math.ceil(math.log(n, 2))))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _decode_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
KV,
|
||||
Out,
|
||||
S,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n: tl.constexpr,
|
||||
d: tl.constexpr,
|
||||
d_original: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
e_original: tl.constexpr,
|
||||
):
|
||||
off_bh = tl.program_id(0)
|
||||
off_h = off_bh % h
|
||||
|
||||
qk_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
o_offset = off_bh * n * e
|
||||
kv_offset = off_bh * d * e
|
||||
|
||||
s = tl.load(S + off_h)
|
||||
ratio = tl.exp(-s)
|
||||
|
||||
d_idx = tl.arange(0, d)
|
||||
e_idx = tl.arange(0, e)
|
||||
|
||||
# Create masks for original dimensions
|
||||
d_mask = d_idx < d_original
|
||||
e_mask = e_idx < e_original
|
||||
|
||||
# Load with masking
|
||||
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
|
||||
|
||||
# Load KV with 2D masking
|
||||
kv = tl.load(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Compute outer product using element-wise operations
|
||||
k_v_prod = k[:, None] * v[None, :]
|
||||
kv = ratio * kv + k_v_prod
|
||||
|
||||
# Store KV with 2D masking
|
||||
tl.store(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
kv.to(KV.dtype.element_ty),
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
)
|
||||
|
||||
# Compute matrix-vector multiplication using element-wise operations and reduction
|
||||
o = tl.sum(q[:, None] * kv, axis=0)
|
||||
|
||||
# Store output with masking
|
||||
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
|
||||
|
||||
|
||||
def triton_lightning_attn_decode(q, k, v, kv, s):
|
||||
"""Triton implementation of Lightning Attention decode operation"""
|
||||
b, h, n, d = q.shape
|
||||
e = v.shape[-1]
|
||||
assert n == 1, "Sequence length must be 1 in decode mode"
|
||||
|
||||
# Get padded dimensions (power of 2)
|
||||
d_padded = next_power_of_2(d)
|
||||
e_padded = next_power_of_2(e)
|
||||
|
||||
# Create output tensor (padded)
|
||||
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
|
||||
# Create padded tensors without actually padding the data
|
||||
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
|
||||
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
|
||||
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
kv_padded = torch.empty(
|
||||
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
|
||||
)
|
||||
|
||||
# Copy data to padded tensors
|
||||
q_padded[..., :d] = q
|
||||
k_padded[..., :d] = k
|
||||
v_padded[..., :e] = v
|
||||
kv_padded[..., :d, :e] = kv
|
||||
|
||||
# Launch kernel
|
||||
grid = (b * h, 1)
|
||||
_decode_kernel[grid](
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
kv_padded,
|
||||
o_padded,
|
||||
s,
|
||||
b=b,
|
||||
h=h,
|
||||
n=n,
|
||||
d=d_padded,
|
||||
d_original=d,
|
||||
e=e_padded,
|
||||
e_original=e,
|
||||
)
|
||||
|
||||
# Get unpadded outputs
|
||||
o = o_padded[..., :e]
|
||||
kv_out = kv_padded[..., :d, :e]
|
||||
|
||||
return o, kv_out
|
||||
|
||||
|
||||
def lightning_attention_decode_naive(q, k, v, past_kv, slope):
|
||||
"""Naive implementation of lightning attention decode"""
|
||||
original_dtype = q.dtype
|
||||
ratio = torch.exp(-slope) # [h, 1, 1]
|
||||
|
||||
kv = past_kv
|
||||
b, h, n, d = q.shape
|
||||
|
||||
output = []
|
||||
for i in range(n):
|
||||
kv = ratio * kv.to(torch.float32) + torch.einsum(
|
||||
"... n d, ... n e -> ... d e",
|
||||
k[:, :, i : i + 1],
|
||||
v[:, :, i : i + 1],
|
||||
)
|
||||
qkv = torch.einsum(
|
||||
"... n e, ... e d -> ... n d",
|
||||
q[:, :, i : i + 1].to(torch.float32),
|
||||
kv.to(torch.float32),
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.cat(output, dim=-2)
|
||||
|
||||
return output.to(original_dtype), kv
|
||||
|
||||
|
||||
def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv):
|
||||
return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
||||
|
||||
|
||||
def calculate_diff(batch_size):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_heads = 64
|
||||
head_dim = 96
|
||||
seq_len = 1
|
||||
|
||||
q = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
k = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||
|
||||
output_naive, new_kv_naive = lightning_attention_decode_naive(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
)
|
||||
|
||||
output_kernel = torch.empty_like(output_naive)
|
||||
new_kv_kernel = torch.empty_like(new_kv_naive)
|
||||
lightning_attention_decode_kernel(
|
||||
q.clone(),
|
||||
k.clone(),
|
||||
v.clone(),
|
||||
past_kv.clone(),
|
||||
slope.clone(),
|
||||
output_kernel,
|
||||
new_kv_kernel,
|
||||
)
|
||||
|
||||
output_triton, new_kv_triton = triton_lightning_attn_decode(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
)
|
||||
|
||||
if (
|
||||
torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2)
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [i for i in range(1, 65)] # 1 to 128
|
||||
configs = [(bs,) for bs in batch_size_range]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["naive", "kernel", "triton"],
|
||||
line_names=["PyTorch Naive", "SGL Kernel", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="lightning-attention-decode-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_heads = 64
|
||||
head_dim = 96
|
||||
seq_len = 1
|
||||
|
||||
q = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
k = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "naive":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: lightning_attention_decode_naive(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "kernel":
|
||||
output = torch.empty(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: lightning_attention_decode_kernel(
|
||||
q.clone(),
|
||||
k.clone(),
|
||||
v.clone(),
|
||||
past_kv.clone(),
|
||||
slope.clone(),
|
||||
output,
|
||||
new_kv,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: triton_lightning_attn_decode(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/lightning_attention_decode_sgl/",
|
||||
help="Path to save lightning attention decode benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(batch_size=4)
|
||||
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True)
|
||||
Reference in New Issue
Block a user