Fix correctness issue for triton decoding kernel (#2479)
This commit is contained in:
@@ -232,9 +232,9 @@ class TestTritonAttention(unittest.TestCase):
|
||||
for B, H_Q, H_KV, D in configs:
|
||||
self._test_decode_attention_once(B, H_Q, H_KV, D)
|
||||
|
||||
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
|
||||
def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V):
|
||||
dtype = torch.bfloat16
|
||||
seq_len = 128 # This represents the number of tokens already in the sequence
|
||||
seq_len = S # This represents the number of tokens already in the sequence
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
num_kv_splits = 8
|
||||
@@ -300,6 +300,7 @@ class TestTritonAttention(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))
|
||||
|
||||
def test_grouped_decode_attention(self):
|
||||
seq_lens = [5, 100, 128, 500]
|
||||
configs = [
|
||||
(2, 16, 16, 64, 64),
|
||||
(2, 16, 1, 64, 64),
|
||||
@@ -309,8 +310,9 @@ class TestTritonAttention(unittest.TestCase):
|
||||
(2, 128, 1, 576, 512),
|
||||
]
|
||||
|
||||
for B, H_Q, H_KV, D, D_V in configs:
|
||||
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V)
|
||||
for S in seq_lens:
|
||||
for B, H_Q, H_KV, D, D_V in configs:
|
||||
self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user