Fix correctness issue for triton decoding kernel (#2479)

This commit is contained in:
Ke Bao
2024-12-14 16:50:54 +08:00
committed by GitHub
parent 5282a4735f
commit 2f9bd0fafd
2 changed files with 30 additions and 18 deletions

View File

@@ -32,7 +32,7 @@ is_hip_ = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy. # TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
logger.warn( logger.warning(
"The following error message 'operation scheduled before its operands' can be ignored." "The following error message 'operation scheduled before its operands' can be ignored."
) )
@@ -474,6 +474,7 @@ def _decode_grouped_att_m_fwd(
def _fwd_kernel_stage2( def _fwd_kernel_stage2(
Mid_O, Mid_O,
O, O,
B_Seqlen,
stride_mid_ob, stride_mid_ob,
stride_mid_oh, stride_mid_oh,
stride_mid_os, stride_mid_os,
@@ -486,6 +487,8 @@ def _fwd_kernel_stage2(
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
offs_d = tl.arange(0, BLOCK_DV) offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv mask_d = offs_d < Lv
@@ -497,19 +500,24 @@ def _fwd_kernel_stage2(
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
for split_kv_id in range(0, NUM_KV_SPLITS): for split_kv_id in range(0, NUM_KV_SPLITS):
tv = tl.load( kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 split_kv_start = kv_len_per_split * split_kv_id
) split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max) if split_kv_end > split_kv_start:
acc *= old_scale tv = tl.load(
exp_logic = tl.exp(tlogic - n_e_max) Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
acc += exp_logic * tv )
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
e_sum = e_sum * old_scale + exp_logic old_scale = tl.exp(e_max - n_e_max)
e_max = n_e_max acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store( tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
@@ -523,6 +531,7 @@ def _decode_softmax_reducev_fwd(
q, q,
o, o,
v_buffer, v_buffer,
b_seq_len,
num_kv_splits, num_kv_splits,
): ):
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
@@ -541,6 +550,7 @@ def _decode_softmax_reducev_fwd(
_fwd_kernel_stage2[grid]( _fwd_kernel_stage2[grid](
logits, logits,
o, o,
b_seq_len,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),
logits.stride(2), logits.stride(2),
@@ -580,7 +590,7 @@ def decode_attention_fwd_normal(
sm_scale, sm_scale,
logit_cap, logit_cap,
) )
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
def decode_attention_fwd_grouped( def decode_attention_fwd_grouped(
@@ -608,7 +618,7 @@ def decode_attention_fwd_grouped(
sm_scale, sm_scale,
logit_cap, logit_cap,
) )
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
def decode_attention_fwd( def decode_attention_fwd(

View File

@@ -232,9 +232,9 @@ class TestTritonAttention(unittest.TestCase):
for B, H_Q, H_KV, D in configs: for B, H_Q, H_KV, D in configs:
self._test_decode_attention_once(B, H_Q, H_KV, D) self._test_decode_attention_once(B, H_Q, H_KV, D)
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V):
dtype = torch.bfloat16 dtype = torch.bfloat16
seq_len = 128 # This represents the number of tokens already in the sequence seq_len = S # This represents the number of tokens already in the sequence
total_tokens = B * seq_len total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5) sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8 num_kv_splits = 8
@@ -300,6 +300,7 @@ class TestTritonAttention(unittest.TestCase):
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2)) self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))
def test_grouped_decode_attention(self): def test_grouped_decode_attention(self):
seq_lens = [5, 100, 128, 500]
configs = [ configs = [
(2, 16, 16, 64, 64), (2, 16, 16, 64, 64),
(2, 16, 1, 64, 64), (2, 16, 1, 64, 64),
@@ -309,8 +310,9 @@ class TestTritonAttention(unittest.TestCase):
(2, 128, 1, 576, 512), (2, 128, 1, 576, 512),
] ]
for B, H_Q, H_KV, D, D_V in configs: for S in seq_lens:
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V) for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)
if __name__ == "__main__": if __name__ == "__main__":