[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)
This commit is contained in:
@@ -93,7 +93,7 @@ def test_cutlass_mla_decode(
|
||||
out_ref = q.new_zeros(bs, h_q, dv)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||
out = cutlass_mla_decode(
|
||||
q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, num_kv_splits
|
||||
q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, scale, num_kv_splits
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user