Remove 200us slow concat kernel (part 1: kernel) (#7145)

This commit is contained in:
fzyzcjy
2025-06-13 16:58:29 +08:00
committed by GitHub
parent 2f4ec752bc
commit aa46ed34d2
6 changed files with 79 additions and 48 deletions

View File

@@ -86,10 +86,14 @@ def test_cutlass_mla_decode(
)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
q_nope = torch.empty((h_q, bs, dv)).transpose(0, 1)
q_nope.copy_(q[:, :, :dv])
q_pe = q[:, :, dv:].clone()
out_ref = q.new_zeros(bs, h_q, dv)
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
out = cutlass_mla_decode(
q, kv_cache, seq_lens, block_table, workspace, num_kv_splits
q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, num_kv_splits
)
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)