support lightning_attention_decode in sgl-kernel for MiniMax-Text-01 (#3030)
This commit is contained in:
@@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params):
|
||||
model_params["num_attention_heads"],
|
||||
d,
|
||||
d,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
with torch.no_grad():
|
||||
@@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params):
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
past_kv = past_kv.contiguous()
|
||||
slope_rate = slope_rate.contiguous()
|
||||
|
||||
# Test Triton implementation
|
||||
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
|
||||
triton_output = triton_output.transpose(1, 2).contiguous()
|
||||
triton_output = triton_output.view(batch_size, seq_len, -1)
|
||||
@@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params):
|
||||
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
|
||||
triton_output = model_attn.out_proj(triton_output)
|
||||
|
||||
# Test SGL implementation
|
||||
sgl_output = torch.empty_like(v)
|
||||
sgl_new_kv = torch.empty_like(past_kv)
|
||||
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
|
||||
|
||||
sgl_output = sgl_output.transpose(1, 2).contiguous()
|
||||
sgl_output = sgl_output.view(batch_size, seq_len, -1)
|
||||
sgl_output = model_attn.norm(sgl_output)
|
||||
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
|
||||
sgl_output = model_attn.out_proj(sgl_output)
|
||||
|
||||
# Verify Triton implementation results
|
||||
torch.testing.assert_close(
|
||||
model_output,
|
||||
triton_output,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="Lightning attention implementations produce different output results",
|
||||
msg="Triton lightning attention implementation produces different output results",
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
new_kv,
|
||||
triton_new_kv,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="Lightning attention implementations produce different kv results",
|
||||
msg="Triton lightning attention implementation produces different kv results",
|
||||
)
|
||||
|
||||
print("✅ Two implementations match")
|
||||
# Verify SGL implementation results
|
||||
torch.testing.assert_close(
|
||||
model_output,
|
||||
sgl_output,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="SGL lightning attention implementation produces different output results",
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
new_kv,
|
||||
sgl_new_kv,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="SGL lightning attention implementation produces different kv results",
|
||||
)
|
||||
|
||||
print("✅ All implementations match")
|
||||
|
||||
|
||||
def _build_slope_tensor(n_attention_heads: int):
|
||||
@@ -408,12 +442,13 @@ def get_benchmark():
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["Original", "Triton"],
|
||||
line_vals=["Original", "Triton", "SGL"],
|
||||
line_names=[
|
||||
"Original PyTorch Implementation",
|
||||
"Triton Implementation",
|
||||
"SGL Implementation",
|
||||
],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="lightning-attention-decode-performance",
|
||||
args={},
|
||||
@@ -446,7 +481,6 @@ def get_benchmark():
|
||||
params["num_attention_heads"],
|
||||
d,
|
||||
d,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@@ -461,7 +495,7 @@ def get_benchmark():
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
elif provider == "Triton":
|
||||
|
||||
def run_triton():
|
||||
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||
@@ -483,6 +517,33 @@ def get_benchmark():
|
||||
run_triton,
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else: # SGL
|
||||
|
||||
def run_sgl():
|
||||
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||
qkv = qkv.view(*new_shape)
|
||||
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||
q = q.transpose(1, 2).contiguous()
|
||||
k = k.transpose(1, 2).contiguous()
|
||||
v = v.transpose(1, 2).contiguous()
|
||||
|
||||
output = torch.empty_like(v)
|
||||
new_kv = torch.empty_like(past_kv)
|
||||
sgl_lightning_attention_decode(
|
||||
q, k, v, past_kv, slope_rate, output, new_kv
|
||||
)
|
||||
|
||||
output = output.transpose(1, 2).contiguous()
|
||||
output = output.view(batch_size, seq_len, -1)
|
||||
output = model_attn.norm(output)
|
||||
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
|
||||
return model_attn.out_proj(output)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
run_sgl,
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
299
sgl-kernel/benchmark/bench_lightning_attention_decode.py
Normal file
299
sgl-kernel/benchmark/bench_lightning_attention_decode.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import lightning_attention_decode
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
return 2 ** (int(math.ceil(math.log(n, 2))))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _decode_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
KV,
|
||||
Out,
|
||||
S,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n: tl.constexpr,
|
||||
d: tl.constexpr,
|
||||
d_original: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
e_original: tl.constexpr,
|
||||
):
|
||||
off_bh = tl.program_id(0)
|
||||
off_h = off_bh % h
|
||||
|
||||
qk_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
o_offset = off_bh * n * e
|
||||
kv_offset = off_bh * d * e
|
||||
|
||||
s = tl.load(S + off_h)
|
||||
ratio = tl.exp(-s)
|
||||
|
||||
d_idx = tl.arange(0, d)
|
||||
e_idx = tl.arange(0, e)
|
||||
|
||||
# Create masks for original dimensions
|
||||
d_mask = d_idx < d_original
|
||||
e_mask = e_idx < e_original
|
||||
|
||||
# Load with masking
|
||||
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
|
||||
|
||||
# Load KV with 2D masking
|
||||
kv = tl.load(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Compute outer product using element-wise operations
|
||||
k_v_prod = k[:, None] * v[None, :]
|
||||
kv = ratio * kv + k_v_prod
|
||||
|
||||
# Store KV with 2D masking
|
||||
tl.store(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
kv.to(KV.dtype.element_ty),
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
)
|
||||
|
||||
# Compute matrix-vector multiplication using element-wise operations and reduction
|
||||
o = tl.sum(q[:, None] * kv, axis=0)
|
||||
|
||||
# Store output with masking
|
||||
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
|
||||
|
||||
|
||||
def triton_lightning_attn_decode(q, k, v, kv, s):
|
||||
"""Triton implementation of Lightning Attention decode operation"""
|
||||
b, h, n, d = q.shape
|
||||
e = v.shape[-1]
|
||||
assert n == 1, "Sequence length must be 1 in decode mode"
|
||||
|
||||
# Get padded dimensions (power of 2)
|
||||
d_padded = next_power_of_2(d)
|
||||
e_padded = next_power_of_2(e)
|
||||
|
||||
# Create output tensor (padded)
|
||||
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
|
||||
# Create padded tensors without actually padding the data
|
||||
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
|
||||
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
|
||||
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
kv_padded = torch.empty(
|
||||
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
|
||||
)
|
||||
|
||||
# Copy data to padded tensors
|
||||
q_padded[..., :d] = q
|
||||
k_padded[..., :d] = k
|
||||
v_padded[..., :e] = v
|
||||
kv_padded[..., :d, :e] = kv
|
||||
|
||||
# Launch kernel
|
||||
grid = (b * h, 1)
|
||||
_decode_kernel[grid](
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
kv_padded,
|
||||
o_padded,
|
||||
s,
|
||||
b=b,
|
||||
h=h,
|
||||
n=n,
|
||||
d=d_padded,
|
||||
d_original=d,
|
||||
e=e_padded,
|
||||
e_original=e,
|
||||
)
|
||||
|
||||
# Get unpadded outputs
|
||||
o = o_padded[..., :e]
|
||||
kv_out = kv_padded[..., :d, :e]
|
||||
|
||||
return o, kv_out
|
||||
|
||||
|
||||
def lightning_attention_decode_naive(q, k, v, past_kv, slope):
|
||||
"""Naive implementation of lightning attention decode"""
|
||||
original_dtype = q.dtype
|
||||
ratio = torch.exp(-slope) # [h, 1, 1]
|
||||
|
||||
kv = past_kv
|
||||
b, h, n, d = q.shape
|
||||
|
||||
output = []
|
||||
for i in range(n):
|
||||
kv = ratio * kv.to(torch.float32) + torch.einsum(
|
||||
"... n d, ... n e -> ... d e",
|
||||
k[:, :, i : i + 1],
|
||||
v[:, :, i : i + 1],
|
||||
)
|
||||
qkv = torch.einsum(
|
||||
"... n e, ... e d -> ... n d",
|
||||
q[:, :, i : i + 1].to(torch.float32),
|
||||
kv.to(torch.float32),
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.concat(output, dim=-2)
|
||||
|
||||
return output.to(original_dtype), kv
|
||||
|
||||
|
||||
def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv):
|
||||
return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
||||
|
||||
|
||||
def calculate_diff(batch_size):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_heads = 64
|
||||
head_dim = 96
|
||||
seq_len = 1
|
||||
|
||||
q = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
k = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||
|
||||
output_naive, new_kv_naive = lightning_attention_decode_naive(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
)
|
||||
|
||||
output_kernel = torch.empty_like(output_naive)
|
||||
new_kv_kernel = torch.empty_like(new_kv_naive)
|
||||
lightning_attention_decode_kernel(
|
||||
q.clone(),
|
||||
k.clone(),
|
||||
v.clone(),
|
||||
past_kv.clone(),
|
||||
slope.clone(),
|
||||
output_kernel,
|
||||
new_kv_kernel,
|
||||
)
|
||||
|
||||
output_triton, new_kv_triton = triton_lightning_attn_decode(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
)
|
||||
|
||||
if (
|
||||
torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2)
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [i for i in range(1, 65)] # 1 to 128
|
||||
configs = [(bs,) for bs in batch_size_range]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["naive", "kernel", "triton"],
|
||||
line_names=["PyTorch Naive", "SGL Kernel", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="lightning-attention-decode-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_heads = 64
|
||||
head_dim = 96
|
||||
seq_len = 1
|
||||
|
||||
q = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
k = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "naive":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: lightning_attention_decode_naive(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "kernel":
|
||||
output = torch.empty(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: lightning_attention_decode_kernel(
|
||||
q.clone(),
|
||||
k.clone(),
|
||||
v.clone(),
|
||||
past_kv.clone(),
|
||||
slope.clone(),
|
||||
output,
|
||||
new_kv,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: triton_lightning_attn_decode(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/lightning_attention_decode_sgl/",
|
||||
help="Path to save lightning attention decode benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(batch_size=4)
|
||||
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True)
|
||||
@@ -100,6 +100,7 @@ ext_modules = [
|
||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||
"3rdparty/flashinfer/csrc/activation.cu",
|
||||
|
||||
@@ -10,6 +10,7 @@ from sgl_kernel.ops import (
|
||||
get_graph_buffer_ipc_meta,
|
||||
init_custom_reduce,
|
||||
int8_scaled_mm,
|
||||
lightning_attention_decode,
|
||||
moe_align_block_size,
|
||||
register_graph_buffers,
|
||||
rmsnorm,
|
||||
@@ -35,5 +36,6 @@ __all__ = [
|
||||
"rmsnorm",
|
||||
"rotary_embedding",
|
||||
"sampling_scaling_penalties",
|
||||
"lightning_attention_decode",
|
||||
"silu_and_mul",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define THREADS_PER_BLOCK 128
|
||||
|
||||
template <typename T>
|
||||
__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d]
|
||||
const T* __restrict__ k, // [b, h, 1, d]
|
||||
const T* __restrict__ v, // [b, h, 1, e]
|
||||
const float* __restrict__ past_kv, // [b, h, d, e]
|
||||
const float* __restrict__ slope, // [h, 1, 1]
|
||||
T* __restrict__ output, // [b, h, 1, e]
|
||||
float* __restrict__ new_kv, // [b, h, d, e]
|
||||
const int batch_size, const int num_heads, const int qk_dim,
|
||||
const int v_dim) {
|
||||
extern __shared__ char smem[];
|
||||
T* q_shared = reinterpret_cast<T*>(smem);
|
||||
T* k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
|
||||
T* v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
|
||||
float* new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
|
||||
T* output_shared =
|
||||
reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));
|
||||
|
||||
const int32_t tid = threadIdx.x;
|
||||
const int32_t current_head = blockIdx.x;
|
||||
const int32_t b = current_head / num_heads;
|
||||
const int32_t h = current_head % num_heads;
|
||||
|
||||
if (b >= batch_size) return;
|
||||
|
||||
const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim;
|
||||
const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
|
||||
const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;
|
||||
|
||||
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||
q_shared[d] = q[qk_offset + d];
|
||||
k_shared[d] = k[qk_offset + d];
|
||||
}
|
||||
for (int e = tid; e < v_dim; e += blockDim.x) {
|
||||
v_shared[e] = v[v_offset + e];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const float ratio = expf(-1.0f * slope[h]);
|
||||
|
||||
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||
T k_val = k_shared[d];
|
||||
for (int e = 0; e < v_dim; ++e) {
|
||||
int past_kv_idx = kv_offset + d * v_dim + e;
|
||||
T v_val = v_shared[e];
|
||||
float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
|
||||
int shared_idx = d * (v_dim + 1) + e;
|
||||
new_kv_shared[shared_idx] = new_val;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
|
||||
int d = idx / v_dim;
|
||||
int e = idx % v_dim;
|
||||
int shared_idx = d * (v_dim + 1) + e;
|
||||
int global_idx = kv_offset + idx;
|
||||
new_kv[global_idx] = new_kv_shared[shared_idx];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int e = tid; e < v_dim; e += blockDim.x) {
|
||||
float sum = 0.0f;
|
||||
for (int d = 0; d < qk_dim; ++d) {
|
||||
int shared_idx = d * (v_dim + 1) + e;
|
||||
sum += q_shared[d] * new_kv_shared[shared_idx];
|
||||
}
|
||||
output_shared[e] = static_cast<T>(sum);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) {
|
||||
for (int e = 0; e < v_dim; ++e) {
|
||||
output[v_offset + e] = output_shared[e];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
||||
torch::Tensor new_kv) {
|
||||
TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
|
||||
TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
|
||||
TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
|
||||
TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous");
|
||||
|
||||
auto batch_size = q.size(0);
|
||||
auto num_heads = q.size(1);
|
||||
auto qk_dim = q.size(3);
|
||||
auto v_dim = v.size(3);
|
||||
|
||||
dim3 block(THREADS_PER_BLOCK);
|
||||
dim3 grid(batch_size * num_heads);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
|
||||
size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
|
||||
lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
|
||||
q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), v.data_ptr<scalar_t>(), past_kv.data_ptr<float>(),
|
||||
slope.data_ptr<float>(), output.data_ptr<scalar_t>(), new_kv.data_ptr<float>(), batch_size, num_heads,
|
||||
qk_dim, v_dim);
|
||||
}));
|
||||
}
|
||||
@@ -26,6 +26,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
|
||||
// lightning_attention_decode
|
||||
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
|
||||
// rotary embedding
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
@@ -69,6 +74,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
|
||||
// int8_scaled_mm
|
||||
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
|
||||
// lightning_attention_decode
|
||||
m.def("lightning_attention_decode", &lightning_attention_decode, "Lightning Attention Ddecode (CUDA)");
|
||||
// rotary embedding
|
||||
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
|
||||
// rms norm
|
||||
|
||||
@@ -14,6 +14,9 @@ from sgl_kernel.ops._kernels import (
|
||||
)
|
||||
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
|
||||
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
|
||||
from sgl_kernel.ops._kernels import (
|
||||
lightning_attention_decode as _lightning_attention_decode,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
|
||||
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
|
||||
from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
|
||||
@@ -86,6 +89,10 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
)
|
||||
|
||||
|
||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
_lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
||||
|
||||
|
||||
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
|
||||
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
|
||||
|
||||
|
||||
84
sgl-kernel/tests/test_lightning_attention_decode.py
Normal file
84
sgl-kernel/tests/test_lightning_attention_decode.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import lightning_attention_decode
|
||||
|
||||
|
||||
def naive_lightning_attention_decode(q, k, v, past_kv, slope):
|
||||
"""Naive implementation of lightning attention decode"""
|
||||
original_dtype = q.dtype
|
||||
ratio = torch.exp(-slope) # [h, 1, 1]
|
||||
|
||||
kv = past_kv
|
||||
b, h, n, d = q.shape
|
||||
|
||||
output = []
|
||||
for i in range(n):
|
||||
kv = ratio * kv.to(torch.float32) + torch.einsum(
|
||||
"... n d, ... n e -> ... d e",
|
||||
k[:, :, i : i + 1],
|
||||
v[:, :, i : i + 1],
|
||||
)
|
||||
qkv = torch.einsum(
|
||||
"... n e, ... e d -> ... n d",
|
||||
q[:, :, i : i + 1].to(torch.float32),
|
||||
kv.to(torch.float32),
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.concat(output, dim=-2)
|
||||
|
||||
return output.to(original_dtype), kv
|
||||
|
||||
|
||||
configs = [
|
||||
# (batch_size, num_heads, dim, embed_dim)
|
||||
(1, 8, 64, 64),
|
||||
(2, 8, 64, 64),
|
||||
(1, 32, 32, 64),
|
||||
(2, 32, 32, 64),
|
||||
(4, 32, 64, 64),
|
||||
(4, 32, 64, 64),
|
||||
(16, 64, 96, 96),
|
||||
(64, 64, 96, 96),
|
||||
]
|
||||
|
||||
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs)
|
||||
def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim):
|
||||
device = torch.device("cuda")
|
||||
|
||||
q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
|
||||
k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
|
||||
v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype)
|
||||
past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device)
|
||||
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||
|
||||
ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope)
|
||||
|
||||
output = torch.empty_like(ref_output)
|
||||
new_kv = torch.empty_like(ref_new_kv)
|
||||
lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
||||
|
||||
rtol = 1e-2
|
||||
atol = 1e-2
|
||||
|
||||
torch.testing.assert_close(
|
||||
output,
|
||||
ref_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, "
|
||||
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
new_kv,
|
||||
ref_new_kv,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, "
|
||||
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
|
||||
)
|
||||
Reference in New Issue
Block a user