Optimize Triton decoding kernel for dynamic workload (#4553)

This commit is contained in:
JieXin Liang
2025-03-19 12:25:38 +08:00
committed by GitHub
parent 588865f0e0
commit c0e9a36c5f
7 changed files with 277 additions and 57 deletions

View File

@@ -228,7 +228,8 @@ class TestTritonAttention(unittest.TestCase):
seq_len = 10 # This represents the number of tokens already in the sequence
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8
max_kv_splits = 8
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
@@ -247,7 +248,12 @@ class TestTritonAttention(unittest.TestCase):
kv_indices = torch.arange(total_tokens, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D + 1),
(B, H_Q, max_kv_splits, D),
dtype=torch.float32,
device="cuda",
)
attn_lse = torch.empty(
(B, H_Q, max_kv_splits),
dtype=torch.float32,
device="cuda",
)
@@ -259,8 +265,9 @@ class TestTritonAttention(unittest.TestCase):
o,
kv_indptr,
kv_indices,
attn_logits,
(attn_logits, attn_lse),
num_kv_splits,
max_kv_splits,
sm_scale,
)
@@ -284,7 +291,8 @@ class TestTritonAttention(unittest.TestCase):
seq_len = S # This represents the number of tokens already in the sequence
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
num_kv_splits = 8
max_kv_splits = 8
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
@@ -304,7 +312,12 @@ class TestTritonAttention(unittest.TestCase):
kv_indices = torch.arange(total_tokens, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
(B, H_Q, max_kv_splits, D_V),
dtype=torch.float32,
device="cuda",
)
attn_lse = torch.empty(
(B, H_Q, max_kv_splits),
dtype=torch.float32,
device="cuda",
)
@@ -316,13 +329,19 @@ class TestTritonAttention(unittest.TestCase):
o,
kv_indptr,
kv_indices,
attn_logits,
(attn_logits, attn_lse),
num_kv_splits,
max_kv_splits,
sm_scale,
)
attn_logits1 = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
(B, H_Q, max_kv_splits, D_V),
dtype=torch.float32,
device="cuda",
)
attn_lse1 = torch.empty(
(B, H_Q, max_kv_splits, D_V),
dtype=torch.float32,
device="cuda",
)
@@ -334,8 +353,9 @@ class TestTritonAttention(unittest.TestCase):
o_grouped,
kv_indptr,
kv_indices,
attn_logits1,
(attn_logits1, attn_lse1),
num_kv_splits,
max_kv_splits,
sm_scale,
)