Optimize Triton decoding kernel for long context (#2394)

This commit is contained in:
Ke Bao
2024-12-08 17:17:37 +08:00
committed by GitHub
parent 1f09e84b9a
commit 7dc66fcb40
4 changed files with 344 additions and 376 deletions

View File

@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase):
seq_len = 10 # This represents the number of tokens already in the sequence
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len = torch.full((B,), seq_len, device="cuda")
attn_logits = torch.empty(
(H_Q, total_tokens),
dtype=dtype,
(B, H_Q, num_kv_splits, D + 1),
dtype=torch.float32,
device="cuda",
)
@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len,
attn_logits,
seq_len,
num_kv_splits,
sm_scale,
)
@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase):
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
dtype = torch.bfloat16
seq_len = 10 # This represents the number of tokens already in the sequence
seq_len = 128 # This represents the number of tokens already in the sequence
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase):
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda")
@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len = torch.full((B,), seq_len, device="cuda")
attn_logits = torch.empty(
(H_Q, total_tokens),
dtype=dtype,
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase):
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
seq_len,
num_kv_splits,
sm_scale,
)
attn_logits1 = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
decode_attention_fwd_grouped(
q,
k_buffer,
@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase):
o_grouped,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
attn_logits1,
seq_len,
num_kv_splits,
sm_scale,
)
cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_grouped.flatten(), dim=0
)
print(cos_sim.item())
self.assertTrue(cos_sim.item() > 0.99)
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))
def test_grouped_decode_attention(self):
configs = [
(2, 16, 16, 64, 64),
(2, 16, 1, 64, 64),
(2, 64, 1, 13, 13),
(2, 128, 1, 80, 80),