Fix Triton decode kernel & ut (#1819)
This commit is contained in:
@@ -296,12 +296,18 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
Lk: tl.constexpr,
|
Lk: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_kv_head = tl.program_id(1)
|
cur_head_id = tl.program_id(1)
|
||||||
|
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||||
start_n = tl.program_id(2)
|
start_n = tl.program_id(2)
|
||||||
|
|
||||||
reduce_dtype = Att_Out.dtype.element_ty
|
reduce_dtype = Att_Out.dtype.element_ty
|
||||||
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
|
||||||
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
if BLOCK_H < kv_group_num:
|
||||||
|
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||||
|
else:
|
||||||
|
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||||
|
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||||
|
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||||
mask_h = mask_h & (cur_head < q_head_num)
|
mask_h = mask_h & (cur_head < q_head_num)
|
||||||
|
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
@@ -400,10 +406,15 @@ def _fwd_grouped_kernel_stage2(
|
|||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_kv_head = tl.program_id(1)
|
cur_head_id = tl.program_id(1)
|
||||||
|
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||||
|
|
||||||
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
if BLOCK_H < kv_group_num:
|
||||||
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||||
|
else:
|
||||||
|
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||||
|
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||||
|
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||||
mask_h = mask_h & (cur_head < q_head_num)
|
mask_h = mask_h & (cur_head < q_head_num)
|
||||||
|
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
@@ -485,7 +496,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||||
|
|
||||||
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
|
||||||
grid = (
|
grid = (
|
||||||
batch,
|
batch,
|
||||||
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
||||||
@@ -534,7 +545,7 @@ def _decode_grouped_softmax_reducev_fwd(
|
|||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
batch, head_num = b_seq_len.shape[0], logits.shape[0]
|
batch, head_num = b_seq_len.shape[0], logits.shape[0]
|
||||||
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
||||||
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
|
||||||
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
||||||
|
|
||||||
num_warps = 8
|
num_warps = 8
|
||||||
@@ -567,6 +578,80 @@ def _decode_grouped_softmax_reducev_fwd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attention_fwd_normal(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
max_len_in_batch,
|
||||||
|
sm_scale,
|
||||||
|
logit_cap=0.0,
|
||||||
|
):
|
||||||
|
_decode_att_m_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
attn_logits,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
max_len_in_batch,
|
||||||
|
sm_scale,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
|
_decode_softmax_reducev_fwd(
|
||||||
|
attn_logits,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attention_fwd_grouped(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
max_len_in_batch,
|
||||||
|
sm_scale,
|
||||||
|
logit_cap=0.0,
|
||||||
|
):
|
||||||
|
_decode_grouped_att_m_fwd(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
attn_logits,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
max_len_in_batch,
|
||||||
|
sm_scale,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
|
_decode_grouped_softmax_reducev_fwd(
|
||||||
|
attn_logits,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def decode_attention_fwd(
|
def decode_attention_fwd(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
@@ -585,47 +670,33 @@ def decode_attention_fwd(
|
|||||||
|
|
||||||
if kv_group_num == 1:
|
if kv_group_num == 1:
|
||||||
# MHA
|
# MHA
|
||||||
_decode_att_m_fwd(
|
decode_attention_fwd_normal(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
attn_logits,
|
|
||||||
req_to_token,
|
|
||||||
b_req_idx,
|
|
||||||
b_start_loc,
|
|
||||||
b_seq_len,
|
|
||||||
max_len_in_batch,
|
|
||||||
sm_scale,
|
|
||||||
logit_cap,
|
|
||||||
)
|
|
||||||
_decode_softmax_reducev_fwd(
|
|
||||||
attn_logits,
|
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
max_len_in_batch,
|
||||||
|
sm_scale,
|
||||||
|
logit_cap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# GQA/MQA/MLA
|
# GQA/MQA/MLA
|
||||||
_decode_grouped_att_m_fwd(
|
decode_attention_fwd_grouped(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
attn_logits,
|
|
||||||
req_to_token,
|
|
||||||
b_req_idx,
|
|
||||||
b_start_loc,
|
|
||||||
b_seq_len,
|
|
||||||
max_len_in_batch,
|
|
||||||
sm_scale,
|
|
||||||
logit_cap,
|
|
||||||
)
|
|
||||||
_decode_grouped_softmax_reducev_fwd(
|
|
||||||
attn_logits,
|
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
max_len_in_batch,
|
||||||
|
sm_scale,
|
||||||
|
logit_cap,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ def _fwd_kernel(
|
|||||||
def context_attention_fwd(
|
def context_attention_fwd(
|
||||||
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
||||||
):
|
):
|
||||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
else:
|
else:
|
||||||
BLOCK = 64
|
BLOCK = 64
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ suites = {
|
|||||||
"test_srt_endpoint.py",
|
"test_srt_endpoint.py",
|
||||||
"test_torch_compile.py",
|
"test_torch_compile.py",
|
||||||
"test_torchao.py",
|
"test_torchao.py",
|
||||||
"test_triton_attn_backend.py",
|
"test_triton_attention_kernels.py",
|
||||||
|
"test_triton_attention_backend.py",
|
||||||
"test_update_weights.py",
|
"test_update_weights.py",
|
||||||
"test_vision_openai_server.py",
|
"test_vision_openai_server.py",
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -3,7 +3,11 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
|
from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
||||||
|
decode_attention_fwd,
|
||||||
|
decode_attention_fwd_grouped,
|
||||||
|
decode_attention_fwd_normal,
|
||||||
|
)
|
||||||
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
||||||
extend_attention_fwd,
|
extend_attention_fwd,
|
||||||
redundant_attention,
|
redundant_attention,
|
||||||
@@ -13,7 +17,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestExtendAttention(unittest.TestCase):
|
class TestTritonAttention(unittest.TestCase):
|
||||||
|
|
||||||
def _set_all_seeds(self, seed):
|
def _set_all_seeds(self, seed):
|
||||||
"""Set all random seeds for reproducibility."""
|
"""Set all random seeds for reproducibility."""
|
||||||
@@ -127,7 +131,7 @@ class TestExtendAttention(unittest.TestCase):
|
|||||||
for value in attention_values:
|
for value in attention_values:
|
||||||
self._test_extend_attention_once(19, 12331, 12, 4, value)
|
self._test_extend_attention_once(19, 12331, 12, 4, value)
|
||||||
|
|
||||||
def _test_context_attention_once(self, head_dim):
|
def _test_context_attention_once(self, head_dim, is_causal):
|
||||||
# Set up a simple test case
|
# Set up a simple test case
|
||||||
num_heads = 4
|
num_heads = 4
|
||||||
seq_lens = [8, 12]
|
seq_lens = [8, 12]
|
||||||
@@ -143,15 +147,35 @@ class TestExtendAttention(unittest.TestCase):
|
|||||||
b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda")
|
b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda")
|
||||||
b_seq_len = torch.tensor(seq_lens, device="cuda")
|
b_seq_len = torch.tensor(seq_lens, device="cuda")
|
||||||
|
|
||||||
context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len)
|
context_attention_fwd(
|
||||||
|
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seq_lens = [0] * (len(seq_lens) + 1)
|
||||||
|
for i, seq_len in enumerate(seq_lens):
|
||||||
|
cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len
|
||||||
|
|
||||||
|
for i in range(len(seq_lens)):
|
||||||
|
start, end = cu_seq_lens[i], cu_seq_lens[i + 1]
|
||||||
|
o_torch = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q[start:end].permute(1, 0, 2),
|
||||||
|
k[start:end].permute(1, 0, 2),
|
||||||
|
v[start:end].permute(1, 0, 2),
|
||||||
|
is_causal=is_causal,
|
||||||
|
).permute(1, 0, 2)
|
||||||
|
|
||||||
|
cos_sim = torch.nn.functional.cosine_similarity(
|
||||||
|
o[start:end].flatten(), o_torch.flatten(), dim=0
|
||||||
|
)
|
||||||
|
self.assertTrue(cos_sim.item() > 1 - (1e-5))
|
||||||
|
self.assertTrue(torch.allclose(o[start:end], o_torch, atol=1e-2))
|
||||||
|
|
||||||
def test_context_attention(self):
|
def test_context_attention(self):
|
||||||
# Here we just to ensure there is no error
|
|
||||||
# TODO: correctnesss test
|
|
||||||
head_dim = [128, 96, 80, 13]
|
head_dim = [128, 96, 80, 13]
|
||||||
|
|
||||||
for dim in head_dim:
|
for dim in head_dim:
|
||||||
self._test_context_attention_once(dim)
|
for is_causal in [True, False]:
|
||||||
|
self._test_context_attention_once(dim, is_causal)
|
||||||
|
|
||||||
def _test_decode_attention_once(self, B, H_Q, H_KV, D):
|
def _test_decode_attention_once(self, B, H_Q, H_KV, D):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
@@ -174,6 +198,12 @@ class TestExtendAttention(unittest.TestCase):
|
|||||||
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
|
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
|
||||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||||
|
|
||||||
|
attn_logits = torch.empty(
|
||||||
|
(H_Q, total_tokens),
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
decode_attention_fwd(
|
decode_attention_fwd(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
@@ -183,8 +213,8 @@ class TestExtendAttention(unittest.TestCase):
|
|||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
seq_len,
|
seq_len,
|
||||||
total_tokens,
|
|
||||||
sm_scale,
|
sm_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -203,6 +233,80 @@ class TestExtendAttention(unittest.TestCase):
|
|||||||
for B, H_Q, H_KV, D in configs:
|
for B, H_Q, H_KV, D in configs:
|
||||||
self._test_decode_attention_once(B, H_Q, H_KV, D)
|
self._test_decode_attention_once(B, H_Q, H_KV, D)
|
||||||
|
|
||||||
|
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
seq_len = 10 # This represents the number of tokens already in the sequence
|
||||||
|
total_tokens = B * seq_len
|
||||||
|
sm_scale = 1.0 / (D**0.5)
|
||||||
|
|
||||||
|
# q represents the new token being generated, one per batch
|
||||||
|
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
# k_buffer and v_buffer represent all previous tokens
|
||||||
|
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||||
|
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
# o will have the same shape as q
|
||||||
|
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
|
||||||
|
o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
||||||
|
b_req_idx = torch.arange(B, device="cuda")
|
||||||
|
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
|
||||||
|
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||||
|
|
||||||
|
attn_logits = torch.empty(
|
||||||
|
(H_Q, total_tokens),
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
decode_attention_fwd_normal(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
seq_len,
|
||||||
|
sm_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
decode_attention_fwd_grouped(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o_grouped,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_start_loc,
|
||||||
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
|
seq_len,
|
||||||
|
sm_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
cos_sim = torch.nn.functional.cosine_similarity(
|
||||||
|
o.flatten(), o_grouped.flatten(), dim=0
|
||||||
|
)
|
||||||
|
self.assertTrue(cos_sim.item() > 0.99)
|
||||||
|
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))
|
||||||
|
|
||||||
|
def test_grouped_decode_attention(self):
|
||||||
|
configs = [
|
||||||
|
(2, 16, 1, 64, 64),
|
||||||
|
(2, 64, 1, 13, 13),
|
||||||
|
(2, 128, 1, 80, 80),
|
||||||
|
(2, 128, 2, 512, 512),
|
||||||
|
(2, 128, 1, 576, 512),
|
||||||
|
]
|
||||||
|
|
||||||
|
for B, H_Q, H_KV, D, D_V in configs:
|
||||||
|
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user