Fix Triton decode kernel & ut (#1819)
This commit is contained in:
@@ -296,12 +296,18 @@ def _fwd_grouped_kernel_stage1(
|
||||
Lk: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_kv_head = tl.program_id(1)
|
||||
cur_head_id = tl.program_id(1)
|
||||
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||
start_n = tl.program_id(2)
|
||||
|
||||
reduce_dtype = Att_Out.dtype.element_ty
|
||||
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
||||
|
||||
if BLOCK_H < kv_group_num:
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||
else:
|
||||
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||
mask_h = mask_h & (cur_head < q_head_num)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
@@ -400,10 +406,15 @@ def _fwd_grouped_kernel_stage2(
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_kv_head = tl.program_id(1)
|
||||
cur_head_id = tl.program_id(1)
|
||||
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||
|
||||
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
||||
if BLOCK_H < kv_group_num:
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||
else:
|
||||
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||
mask_h = mask_h & (cur_head < q_head_num)
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
@@ -485,7 +496,7 @@ def _decode_grouped_att_m_fwd(
|
||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||
|
||||
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
||||
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
|
||||
grid = (
|
||||
batch,
|
||||
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
||||
@@ -534,7 +545,7 @@ def _decode_grouped_softmax_reducev_fwd(
|
||||
BLOCK = 128
|
||||
batch, head_num = b_seq_len.shape[0], logits.shape[0]
|
||||
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
||||
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
||||
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
|
||||
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
||||
|
||||
num_warps = 8
|
||||
@@ -567,6 +578,80 @@ def _decode_grouped_softmax_reducev_fwd(
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_grouped_softmax_reducev_fwd(
|
||||
attn_logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
@@ -585,47 +670,33 @@ def decode_attention_fwd(
|
||||
|
||||
if kv_group_num == 1:
|
||||
# MHA
|
||||
_decode_att_m_fwd(
|
||||
decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
_decode_grouped_att_m_fwd(
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_grouped_softmax_reducev_fwd(
|
||||
attn_logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
@@ -168,7 +168,7 @@ def _fwd_kernel(
|
||||
def context_attention_fwd(
|
||||
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
||||
):
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64
|
||||
|
||||
@@ -26,7 +26,8 @@ suites = {
|
||||
"test_srt_endpoint.py",
|
||||
"test_torch_compile.py",
|
||||
"test_torchao.py",
|
||||
"test_triton_attn_backend.py",
|
||||
"test_triton_attention_kernels.py",
|
||||
"test_triton_attention_backend.py",
|
||||
"test_update_weights.py",
|
||||
"test_vision_openai_server.py",
|
||||
],
|
||||
|
||||
@@ -3,7 +3,11 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
||||
decode_attention_fwd,
|
||||
decode_attention_fwd_grouped,
|
||||
decode_attention_fwd_normal,
|
||||
)
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
||||
extend_attention_fwd,
|
||||
redundant_attention,
|
||||
@@ -13,7 +17,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
)
|
||||
|
||||
|
||||
class TestExtendAttention(unittest.TestCase):
|
||||
class TestTritonAttention(unittest.TestCase):
|
||||
|
||||
def _set_all_seeds(self, seed):
|
||||
"""Set all random seeds for reproducibility."""
|
||||
@@ -127,7 +131,7 @@ class TestExtendAttention(unittest.TestCase):
|
||||
for value in attention_values:
|
||||
self._test_extend_attention_once(19, 12331, 12, 4, value)
|
||||
|
||||
def _test_context_attention_once(self, head_dim):
|
||||
def _test_context_attention_once(self, head_dim, is_causal):
|
||||
# Set up a simple test case
|
||||
num_heads = 4
|
||||
seq_lens = [8, 12]
|
||||
@@ -143,15 +147,35 @@ class TestExtendAttention(unittest.TestCase):
|
||||
b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda")
|
||||
b_seq_len = torch.tensor(seq_lens, device="cuda")
|
||||
|
||||
context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len)
|
||||
context_attention_fwd(
|
||||
q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal
|
||||
)
|
||||
|
||||
cu_seq_lens = [0] * (len(seq_lens) + 1)
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len
|
||||
|
||||
for i in range(len(seq_lens)):
|
||||
start, end = cu_seq_lens[i], cu_seq_lens[i + 1]
|
||||
o_torch = torch.nn.functional.scaled_dot_product_attention(
|
||||
q[start:end].permute(1, 0, 2),
|
||||
k[start:end].permute(1, 0, 2),
|
||||
v[start:end].permute(1, 0, 2),
|
||||
is_causal=is_causal,
|
||||
).permute(1, 0, 2)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
o[start:end].flatten(), o_torch.flatten(), dim=0
|
||||
)
|
||||
self.assertTrue(cos_sim.item() > 1 - (1e-5))
|
||||
self.assertTrue(torch.allclose(o[start:end], o_torch, atol=1e-2))
|
||||
|
||||
def test_context_attention(self):
|
||||
# Here we just to ensure there is no error
|
||||
# TODO: correctnesss test
|
||||
head_dim = [128, 96, 80, 13]
|
||||
|
||||
for dim in head_dim:
|
||||
self._test_context_attention_once(dim)
|
||||
for is_causal in [True, False]:
|
||||
self._test_context_attention_once(dim, is_causal)
|
||||
|
||||
def _test_decode_attention_once(self, B, H_Q, H_KV, D):
|
||||
dtype = torch.bfloat16
|
||||
@@ -174,6 +198,12 @@ class TestExtendAttention(unittest.TestCase):
|
||||
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(H_Q, total_tokens),
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
@@ -183,8 +213,8 @@ class TestExtendAttention(unittest.TestCase):
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
seq_len,
|
||||
total_tokens,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
@@ -203,6 +233,80 @@ class TestExtendAttention(unittest.TestCase):
|
||||
for B, H_Q, H_KV, D in configs:
|
||||
self._test_decode_attention_once(B, H_Q, H_KV, D)
|
||||
|
||||
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
|
||||
dtype = torch.bfloat16
|
||||
seq_len = 10 # This represents the number of tokens already in the sequence
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
|
||||
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
|
||||
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
||||
b_req_idx = torch.arange(B, device="cuda")
|
||||
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(H_Q, total_tokens),
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
seq_len,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o_grouped,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
seq_len,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
o.flatten(), o_grouped.flatten(), dim=0
|
||||
)
|
||||
self.assertTrue(cos_sim.item() > 0.99)
|
||||
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))
|
||||
|
||||
def test_grouped_decode_attention(self):
|
||||
configs = [
|
||||
(2, 16, 1, 64, 64),
|
||||
(2, 64, 1, 13, 13),
|
||||
(2, 128, 1, 80, 80),
|
||||
(2, 128, 2, 512, 512),
|
||||
(2, 128, 1, 576, 512),
|
||||
]
|
||||
|
||||
for B, H_Q, H_KV, D, D_V in configs:
|
||||
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user