[Build] Fix cuda12.8 build error in nvfp4_scaled_mm_kernels.cu (#4953)
This commit is contained in:
@@ -159,7 +159,7 @@ typename T::Gemm::Arguments args_from_options(
|
||||
using StrideA = typename T::StrideA;
|
||||
using StrideB = typename T::StrideB;
|
||||
using StrideD = typename T::StrideD;
|
||||
using Sm100BlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
using Sm1xxBlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
int m = static_cast<int>(M);
|
||||
int n = static_cast<int>(N);
|
||||
@@ -168,8 +168,8 @@ typename T::Gemm::Arguments args_from_options(
|
||||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
|
||||
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
|
||||
|
||||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
|
||||
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
|
||||
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
|
||||
|
||||
typename T::Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
|
||||
Reference in New Issue
Block a user