Optimize Triton decoding kernel for long context (#2394)
This commit is contained in:
@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase):
|
||||
seq_len = 10 # This represents the number of tokens already in the sequence
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
num_kv_splits = 8
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase):
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(H_Q, total_tokens),
|
||||
dtype=dtype,
|
||||
(B, H_Q, num_kv_splits, D + 1),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase):
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase):
|
||||
|
||||
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V):
|
||||
dtype = torch.bfloat16
|
||||
seq_len = 10 # This represents the number of tokens already in the sequence
|
||||
seq_len = 128 # This represents the number of tokens already in the sequence
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
num_kv_splits = 8
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase):
|
||||
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
||||
b_req_idx = torch.arange(B, device="cuda")
|
||||
@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase):
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(H_Q, total_tokens),
|
||||
dtype=dtype,
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase):
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
attn_logits1 = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase):
|
||||
o_grouped,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
attn_logits1,
|
||||
seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
o.flatten(), o_grouped.flatten(), dim=0
|
||||
)
|
||||
print(cos_sim.item())
|
||||
self.assertTrue(cos_sim.item() > 0.99)
|
||||
self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2))
|
||||
|
||||
def test_grouped_decode_attention(self):
|
||||
configs = [
|
||||
(2, 16, 16, 64, 64),
|
||||
(2, 16, 1, 64, 64),
|
||||
(2, 64, 1, 13, 13),
|
||||
(2, 128, 1, 80, 80),
|
||||
|
||||
Reference in New Issue
Block a user