Optimize Triton decoding kernel for dynamic workload (#4553)
This commit is contained in:
@@ -228,7 +228,8 @@ class TestTritonAttention(unittest.TestCase):
|
||||
seq_len = 10 # This represents the number of tokens already in the sequence
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
num_kv_splits = 8
|
||||
max_kv_splits = 8
|
||||
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
@@ -247,7 +248,12 @@ class TestTritonAttention(unittest.TestCase):
|
||||
kv_indices = torch.arange(total_tokens, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D + 1),
|
||||
(B, H_Q, max_kv_splits, D),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
attn_lse = torch.empty(
|
||||
(B, H_Q, max_kv_splits),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
@@ -259,8 +265,9 @@ class TestTritonAttention(unittest.TestCase):
|
||||
o,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
attn_logits,
|
||||
(attn_logits, attn_lse),
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
@@ -284,7 +291,8 @@ class TestTritonAttention(unittest.TestCase):
|
||||
seq_len = S # This represents the number of tokens already in the sequence
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
num_kv_splits = 8
|
||||
max_kv_splits = 8
|
||||
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
@@ -304,7 +312,12 @@ class TestTritonAttention(unittest.TestCase):
|
||||
kv_indices = torch.arange(total_tokens, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
(B, H_Q, max_kv_splits, D_V),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
attn_lse = torch.empty(
|
||||
(B, H_Q, max_kv_splits),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
@@ -316,13 +329,19 @@ class TestTritonAttention(unittest.TestCase):
|
||||
o,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
attn_logits,
|
||||
(attn_logits, attn_lse),
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
attn_logits1 = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
(B, H_Q, max_kv_splits, D_V),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
attn_lse1 = torch.empty(
|
||||
(B, H_Q, max_kv_splits, D_V),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
@@ -334,8 +353,9 @@ class TestTritonAttention(unittest.TestCase):
|
||||
o_grouped,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
attn_logits1,
|
||||
(attn_logits1, attn_lse1),
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user