[Build] Fix cuda12.8 build error in nvfp4_scaled_mm_kernels.cu (#4953)

This commit is contained in:
Yuhong Guo
2025-04-01 03:00:34 +08:00
committed by GitHub
parent 6384d31776
commit ee47a6c1c3

View File

@@ -159,7 +159,7 @@ typename T::Gemm::Arguments args_from_options(
using StrideA = typename T::StrideA;
using StrideB = typename T::StrideB;
using StrideD = typename T::StrideD;
using Sm100BlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
using Sm1xxBlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
int m = static_cast<int>(M);
int n = static_cast<int>(N);
@@ -168,8 +168,8 @@ typename T::Gemm::Arguments args_from_options(
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
typename T::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,