support lightning_attention_decode in sgl-kernel for MiniMax-Text-01 (#3030)

This commit is contained in:
Xiaoyu Zhang
2025-01-23 15:29:20 +08:00
committed by GitHub
parent 3e032c07cc
commit ac2dc35d0e
8 changed files with 588 additions and 8 deletions

View File

@@ -9,6 +9,7 @@ import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
@triton.jit
@@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params):
model_params["num_attention_heads"],
d,
d,
dtype=dtype,
device=device,
)
with torch.no_grad():
@@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params):
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
past_kv = past_kv.contiguous()
slope_rate = slope_rate.contiguous()
# Test Triton implementation
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
triton_output = triton_output.transpose(1, 2).contiguous()
triton_output = triton_output.view(batch_size, seq_len, -1)
@@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params):
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
triton_output = model_attn.out_proj(triton_output)
# Test SGL implementation
sgl_output = torch.empty_like(v)
sgl_new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
sgl_output = sgl_output.transpose(1, 2).contiguous()
sgl_output = sgl_output.view(batch_size, seq_len, -1)
sgl_output = model_attn.norm(sgl_output)
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
sgl_output = model_attn.out_proj(sgl_output)
# Verify Triton implementation results
torch.testing.assert_close(
model_output,
triton_output,
rtol=1e-3,
atol=1e-2,
msg="Lightning attention implementations produce different output results",
msg="Triton lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
triton_new_kv,
rtol=1e-3,
atol=1e-2,
msg="Lightning attention implementations produce different kv results",
msg="Triton lightning attention implementation produces different kv results",
)
print("✅ Two implementations match")
# Verify SGL implementation results
torch.testing.assert_close(
model_output,
sgl_output,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
sgl_new_kv,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different kv results",
)
print("✅ All implementations match")
def _build_slope_tensor(n_attention_heads: int):
@@ -408,12 +442,13 @@ def get_benchmark():
x_names=["batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["Original", "Triton"],
line_vals=["Original", "Triton", "SGL"],
line_names=[
"Original PyTorch Implementation",
"Triton Implementation",
"SGL Implementation",
],
styles=[("blue", "-"), ("green", "-")],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="lightning-attention-decode-performance",
args={},
@@ -446,7 +481,6 @@ def get_benchmark():
params["num_attention_heads"],
d,
d,
dtype=dtype,
device=device,
)
@@ -461,7 +495,7 @@ def get_benchmark():
),
quantiles=quantiles,
)
else:
elif provider == "Triton":
def run_triton():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
@@ -483,6 +517,33 @@ def get_benchmark():
run_triton,
quantiles=quantiles,
)
else: # SGL
def run_sgl():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
output = torch.empty_like(v)
new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(
q, k, v, past_kv, slope_rate, output, new_kv
)
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, -1)
output = model_attn.norm(output)
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
return model_attn.out_proj(output)
ms, min_ms, max_ms = triton.testing.do_bench(
run_sgl,
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms