From ac2dc35d0e529a278450bceb4d234aae3a1c93d8 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 23 Jan 2025 15:29:20 +0800 Subject: [PATCH] support lightning_attention_decode in sgl-kernel for MiniMax-Text-01 (#3030) --- .../benchmark_lightning_attention_decode.py | 77 ++++- .../bench_lightning_attention_decode.py | 299 ++++++++++++++++++ sgl-kernel/setup.py | 1 + sgl-kernel/src/sgl-kernel/__init__.py | 2 + .../csrc/lightning_attention_decode_kernel.cu | 119 +++++++ .../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 7 + sgl-kernel/src/sgl-kernel/ops/__init__.py | 7 + .../tests/test_lightning_attention_decode.py | 84 +++++ 8 files changed, 588 insertions(+), 8 deletions(-) create mode 100644 sgl-kernel/benchmark/bench_lightning_attention_decode.py create mode 100644 sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu create mode 100644 sgl-kernel/tests/test_lightning_attention_decode.py diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py index a2d1e10f6..57fbcfddf 100644 --- a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py +++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py @@ -9,6 +9,7 @@ import torch.nn.functional as F import triton import triton.language as tl from einops import rearrange +from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode @triton.jit @@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params): model_params["num_attention_heads"], d, d, - dtype=dtype, device=device, ) with torch.no_grad(): @@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params): q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + past_kv = past_kv.contiguous() + slope_rate = slope_rate.contiguous() + # Test Triton implementation triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate) triton_output = triton_output.transpose(1, 2).contiguous() triton_output = triton_output.view(batch_size, seq_len, -1) @@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params): triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output triton_output = model_attn.out_proj(triton_output) + # Test SGL implementation + sgl_output = torch.empty_like(v) + sgl_new_kv = torch.empty_like(past_kv) + sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv) + + sgl_output = sgl_output.transpose(1, 2).contiguous() + sgl_output = sgl_output.view(batch_size, seq_len, -1) + sgl_output = model_attn.norm(sgl_output) + sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output + sgl_output = model_attn.out_proj(sgl_output) + + # Verify Triton implementation results torch.testing.assert_close( model_output, triton_output, rtol=1e-3, atol=1e-2, - msg="Lightning attention implementations produce different output results", + msg="Triton lightning attention implementation produces different output results", ) torch.testing.assert_close( new_kv, triton_new_kv, rtol=1e-3, atol=1e-2, - msg="Lightning attention implementations produce different kv results", + msg="Triton lightning attention implementation produces different kv results", ) - print("✅ Two implementations match") + # Verify SGL implementation results + torch.testing.assert_close( + model_output, + sgl_output, + rtol=1e-3, + atol=1e-2, + msg="SGL lightning attention implementation produces different output results", + ) + torch.testing.assert_close( + new_kv, + sgl_new_kv, + rtol=1e-3, + atol=1e-2, + msg="SGL lightning attention implementation produces different kv results", + ) + + print("✅ All implementations match") def _build_slope_tensor(n_attention_heads: int): @@ -408,12 +442,13 @@ def get_benchmark(): x_names=["batch_size", "seq_len"], x_vals=[list(_) for _ in configs], line_arg="provider", - line_vals=["Original", "Triton"], + line_vals=["Original", "Triton", "SGL"], line_names=[ "Original PyTorch Implementation", "Triton Implementation", + "SGL Implementation", ], - styles=[("blue", "-"), ("green", "-")], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], ylabel="us", plot_name="lightning-attention-decode-performance", args={}, @@ -446,7 +481,6 @@ def get_benchmark(): params["num_attention_heads"], d, d, - dtype=dtype, device=device, ) @@ -461,7 +495,7 @@ def get_benchmark(): ), quantiles=quantiles, ) - else: + elif provider == "Triton": def run_triton(): qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) @@ -483,6 +517,33 @@ def get_benchmark(): run_triton, quantiles=quantiles, ) + else: # SGL + + def run_sgl(): + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + + output = torch.empty_like(v) + new_kv = torch.empty_like(past_kv) + sgl_lightning_attention_decode( + q, k, v, past_kv, slope_rate, output, new_kv + ) + + output = output.transpose(1, 2).contiguous() + output = output.view(batch_size, seq_len, -1) + output = model_attn.norm(output) + output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output + return model_attn.out_proj(output) + + ms, min_ms, max_ms = triton.testing.do_bench( + run_sgl, + quantiles=quantiles, + ) return 1000 * ms, 1000 * max_ms, 1000 * min_ms diff --git a/sgl-kernel/benchmark/bench_lightning_attention_decode.py b/sgl-kernel/benchmark/bench_lightning_attention_decode.py new file mode 100644 index 000000000..24872e61a --- /dev/null +++ b/sgl-kernel/benchmark/bench_lightning_attention_decode.py @@ -0,0 +1,299 @@ +import itertools +import math + +import torch +import triton +import triton.language as tl +from sgl_kernel import lightning_attention_decode + + +def next_power_of_2(n): + return 2 ** (int(math.ceil(math.log(n, 2)))) + + +@triton.jit +def _decode_kernel( + Q, + K, + V, + KV, + Out, + S, + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + d_original: tl.constexpr, + e: tl.constexpr, + e_original: tl.constexpr, +): + off_bh = tl.program_id(0) + off_h = off_bh % h + + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + kv_offset = off_bh * d * e + + s = tl.load(S + off_h) + ratio = tl.exp(-s) + + d_idx = tl.arange(0, d) + e_idx = tl.arange(0, e) + + # Create masks for original dimensions + d_mask = d_idx < d_original + e_mask = e_idx < e_original + + # Load with masking + q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0) + k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0) + v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0) + + # Load KV with 2D masking + kv = tl.load( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + mask=(d_mask[:, None] & e_mask[None, :]), + other=0.0, + ) + + # Compute outer product using element-wise operations + k_v_prod = k[:, None] * v[None, :] + kv = ratio * kv + k_v_prod + + # Store KV with 2D masking + tl.store( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + kv.to(KV.dtype.element_ty), + mask=(d_mask[:, None] & e_mask[None, :]), + ) + + # Compute matrix-vector multiplication using element-wise operations and reduction + o = tl.sum(q[:, None] * kv, axis=0) + + # Store output with masking + tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask) + + +def triton_lightning_attn_decode(q, k, v, kv, s): + """Triton implementation of Lightning Attention decode operation""" + b, h, n, d = q.shape + e = v.shape[-1] + assert n == 1, "Sequence length must be 1 in decode mode" + + # Get padded dimensions (power of 2) + d_padded = next_power_of_2(d) + e_padded = next_power_of_2(e) + + # Create output tensor (padded) + o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + + # Create padded tensors without actually padding the data + q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device) + k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device) + v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + kv_padded = torch.empty( + b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device + ) + + # Copy data to padded tensors + q_padded[..., :d] = q + k_padded[..., :d] = k + v_padded[..., :e] = v + kv_padded[..., :d, :e] = kv + + # Launch kernel + grid = (b * h, 1) + _decode_kernel[grid]( + q_padded, + k_padded, + v_padded, + kv_padded, + o_padded, + s, + b=b, + h=h, + n=n, + d=d_padded, + d_original=d, + e=e_padded, + e_original=e, + ) + + # Get unpadded outputs + o = o_padded[..., :e] + kv_out = kv_padded[..., :d, :e] + + return o, kv_out + + +def lightning_attention_decode_naive(q, k, v, past_kv, slope): + """Naive implementation of lightning attention decode""" + original_dtype = q.dtype + ratio = torch.exp(-slope) # [h, 1, 1] + + kv = past_kv + b, h, n, d = q.shape + + output = [] + for i in range(n): + kv = ratio * kv.to(torch.float32) + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", + q[:, :, i : i + 1].to(torch.float32), + kv.to(torch.float32), + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + return output.to(original_dtype), kv + + +def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv): + return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + +def calculate_diff(batch_size): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_heads = 64 + head_dim = 96 + seq_len = 1 + + q = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + k = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + output_naive, new_kv_naive = lightning_attention_decode_naive( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ) + + output_kernel = torch.empty_like(output_naive) + new_kv_kernel = torch.empty_like(new_kv_naive) + lightning_attention_decode_kernel( + q.clone(), + k.clone(), + v.clone(), + past_kv.clone(), + slope.clone(), + output_kernel, + new_kv_kernel, + ) + + output_triton, new_kv_triton = triton_lightning_attn_decode( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ) + + if ( + torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2) + and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2) + and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2) + and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2) + ): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [i for i in range(1, 65)] # 1 to 128 +configs = [(bs,) for bs in batch_size_range] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "kernel", "triton"], + line_names=["PyTorch Naive", "SGL Kernel", "Triton"], + styles=[("blue", "-"), ("red", "-"), ("green", "-")], + ylabel="us", + plot_name="lightning-attention-decode-performance", + args={}, + ) +) +def benchmark(batch_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_heads = 64 + head_dim = 96 + seq_len = 1 + + q = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + k = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "naive": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: lightning_attention_decode_naive( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ), + quantiles=quantiles, + ) + elif provider == "kernel": + output = torch.empty( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: lightning_attention_decode_kernel( + q.clone(), + k.clone(), + v.clone(), + past_kv.clone(), + slope.clone(), + output, + new_kv, + ), + quantiles=quantiles, + ) + elif provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_lightning_attn_decode( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/lightning_attention_decode_sgl/", + help="Path to save lightning attention decode benchmark results", + ) + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=4) + + # Run performance benchmark + benchmark.run(print_data=True) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 81cd96e99..9a2324b60 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -100,6 +100,7 @@ ext_modules = [ "src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/sampling_scaling_penalties.cu", + "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/rotary_embedding.cu", "3rdparty/flashinfer/csrc/activation.cu", diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 86c4f34d3..9eaa64e50 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -10,6 +10,7 @@ from sgl_kernel.ops import ( get_graph_buffer_ipc_meta, init_custom_reduce, int8_scaled_mm, + lightning_attention_decode, moe_align_block_size, register_graph_buffers, rmsnorm, @@ -35,5 +36,6 @@ __all__ = [ "rmsnorm", "rotary_embedding", "sampling_scaling_penalties", + "lightning_attention_decode", "silu_and_mul", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu new file mode 100644 index 000000000..eb79373b2 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu @@ -0,0 +1,119 @@ +#include +#include +#include +#include +#include + +#include "utils.h" + +#define THREADS_PER_BLOCK 128 + +template +__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d] + const T* __restrict__ k, // [b, h, 1, d] + const T* __restrict__ v, // [b, h, 1, e] + const float* __restrict__ past_kv, // [b, h, d, e] + const float* __restrict__ slope, // [h, 1, 1] + T* __restrict__ output, // [b, h, 1, e] + float* __restrict__ new_kv, // [b, h, d, e] + const int batch_size, const int num_heads, const int qk_dim, + const int v_dim) { + extern __shared__ char smem[]; + T* q_shared = reinterpret_cast(smem); + T* k_shared = reinterpret_cast(smem + qk_dim * sizeof(T)); + T* v_shared = reinterpret_cast(smem + 2 * qk_dim * sizeof(T)); + float* new_kv_shared = reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T)); + T* output_shared = + reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float)); + + const int32_t tid = threadIdx.x; + const int32_t current_head = blockIdx.x; + const int32_t b = current_head / num_heads; + const int32_t h = current_head % num_heads; + + if (b >= batch_size) return; + + const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim; + const int32_t v_offset = b * num_heads * v_dim + h * v_dim; + const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim; + + for (int d = tid; d < qk_dim; d += blockDim.x) { + q_shared[d] = q[qk_offset + d]; + k_shared[d] = k[qk_offset + d]; + } + for (int e = tid; e < v_dim; e += blockDim.x) { + v_shared[e] = v[v_offset + e]; + } + + __syncthreads(); + + const float ratio = expf(-1.0f * slope[h]); + + for (int d = tid; d < qk_dim; d += blockDim.x) { + T k_val = k_shared[d]; + for (int e = 0; e < v_dim; ++e) { + int past_kv_idx = kv_offset + d * v_dim + e; + T v_val = v_shared[e]; + float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val; + int shared_idx = d * (v_dim + 1) + e; + new_kv_shared[shared_idx] = new_val; + } + } + + __syncthreads(); + + for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) { + int d = idx / v_dim; + int e = idx % v_dim; + int shared_idx = d * (v_dim + 1) + e; + int global_idx = kv_offset + idx; + new_kv[global_idx] = new_kv_shared[shared_idx]; + } + + __syncthreads(); + + for (int e = tid; e < v_dim; e += blockDim.x) { + float sum = 0.0f; + for (int d = 0; d < qk_dim; ++d) { + int shared_idx = d * (v_dim + 1) + e; + sum += q_shared[d] * new_kv_shared[shared_idx]; + } + output_shared[e] = static_cast(sum); + } + + __syncthreads(); + + if (tid == 0) { + for (int e = 0; e < v_dim; ++e) { + output[v_offset + e] = output_shared[e]; + } + } +} + +void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, + const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, + torch::Tensor new_kv) { + TORCH_CHECK(q.is_contiguous(), "q must be contiguous"); + TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); + TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); + TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous"); + + auto batch_size = q.size(0); + auto num_heads = q.size(1); + auto qk_dim = q.size(3); + auto v_dim = v.size(3); + + dim3 block(THREADS_PER_BLOCK); + dim3 grid(batch_size * num_heads); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] { + size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float); + lightning_attention_decode_kernel<<>>( + q.data_ptr(), k.data_ptr(), v.data_ptr(), past_kv.data_ptr(), + slope.data_ptr(), output.data_ptr(), new_kv.data_ptr(), batch_size, num_heads, + qk_dim, v_dim); + })); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index 12df07471..cd5df0789 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -26,6 +26,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const c10::optional& bias); +// lightning_attention_decode +void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, + const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, + torch::Tensor new_kv); + // rotary embedding void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); @@ -69,6 +74,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)"); // int8_scaled_mm m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); + // lightning_attention_decode + m.def("lightning_attention_decode", &lightning_attention_decode, "Lightning Attention Ddecode (CUDA)"); // rotary embedding m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)"); // rms norm diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index d90f121d4..0aead260b 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -14,6 +14,9 @@ from sgl_kernel.ops._kernels import ( ) from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm +from sgl_kernel.ops._kernels import ( + lightning_attention_decode as _lightning_attention_decode, +) from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm @@ -86,6 +89,10 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): ) +def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): + _lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox): return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) diff --git a/sgl-kernel/tests/test_lightning_attention_decode.py b/sgl-kernel/tests/test_lightning_attention_decode.py new file mode 100644 index 000000000..74af78e27 --- /dev/null +++ b/sgl-kernel/tests/test_lightning_attention_decode.py @@ -0,0 +1,84 @@ +import pytest +import torch +from sgl_kernel import lightning_attention_decode + + +def naive_lightning_attention_decode(q, k, v, past_kv, slope): + """Naive implementation of lightning attention decode""" + original_dtype = q.dtype + ratio = torch.exp(-slope) # [h, 1, 1] + + kv = past_kv + b, h, n, d = q.shape + + output = [] + for i in range(n): + kv = ratio * kv.to(torch.float32) + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", + q[:, :, i : i + 1].to(torch.float32), + kv.to(torch.float32), + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + return output.to(original_dtype), kv + + +configs = [ + # (batch_size, num_heads, dim, embed_dim) + (1, 8, 64, 64), + (2, 8, 64, 64), + (1, 32, 32, 64), + (2, 32, 32, 64), + (4, 32, 64, 64), + (4, 32, 64, 64), + (16, 64, 96, 96), + (64, 64, 96, 96), +] + +dtypes = [torch.float32, torch.float16, torch.bfloat16] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs) +def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim): + device = torch.device("cuda") + + q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype) + k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype) + v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype) + past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope) + + output = torch.empty_like(ref_output) + new_kv = torch.empty_like(ref_new_kv) + lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + rtol = 1e-2 + atol = 1e-2 + + torch.testing.assert_close( + output, + ref_output, + rtol=rtol, + atol=atol, + msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, " + f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", + ) + + torch.testing.assert_close( + new_kv, + ref_new_kv, + rtol=rtol, + atol=atol, + msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, " + f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", + )