Remove 200us slow concat kernel (part 1: kernel) (#7145)
This commit is contained in:
@@ -38,6 +38,7 @@ configs = list(itertools.product(bs_range, qlen_range))
|
||||
)
|
||||
def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
||||
d = 576
|
||||
dn = 64
|
||||
dv = 512
|
||||
|
||||
h_q_map = {
|
||||
@@ -63,7 +64,11 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
||||
pack_factor = 128 // block_size
|
||||
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
|
||||
|
||||
q = torch.randn(batch_size, h_q, d, dtype=torch.bfloat16, device="cuda") * 100.0
|
||||
qn = (
|
||||
torch.randn(h_q, batch_size, d - dn, dtype=torch.bfloat16, device="cuda")
|
||||
* 100.0
|
||||
)
|
||||
qr = torch.randn(batch_size, h_q, dn, dtype=torch.bfloat16, device="cuda") * 100.0
|
||||
block_table = torch.randint(
|
||||
0,
|
||||
batch_size * block_num,
|
||||
@@ -84,16 +89,22 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: cutlass_mla_decode(
|
||||
q, kv_cache, seq_lens, block_table, workspace, num_kv_splits
|
||||
qn.transpose(0, 1),
|
||||
qr,
|
||||
kv_cache,
|
||||
seq_lens,
|
||||
block_table,
|
||||
workspace,
|
||||
num_kv_splits,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
q_size = qn.numel() * qn.element_size() + qr.numel() * qr.element_size()
|
||||
|
||||
gbps = (
|
||||
lambda ms: (
|
||||
q.numel() * q.element_size()
|
||||
+ q.numel() * q.element_size() * dv / d
|
||||
+ kv_cache.numel() * kv_cache.element_size()
|
||||
q_size + q_size * dv / d + kv_cache.numel() * kv_cache.element_size()
|
||||
)
|
||||
* 1e-9
|
||||
/ (ms * 1e-3)
|
||||
|
||||
Reference in New Issue
Block a user