support lightning_attention_decode in sgl-kernel for MiniMax-Text-01 (#3030)
This commit is contained in:
@@ -9,6 +9,7 @@ import torch.nn.functional as F
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params):
|
|||||||
model_params["num_attention_heads"],
|
model_params["num_attention_heads"],
|
||||||
d,
|
d,
|
||||||
d,
|
d,
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params):
|
|||||||
q = q.transpose(1, 2)
|
q = q.transpose(1, 2)
|
||||||
k = k.transpose(1, 2)
|
k = k.transpose(1, 2)
|
||||||
v = v.transpose(1, 2)
|
v = v.transpose(1, 2)
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
past_kv = past_kv.contiguous()
|
||||||
|
slope_rate = slope_rate.contiguous()
|
||||||
|
|
||||||
|
# Test Triton implementation
|
||||||
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
|
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
|
||||||
triton_output = triton_output.transpose(1, 2).contiguous()
|
triton_output = triton_output.transpose(1, 2).contiguous()
|
||||||
triton_output = triton_output.view(batch_size, seq_len, -1)
|
triton_output = triton_output.view(batch_size, seq_len, -1)
|
||||||
@@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params):
|
|||||||
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
|
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
|
||||||
triton_output = model_attn.out_proj(triton_output)
|
triton_output = model_attn.out_proj(triton_output)
|
||||||
|
|
||||||
|
# Test SGL implementation
|
||||||
|
sgl_output = torch.empty_like(v)
|
||||||
|
sgl_new_kv = torch.empty_like(past_kv)
|
||||||
|
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
|
||||||
|
|
||||||
|
sgl_output = sgl_output.transpose(1, 2).contiguous()
|
||||||
|
sgl_output = sgl_output.view(batch_size, seq_len, -1)
|
||||||
|
sgl_output = model_attn.norm(sgl_output)
|
||||||
|
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
|
||||||
|
sgl_output = model_attn.out_proj(sgl_output)
|
||||||
|
|
||||||
|
# Verify Triton implementation results
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
model_output,
|
model_output,
|
||||||
triton_output,
|
triton_output,
|
||||||
rtol=1e-3,
|
rtol=1e-3,
|
||||||
atol=1e-2,
|
atol=1e-2,
|
||||||
msg="Lightning attention implementations produce different output results",
|
msg="Triton lightning attention implementation produces different output results",
|
||||||
)
|
)
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
new_kv,
|
new_kv,
|
||||||
triton_new_kv,
|
triton_new_kv,
|
||||||
rtol=1e-3,
|
rtol=1e-3,
|
||||||
atol=1e-2,
|
atol=1e-2,
|
||||||
msg="Lightning attention implementations produce different kv results",
|
msg="Triton lightning attention implementation produces different kv results",
|
||||||
)
|
)
|
||||||
|
|
||||||
print("✅ Two implementations match")
|
# Verify SGL implementation results
|
||||||
|
torch.testing.assert_close(
|
||||||
|
model_output,
|
||||||
|
sgl_output,
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-2,
|
||||||
|
msg="SGL lightning attention implementation produces different output results",
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_kv,
|
||||||
|
sgl_new_kv,
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-2,
|
||||||
|
msg="SGL lightning attention implementation produces different kv results",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("✅ All implementations match")
|
||||||
|
|
||||||
|
|
||||||
def _build_slope_tensor(n_attention_heads: int):
|
def _build_slope_tensor(n_attention_heads: int):
|
||||||
@@ -408,12 +442,13 @@ def get_benchmark():
|
|||||||
x_names=["batch_size", "seq_len"],
|
x_names=["batch_size", "seq_len"],
|
||||||
x_vals=[list(_) for _ in configs],
|
x_vals=[list(_) for _ in configs],
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["Original", "Triton"],
|
line_vals=["Original", "Triton", "SGL"],
|
||||||
line_names=[
|
line_names=[
|
||||||
"Original PyTorch Implementation",
|
"Original PyTorch Implementation",
|
||||||
"Triton Implementation",
|
"Triton Implementation",
|
||||||
|
"SGL Implementation",
|
||||||
],
|
],
|
||||||
styles=[("blue", "-"), ("green", "-")],
|
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||||
ylabel="us",
|
ylabel="us",
|
||||||
plot_name="lightning-attention-decode-performance",
|
plot_name="lightning-attention-decode-performance",
|
||||||
args={},
|
args={},
|
||||||
@@ -446,7 +481,6 @@ def get_benchmark():
|
|||||||
params["num_attention_heads"],
|
params["num_attention_heads"],
|
||||||
d,
|
d,
|
||||||
d,
|
d,
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -461,7 +495,7 @@ def get_benchmark():
|
|||||||
),
|
),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
else:
|
elif provider == "Triton":
|
||||||
|
|
||||||
def run_triton():
|
def run_triton():
|
||||||
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||||
@@ -483,6 +517,33 @@ def get_benchmark():
|
|||||||
run_triton,
|
run_triton,
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
|
else: # SGL
|
||||||
|
|
||||||
|
def run_sgl():
|
||||||
|
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||||
|
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||||
|
qkv = qkv.view(*new_shape)
|
||||||
|
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||||
|
q = q.transpose(1, 2).contiguous()
|
||||||
|
k = k.transpose(1, 2).contiguous()
|
||||||
|
v = v.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
output = torch.empty_like(v)
|
||||||
|
new_kv = torch.empty_like(past_kv)
|
||||||
|
sgl_lightning_attention_decode(
|
||||||
|
q, k, v, past_kv, slope_rate, output, new_kv
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.transpose(1, 2).contiguous()
|
||||||
|
output = output.view(batch_size, seq_len, -1)
|
||||||
|
output = model_attn.norm(output)
|
||||||
|
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
|
||||||
|
return model_attn.out_proj(output)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
run_sgl,
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|||||||
299
sgl-kernel/benchmark/bench_lightning_attention_decode.py
Normal file
299
sgl-kernel/benchmark/bench_lightning_attention_decode.py
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from sgl_kernel import lightning_attention_decode
|
||||||
|
|
||||||
|
|
||||||
|
def next_power_of_2(n):
|
||||||
|
return 2 ** (int(math.ceil(math.log(n, 2))))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _decode_kernel(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
KV,
|
||||||
|
Out,
|
||||||
|
S,
|
||||||
|
b: tl.constexpr,
|
||||||
|
h: tl.constexpr,
|
||||||
|
n: tl.constexpr,
|
||||||
|
d: tl.constexpr,
|
||||||
|
d_original: tl.constexpr,
|
||||||
|
e: tl.constexpr,
|
||||||
|
e_original: tl.constexpr,
|
||||||
|
):
|
||||||
|
off_bh = tl.program_id(0)
|
||||||
|
off_h = off_bh % h
|
||||||
|
|
||||||
|
qk_offset = off_bh * n * d
|
||||||
|
v_offset = off_bh * n * e
|
||||||
|
o_offset = off_bh * n * e
|
||||||
|
kv_offset = off_bh * d * e
|
||||||
|
|
||||||
|
s = tl.load(S + off_h)
|
||||||
|
ratio = tl.exp(-s)
|
||||||
|
|
||||||
|
d_idx = tl.arange(0, d)
|
||||||
|
e_idx = tl.arange(0, e)
|
||||||
|
|
||||||
|
# Create masks for original dimensions
|
||||||
|
d_mask = d_idx < d_original
|
||||||
|
e_mask = e_idx < e_original
|
||||||
|
|
||||||
|
# Load with masking
|
||||||
|
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||||
|
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||||
|
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
|
||||||
|
|
||||||
|
# Load KV with 2D masking
|
||||||
|
kv = tl.load(
|
||||||
|
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||||
|
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute outer product using element-wise operations
|
||||||
|
k_v_prod = k[:, None] * v[None, :]
|
||||||
|
kv = ratio * kv + k_v_prod
|
||||||
|
|
||||||
|
# Store KV with 2D masking
|
||||||
|
tl.store(
|
||||||
|
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||||
|
kv.to(KV.dtype.element_ty),
|
||||||
|
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute matrix-vector multiplication using element-wise operations and reduction
|
||||||
|
o = tl.sum(q[:, None] * kv, axis=0)
|
||||||
|
|
||||||
|
# Store output with masking
|
||||||
|
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def triton_lightning_attn_decode(q, k, v, kv, s):
|
||||||
|
"""Triton implementation of Lightning Attention decode operation"""
|
||||||
|
b, h, n, d = q.shape
|
||||||
|
e = v.shape[-1]
|
||||||
|
assert n == 1, "Sequence length must be 1 in decode mode"
|
||||||
|
|
||||||
|
# Get padded dimensions (power of 2)
|
||||||
|
d_padded = next_power_of_2(d)
|
||||||
|
e_padded = next_power_of_2(e)
|
||||||
|
|
||||||
|
# Create output tensor (padded)
|
||||||
|
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||||
|
|
||||||
|
# Create padded tensors without actually padding the data
|
||||||
|
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
|
||||||
|
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
|
||||||
|
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||||
|
kv_padded = torch.empty(
|
||||||
|
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy data to padded tensors
|
||||||
|
q_padded[..., :d] = q
|
||||||
|
k_padded[..., :d] = k
|
||||||
|
v_padded[..., :e] = v
|
||||||
|
kv_padded[..., :d, :e] = kv
|
||||||
|
|
||||||
|
# Launch kernel
|
||||||
|
grid = (b * h, 1)
|
||||||
|
_decode_kernel[grid](
|
||||||
|
q_padded,
|
||||||
|
k_padded,
|
||||||
|
v_padded,
|
||||||
|
kv_padded,
|
||||||
|
o_padded,
|
||||||
|
s,
|
||||||
|
b=b,
|
||||||
|
h=h,
|
||||||
|
n=n,
|
||||||
|
d=d_padded,
|
||||||
|
d_original=d,
|
||||||
|
e=e_padded,
|
||||||
|
e_original=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get unpadded outputs
|
||||||
|
o = o_padded[..., :e]
|
||||||
|
kv_out = kv_padded[..., :d, :e]
|
||||||
|
|
||||||
|
return o, kv_out
|
||||||
|
|
||||||
|
|
||||||
|
def lightning_attention_decode_naive(q, k, v, past_kv, slope):
|
||||||
|
"""Naive implementation of lightning attention decode"""
|
||||||
|
original_dtype = q.dtype
|
||||||
|
ratio = torch.exp(-slope) # [h, 1, 1]
|
||||||
|
|
||||||
|
kv = past_kv
|
||||||
|
b, h, n, d = q.shape
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for i in range(n):
|
||||||
|
kv = ratio * kv.to(torch.float32) + torch.einsum(
|
||||||
|
"... n d, ... n e -> ... d e",
|
||||||
|
k[:, :, i : i + 1],
|
||||||
|
v[:, :, i : i + 1],
|
||||||
|
)
|
||||||
|
qkv = torch.einsum(
|
||||||
|
"... n e, ... e d -> ... n d",
|
||||||
|
q[:, :, i : i + 1].to(torch.float32),
|
||||||
|
kv.to(torch.float32),
|
||||||
|
)
|
||||||
|
output.append(qkv)
|
||||||
|
output = torch.concat(output, dim=-2)
|
||||||
|
|
||||||
|
return output.to(original_dtype), kv
|
||||||
|
|
||||||
|
|
||||||
|
def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv):
|
||||||
|
return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_diff(batch_size):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda")
|
||||||
|
num_heads = 64
|
||||||
|
head_dim = 96
|
||||||
|
seq_len = 1
|
||||||
|
|
||||||
|
q = torch.randn(
|
||||||
|
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
k = torch.randn(
|
||||||
|
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
v = torch.randn(
|
||||||
|
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||||
|
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||||
|
|
||||||
|
output_naive, new_kv_naive = lightning_attention_decode_naive(
|
||||||
|
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||||
|
)
|
||||||
|
|
||||||
|
output_kernel = torch.empty_like(output_naive)
|
||||||
|
new_kv_kernel = torch.empty_like(new_kv_naive)
|
||||||
|
lightning_attention_decode_kernel(
|
||||||
|
q.clone(),
|
||||||
|
k.clone(),
|
||||||
|
v.clone(),
|
||||||
|
past_kv.clone(),
|
||||||
|
slope.clone(),
|
||||||
|
output_kernel,
|
||||||
|
new_kv_kernel,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_triton, new_kv_triton = triton_lightning_attn_decode(
|
||||||
|
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2)
|
||||||
|
and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2)
|
||||||
|
and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2)
|
||||||
|
and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2)
|
||||||
|
):
|
||||||
|
print("✅ All implementations match")
|
||||||
|
else:
|
||||||
|
print("❌ Implementations differ")
|
||||||
|
|
||||||
|
|
||||||
|
batch_size_range = [i for i in range(1, 65)] # 1 to 128
|
||||||
|
configs = [(bs,) for bs in batch_size_range]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[list(_) for _ in configs],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["naive", "kernel", "triton"],
|
||||||
|
line_names=["PyTorch Naive", "SGL Kernel", "Triton"],
|
||||||
|
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name="lightning-attention-decode-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda")
|
||||||
|
num_heads = 64
|
||||||
|
head_dim = 96
|
||||||
|
seq_len = 1
|
||||||
|
|
||||||
|
q = torch.randn(
|
||||||
|
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
k = torch.randn(
|
||||||
|
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
v = torch.randn(
|
||||||
|
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||||
|
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "naive":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: lightning_attention_decode_naive(
|
||||||
|
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif provider == "kernel":
|
||||||
|
output = torch.empty(
|
||||||
|
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: lightning_attention_decode_kernel(
|
||||||
|
q.clone(),
|
||||||
|
k.clone(),
|
||||||
|
v.clone(),
|
||||||
|
past_kv.clone(),
|
||||||
|
slope.clone(),
|
||||||
|
output,
|
||||||
|
new_kv,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif provider == "triton":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: triton_lightning_attn_decode(
|
||||||
|
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/lightning_attention_decode_sgl/",
|
||||||
|
help="Path to save lightning attention decode benchmark results",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Run correctness test
|
||||||
|
calculate_diff(batch_size=4)
|
||||||
|
|
||||||
|
# Run performance benchmark
|
||||||
|
benchmark.run(print_data=True)
|
||||||
@@ -100,6 +100,7 @@ ext_modules = [
|
|||||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||||
|
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||||
"3rdparty/flashinfer/csrc/activation.cu",
|
"3rdparty/flashinfer/csrc/activation.cu",
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from sgl_kernel.ops import (
|
|||||||
get_graph_buffer_ipc_meta,
|
get_graph_buffer_ipc_meta,
|
||||||
init_custom_reduce,
|
init_custom_reduce,
|
||||||
int8_scaled_mm,
|
int8_scaled_mm,
|
||||||
|
lightning_attention_decode,
|
||||||
moe_align_block_size,
|
moe_align_block_size,
|
||||||
register_graph_buffers,
|
register_graph_buffers,
|
||||||
rmsnorm,
|
rmsnorm,
|
||||||
@@ -35,5 +36,6 @@ __all__ = [
|
|||||||
"rmsnorm",
|
"rmsnorm",
|
||||||
"rotary_embedding",
|
"rotary_embedding",
|
||||||
"sampling_scaling_penalties",
|
"sampling_scaling_penalties",
|
||||||
|
"lightning_attention_decode",
|
||||||
"silu_and_mul",
|
"silu_and_mul",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,119 @@
|
|||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
|
#define THREADS_PER_BLOCK 128
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d]
|
||||||
|
const T* __restrict__ k, // [b, h, 1, d]
|
||||||
|
const T* __restrict__ v, // [b, h, 1, e]
|
||||||
|
const float* __restrict__ past_kv, // [b, h, d, e]
|
||||||
|
const float* __restrict__ slope, // [h, 1, 1]
|
||||||
|
T* __restrict__ output, // [b, h, 1, e]
|
||||||
|
float* __restrict__ new_kv, // [b, h, d, e]
|
||||||
|
const int batch_size, const int num_heads, const int qk_dim,
|
||||||
|
const int v_dim) {
|
||||||
|
extern __shared__ char smem[];
|
||||||
|
T* q_shared = reinterpret_cast<T*>(smem);
|
||||||
|
T* k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
|
||||||
|
T* v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
|
||||||
|
float* new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
|
||||||
|
T* output_shared =
|
||||||
|
reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));
|
||||||
|
|
||||||
|
const int32_t tid = threadIdx.x;
|
||||||
|
const int32_t current_head = blockIdx.x;
|
||||||
|
const int32_t b = current_head / num_heads;
|
||||||
|
const int32_t h = current_head % num_heads;
|
||||||
|
|
||||||
|
if (b >= batch_size) return;
|
||||||
|
|
||||||
|
const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim;
|
||||||
|
const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
|
||||||
|
const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;
|
||||||
|
|
||||||
|
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||||
|
q_shared[d] = q[qk_offset + d];
|
||||||
|
k_shared[d] = k[qk_offset + d];
|
||||||
|
}
|
||||||
|
for (int e = tid; e < v_dim; e += blockDim.x) {
|
||||||
|
v_shared[e] = v[v_offset + e];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float ratio = expf(-1.0f * slope[h]);
|
||||||
|
|
||||||
|
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||||
|
T k_val = k_shared[d];
|
||||||
|
for (int e = 0; e < v_dim; ++e) {
|
||||||
|
int past_kv_idx = kv_offset + d * v_dim + e;
|
||||||
|
T v_val = v_shared[e];
|
||||||
|
float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
|
||||||
|
int shared_idx = d * (v_dim + 1) + e;
|
||||||
|
new_kv_shared[shared_idx] = new_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
|
||||||
|
int d = idx / v_dim;
|
||||||
|
int e = idx % v_dim;
|
||||||
|
int shared_idx = d * (v_dim + 1) + e;
|
||||||
|
int global_idx = kv_offset + idx;
|
||||||
|
new_kv[global_idx] = new_kv_shared[shared_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int e = tid; e < v_dim; e += blockDim.x) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int d = 0; d < qk_dim; ++d) {
|
||||||
|
int shared_idx = d * (v_dim + 1) + e;
|
||||||
|
sum += q_shared[d] * new_kv_shared[shared_idx];
|
||||||
|
}
|
||||||
|
output_shared[e] = static_cast<T>(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
for (int e = 0; e < v_dim; ++e) {
|
||||||
|
output[v_offset + e] = output_shared[e];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
||||||
|
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
||||||
|
torch::Tensor new_kv) {
|
||||||
|
TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
|
||||||
|
TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
|
||||||
|
TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
|
||||||
|
TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous");
|
||||||
|
|
||||||
|
auto batch_size = q.size(0);
|
||||||
|
auto num_heads = q.size(1);
|
||||||
|
auto qk_dim = q.size(3);
|
||||||
|
auto v_dim = v.size(3);
|
||||||
|
|
||||||
|
dim3 block(THREADS_PER_BLOCK);
|
||||||
|
dim3 grid(batch_size * num_heads);
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||||
|
at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
|
||||||
|
size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
|
||||||
|
lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
|
||||||
|
q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), v.data_ptr<scalar_t>(), past_kv.data_ptr<float>(),
|
||||||
|
slope.data_ptr<float>(), output.data_ptr<scalar_t>(), new_kv.data_ptr<float>(), batch_size, num_heads,
|
||||||
|
qk_dim, v_dim);
|
||||||
|
}));
|
||||||
|
}
|
||||||
@@ -26,6 +26,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
|
|||||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||||
const c10::optional<torch::Tensor>& bias);
|
const c10::optional<torch::Tensor>& bias);
|
||||||
|
|
||||||
|
// lightning_attention_decode
|
||||||
|
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
||||||
|
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
||||||
|
torch::Tensor new_kv);
|
||||||
|
|
||||||
// rotary embedding
|
// rotary embedding
|
||||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
|
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||||
@@ -69,6 +74,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
|
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
|
||||||
// int8_scaled_mm
|
// int8_scaled_mm
|
||||||
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
|
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
|
||||||
|
// lightning_attention_decode
|
||||||
|
m.def("lightning_attention_decode", &lightning_attention_decode, "Lightning Attention Ddecode (CUDA)");
|
||||||
// rotary embedding
|
// rotary embedding
|
||||||
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
|
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
|
||||||
// rms norm
|
// rms norm
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ from sgl_kernel.ops._kernels import (
|
|||||||
)
|
)
|
||||||
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
|
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
|
||||||
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
|
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
|
||||||
|
from sgl_kernel.ops._kernels import (
|
||||||
|
lightning_attention_decode as _lightning_attention_decode,
|
||||||
|
)
|
||||||
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
|
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
|
||||||
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
|
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
|
||||||
from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
|
from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
|
||||||
@@ -86,6 +89,10 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||||
|
_lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
||||||
|
|
||||||
|
|
||||||
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
|
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
|
||||||
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
|
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
|
||||||
|
|
||||||
|
|||||||
84
sgl-kernel/tests/test_lightning_attention_decode.py
Normal file
84
sgl-kernel/tests/test_lightning_attention_decode.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from sgl_kernel import lightning_attention_decode
|
||||||
|
|
||||||
|
|
||||||
|
def naive_lightning_attention_decode(q, k, v, past_kv, slope):
|
||||||
|
"""Naive implementation of lightning attention decode"""
|
||||||
|
original_dtype = q.dtype
|
||||||
|
ratio = torch.exp(-slope) # [h, 1, 1]
|
||||||
|
|
||||||
|
kv = past_kv
|
||||||
|
b, h, n, d = q.shape
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for i in range(n):
|
||||||
|
kv = ratio * kv.to(torch.float32) + torch.einsum(
|
||||||
|
"... n d, ... n e -> ... d e",
|
||||||
|
k[:, :, i : i + 1],
|
||||||
|
v[:, :, i : i + 1],
|
||||||
|
)
|
||||||
|
qkv = torch.einsum(
|
||||||
|
"... n e, ... e d -> ... n d",
|
||||||
|
q[:, :, i : i + 1].to(torch.float32),
|
||||||
|
kv.to(torch.float32),
|
||||||
|
)
|
||||||
|
output.append(qkv)
|
||||||
|
output = torch.concat(output, dim=-2)
|
||||||
|
|
||||||
|
return output.to(original_dtype), kv
|
||||||
|
|
||||||
|
|
||||||
|
configs = [
|
||||||
|
# (batch_size, num_heads, dim, embed_dim)
|
||||||
|
(1, 8, 64, 64),
|
||||||
|
(2, 8, 64, 64),
|
||||||
|
(1, 32, 32, 64),
|
||||||
|
(2, 32, 32, 64),
|
||||||
|
(4, 32, 64, 64),
|
||||||
|
(4, 32, 64, 64),
|
||||||
|
(16, 64, 96, 96),
|
||||||
|
(64, 64, 96, 96),
|
||||||
|
]
|
||||||
|
|
||||||
|
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||||
|
@pytest.mark.parametrize("dtype", dtypes)
|
||||||
|
@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs)
|
||||||
|
def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
|
||||||
|
k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
|
||||||
|
v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype)
|
||||||
|
past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device)
|
||||||
|
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||||
|
|
||||||
|
ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope)
|
||||||
|
|
||||||
|
output = torch.empty_like(ref_output)
|
||||||
|
new_kv = torch.empty_like(ref_new_kv)
|
||||||
|
lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
||||||
|
|
||||||
|
rtol = 1e-2
|
||||||
|
atol = 1e-2
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
output,
|
||||||
|
ref_output,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, "
|
||||||
|
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
new_kv,
|
||||||
|
ref_new_kv,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, "
|
||||||
|
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user