Fix correctness issue for triton decoding kernel (#2479)

This commit is contained in:
Ke Bao
2024-12-14 16:50:54 +08:00
committed by GitHub
parent 5282a4735f
commit 2f9bd0fafd
2 changed files with 30 additions and 18 deletions

View File

@@ -232,9 +232,9 @@ class TestTritonAttention(unittest.TestCase):
for B, H_Q, H_KV, D in configs:
self._test_decode_attention_once(B, H_Q, H_KV, D)
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V):
dtype = torch.bfloat16
seq_len = 128 # This represents the number of tokens already in the sequence
seq_len = S # This represents the number of tokens already in the sequence
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8
@@ -300,6 +300,7 @@ class TestTritonAttention(unittest.TestCase):
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))
def test_grouped_decode_attention(self):
seq_lens = [5, 100, 128, 500]
configs = [
(2, 16, 16, 64, 64),
(2, 16, 1, 64, 64),
@@ -309,8 +310,9 @@ class TestTritonAttention(unittest.TestCase):
(2, 128, 1, 576, 512),
]
for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V)
for S in seq_lens:
for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)
if __name__ == "__main__":