[Build] Fix cuda12.8 build error in nvfp4_scaled_mm_kernels.cu (#4953)
This commit is contained in:
@@ -159,7 +159,7 @@ typename T::Gemm::Arguments args_from_options(
|
|||||||
using StrideA = typename T::StrideA;
|
using StrideA = typename T::StrideA;
|
||||||
using StrideB = typename T::StrideB;
|
using StrideB = typename T::StrideB;
|
||||||
using StrideD = typename T::StrideD;
|
using StrideD = typename T::StrideD;
|
||||||
using Sm100BlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
using Sm1xxBlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||||
|
|
||||||
int m = static_cast<int>(M);
|
int m = static_cast<int>(M);
|
||||||
int n = static_cast<int>(N);
|
int n = static_cast<int>(N);
|
||||||
@@ -168,8 +168,8 @@ typename T::Gemm::Arguments args_from_options(
|
|||||||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
|
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
|
||||||
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
|
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
|
||||||
|
|
||||||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
|
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
|
||||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
|
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
|
||||||
|
|
||||||
typename T::Gemm::Arguments arguments{
|
typename T::Gemm::Arguments arguments{
|
||||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||||
|
|||||||
Reference in New Issue
Block a user