Remove 200us slow concat kernel (part 1: kernel) (#7145)

This commit is contained in:
fzyzcjy
2025-06-13 16:58:29 +08:00
committed by GitHub
parent 2f4ec752bc
commit aa46ed34d2
6 changed files with 79 additions and 48 deletions

View File

@@ -38,6 +38,7 @@ configs = list(itertools.product(bs_range, qlen_range))
)
def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
d = 576
dn = 64
dv = 512
h_q_map = {
@@ -63,7 +64,11 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
q = torch.randn(batch_size, h_q, d, dtype=torch.bfloat16, device="cuda") * 100.0
qn = (
torch.randn(h_q, batch_size, d - dn, dtype=torch.bfloat16, device="cuda")
* 100.0
)
qr = torch.randn(batch_size, h_q, dn, dtype=torch.bfloat16, device="cuda") * 100.0
block_table = torch.randint(
0,
batch_size * block_num,
@@ -84,16 +89,22 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: cutlass_mla_decode(
q, kv_cache, seq_lens, block_table, workspace, num_kv_splits
qn.transpose(0, 1),
qr,
kv_cache,
seq_lens,
block_table,
workspace,
num_kv_splits,
),
quantiles=quantiles,
)
q_size = qn.numel() * qn.element_size() + qr.numel() * qr.element_size()
gbps = (
lambda ms: (
q.numel() * q.element_size()
+ q.numel() * q.element_size() * dv / d
+ kv_cache.numel() * kv_cache.element_size()
q_size + q_size * dv / d + kv_cache.numel() * kv_cache.element_size()
)
* 1e-9
/ (ms * 1e-3)