diff --git a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu index c3da779f6..4fc4972dc 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu @@ -159,7 +159,7 @@ typename T::Gemm::Arguments args_from_options( using StrideA = typename T::StrideA; using StrideB = typename T::StrideB; using StrideD = typename T::StrideD; - using Sm100BlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + using Sm1xxBlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; int m = static_cast(M); int n = static_cast(N); @@ -168,8 +168,8 @@ typename T::Gemm::Arguments args_from_options( auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); - auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); - auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); typename T::Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm,