support lightning_attention_decode in sgl-kernel for MiniMax-Text-01 (#3030)
This commit is contained in:
@@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params):
|
||||
model_params["num_attention_heads"],
|
||||
d,
|
||||
d,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
with torch.no_grad():
|
||||
@@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params):
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
past_kv = past_kv.contiguous()
|
||||
slope_rate = slope_rate.contiguous()
|
||||
|
||||
# Test Triton implementation
|
||||
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
|
||||
triton_output = triton_output.transpose(1, 2).contiguous()
|
||||
triton_output = triton_output.view(batch_size, seq_len, -1)
|
||||
@@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params):
|
||||
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
|
||||
triton_output = model_attn.out_proj(triton_output)
|
||||
|
||||
# Test SGL implementation
|
||||
sgl_output = torch.empty_like(v)
|
||||
sgl_new_kv = torch.empty_like(past_kv)
|
||||
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
|
||||
|
||||
sgl_output = sgl_output.transpose(1, 2).contiguous()
|
||||
sgl_output = sgl_output.view(batch_size, seq_len, -1)
|
||||
sgl_output = model_attn.norm(sgl_output)
|
||||
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
|
||||
sgl_output = model_attn.out_proj(sgl_output)
|
||||
|
||||
# Verify Triton implementation results
|
||||
torch.testing.assert_close(
|
||||
model_output,
|
||||
triton_output,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="Lightning attention implementations produce different output results",
|
||||
msg="Triton lightning attention implementation produces different output results",
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
new_kv,
|
||||
triton_new_kv,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="Lightning attention implementations produce different kv results",
|
||||
msg="Triton lightning attention implementation produces different kv results",
|
||||
)
|
||||
|
||||
print("✅ Two implementations match")
|
||||
# Verify SGL implementation results
|
||||
torch.testing.assert_close(
|
||||
model_output,
|
||||
sgl_output,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="SGL lightning attention implementation produces different output results",
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
new_kv,
|
||||
sgl_new_kv,
|
||||
rtol=1e-3,
|
||||
atol=1e-2,
|
||||
msg="SGL lightning attention implementation produces different kv results",
|
||||
)
|
||||
|
||||
print("✅ All implementations match")
|
||||
|
||||
|
||||
def _build_slope_tensor(n_attention_heads: int):
|
||||
@@ -408,12 +442,13 @@ def get_benchmark():
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["Original", "Triton"],
|
||||
line_vals=["Original", "Triton", "SGL"],
|
||||
line_names=[
|
||||
"Original PyTorch Implementation",
|
||||
"Triton Implementation",
|
||||
"SGL Implementation",
|
||||
],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="lightning-attention-decode-performance",
|
||||
args={},
|
||||
@@ -446,7 +481,6 @@ def get_benchmark():
|
||||
params["num_attention_heads"],
|
||||
d,
|
||||
d,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@@ -461,7 +495,7 @@ def get_benchmark():
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
elif provider == "Triton":
|
||||
|
||||
def run_triton():
|
||||
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||
@@ -483,6 +517,33 @@ def get_benchmark():
|
||||
run_triton,
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else: # SGL
|
||||
|
||||
def run_sgl():
|
||||
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
|
||||
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
|
||||
qkv = qkv.view(*new_shape)
|
||||
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
|
||||
q = q.transpose(1, 2).contiguous()
|
||||
k = k.transpose(1, 2).contiguous()
|
||||
v = v.transpose(1, 2).contiguous()
|
||||
|
||||
output = torch.empty_like(v)
|
||||
new_kv = torch.empty_like(past_kv)
|
||||
sgl_lightning_attention_decode(
|
||||
q, k, v, past_kv, slope_rate, output, new_kv
|
||||
)
|
||||
|
||||
output = output.transpose(1, 2).contiguous()
|
||||
output = output.view(batch_size, seq_len, -1)
|
||||
output = model_attn.norm(output)
|
||||
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
|
||||
return model_attn.out_proj(output)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
run_sgl,
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
Reference in New Issue
Block a user