Remove 200us slow concat kernel (part 1: kernel) (#7145)
This commit is contained in:
@@ -86,10 +86,14 @@ def test_cutlass_mla_decode(
|
||||
)
|
||||
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
||||
|
||||
q_nope = torch.empty((h_q, bs, dv)).transpose(0, 1)
|
||||
q_nope.copy_(q[:, :, :dv])
|
||||
q_pe = q[:, :, dv:].clone()
|
||||
|
||||
out_ref = q.new_zeros(bs, h_q, dv)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||
out = cutlass_mla_decode(
|
||||
q, kv_cache, seq_lens, block_table, workspace, num_kv_splits
|
||||
q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, num_kv_splits
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user