[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)

This commit is contained in:
JieXin Liang
2025-06-15 03:45:41 +08:00
committed by GitHub
parent ed54bf9d19
commit ab1a4fa5cb
7 changed files with 29 additions and 17 deletions

View File

@@ -93,7 +93,7 @@ def test_cutlass_mla_decode(
out_ref = q.new_zeros(bs, h_q, dv)
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
out = cutlass_mla_decode(
q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, num_kv_splits
q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, scale, num_kv_splits
)
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)