diff --git a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu index 53219ea86..6fb9e83f3 100644 --- a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu +++ b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu @@ -151,7 +151,10 @@ typename T::Fmha::Arguments args_from_options( page_size}, {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, hw_info, - -1, // split_kv + // TODO(trevor-m): Change split_kv back to -1 when + // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will + // perform worse with larger context length and smaller batch sizes. + 1, // split_kv nullptr, // is_var_split_kv }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute diff --git a/sgl-kernel/tests/test_cutlass_mla.py b/sgl-kernel/tests/test_cutlass_mla.py index a8aab8d15..d6e506490 100644 --- a/sgl-kernel/tests/test_cutlass_mla.py +++ b/sgl-kernel/tests/test_cutlass_mla.py @@ -67,7 +67,7 @@ def test_cutlass_mla_decode( pack_factor = 128 // block_size block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor - q = torch.randn(bs, h_q, d) + q = torch.randn(bs, h_q, d) * 100.0 block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32) kv_cache = torch.randn(block_table.numel(), block_size, d)