Optimize nvfp4 block scaled gemm kernel when M is small. (#10101)
This commit is contained in:
@@ -38,27 +38,74 @@ limitations under the License.
|
|||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||||
// Kernel Perf config
|
// Config(half_t/bfloat16_t) for M <= 128
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct KernelTraits {
|
struct KernelConfigM128 {
|
||||||
|
using OutputType = T;
|
||||||
|
using MmaTileShape = Shape<_128, _256, _256>;
|
||||||
|
using ClusterShape = Shape<int, int, _1>;
|
||||||
|
using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||||
|
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
|
||||||
|
const static dim3 preferred_cluster;
|
||||||
|
const static dim3 fallback_cluster;
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
const dim3 KernelConfigM128<T>::preferred_cluster(1, 4, 1);
|
||||||
|
template <typename T>
|
||||||
|
const dim3 KernelConfigM128<T>::fallback_cluster(1, 2, 1);
|
||||||
|
|
||||||
|
// Config(half_t/bfloat16_t) for M <= 256
|
||||||
|
template <typename T>
|
||||||
|
struct KernelConfigM256 {
|
||||||
|
using OutputType = T;
|
||||||
using MmaTileShape = Shape<_256, _256, _256>;
|
using MmaTileShape = Shape<_256, _256, _256>;
|
||||||
using ClusterShape = Shape<int, int, _1>;
|
using ClusterShape = Shape<int, int, _1>;
|
||||||
using EpilogueTile = Shape<_128, _64>;
|
using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
|
||||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
|
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
|
||||||
|
const static dim3 preferred_cluster;
|
||||||
|
const static dim3 fallback_cluster;
|
||||||
};
|
};
|
||||||
|
template <typename T>
|
||||||
|
const dim3 KernelConfigM256<T>::preferred_cluster(2, 4, 1);
|
||||||
|
template <typename T>
|
||||||
|
const dim3 KernelConfigM256<T>::fallback_cluster(2, 1, 1);
|
||||||
|
|
||||||
template <>
|
// Default config(half_t/bfloat16_t) for M > 256
|
||||||
struct KernelTraits<float> {
|
template <typename T>
|
||||||
|
struct KernelConfigDefault {
|
||||||
|
using OutputType = T;
|
||||||
|
using MmaTileShape = Shape<_256, _256, _256>;
|
||||||
|
using ClusterShape = Shape<int, int, _1>;
|
||||||
|
using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||||
|
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
|
||||||
|
const static dim3 preferred_cluster;
|
||||||
|
const static dim3 fallback_cluster;
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
const dim3 KernelConfigDefault<T>::preferred_cluster(4, 4, 1);
|
||||||
|
template <typename T>
|
||||||
|
const dim3 KernelConfigDefault<T>::fallback_cluster(2, 1, 1);
|
||||||
|
|
||||||
|
struct KernelConfigFp32 {
|
||||||
|
using OutputType = float;
|
||||||
using MmaTileShape = Shape<_128, _128, _256>;
|
using MmaTileShape = Shape<_128, _128, _256>;
|
||||||
using ClusterShape = Shape<int, int, _1>;
|
using ClusterShape = Shape<int, int, _1>;
|
||||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
|
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
|
||||||
|
const static dim3 preferred_cluster;
|
||||||
|
const static dim3 fallback_cluster;
|
||||||
};
|
};
|
||||||
|
const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1);
|
||||||
|
const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1);
|
||||||
|
|
||||||
template <typename T>
|
template <typename KernelConfig>
|
||||||
struct Fp4GemmSm100 {
|
struct Fp4GemmSm100 {
|
||||||
|
using Config = KernelConfig; // For generating args
|
||||||
|
using OutputType = typename KernelConfig::OutputType;
|
||||||
// A matrix configuration
|
// A matrix configuration
|
||||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||||
using LayoutATag = cutlass::layout::RowMajor;
|
using LayoutATag = cutlass::layout::RowMajor;
|
||||||
@@ -70,8 +117,8 @@ struct Fp4GemmSm100 {
|
|||||||
static constexpr int AlignmentB = 32;
|
static constexpr int AlignmentB = 32;
|
||||||
|
|
||||||
// C/D matrix configuration
|
// C/D matrix configuration
|
||||||
using ElementD = T;
|
using ElementD = OutputType;
|
||||||
using ElementC = T;
|
using ElementC = OutputType;
|
||||||
using LayoutCTag = cutlass::layout::RowMajor;
|
using LayoutCTag = cutlass::layout::RowMajor;
|
||||||
using LayoutDTag = cutlass::layout::RowMajor;
|
using LayoutDTag = cutlass::layout::RowMajor;
|
||||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
@@ -82,15 +129,15 @@ struct Fp4GemmSm100 {
|
|||||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||||
|
|
||||||
// Kernel Perf config
|
// Kernel Perf config
|
||||||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
using MmaTileShape = typename KernelConfig::MmaTileShape;
|
||||||
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
using ClusterShape = typename KernelConfig::ClusterShape;
|
||||||
using EpilogueTile = typename KernelTraits<T>::EpilogueTile;
|
using EpilogueTile = typename KernelConfig::EpilogueTile;
|
||||||
using EpilogueSchedule = typename KernelTraits<T>::EpilogueSchedule;
|
using EpilogueSchedule = typename KernelConfig::EpilogueSchedule;
|
||||||
using MainloopSchedule = typename KernelTraits<T>::MainloopSchedule;
|
using MainloopSchedule = typename KernelConfig::MainloopSchedule;
|
||||||
|
|
||||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
ArchTag,
|
ArchTag,
|
||||||
cutlass::arch::OpClassTensorOp,
|
OperatorClass,
|
||||||
MmaTileShape,
|
MmaTileShape,
|
||||||
ClusterShape,
|
ClusterShape,
|
||||||
EpilogueTile,
|
EpilogueTile,
|
||||||
@@ -182,19 +229,15 @@ typename T::Gemm::Arguments args_from_options(
|
|||||||
layout_SFB},
|
layout_SFB},
|
||||||
{ // Epilogue arguments
|
{ // Epilogue arguments
|
||||||
{}, // epilogue.thread
|
{}, // epilogue.thread
|
||||||
static_cast<ElementD const*>(D.data_ptr()),
|
nullptr,
|
||||||
stride_D,
|
stride_D,
|
||||||
static_cast<ElementD*>(D.data_ptr()),
|
static_cast<ElementD*>(D.data_ptr()),
|
||||||
stride_D}};
|
stride_D}};
|
||||||
auto& fusion_args = arguments.epilogue.thread;
|
auto& fusion_args = arguments.epilogue.thread;
|
||||||
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
|
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
|
||||||
if constexpr (std::is_same_v<T, float>) {
|
using KernelConfig = typename T::Config;
|
||||||
arguments.hw_info.cluster_shape = dim3(1, 4, 1);
|
arguments.hw_info.cluster_shape = KernelConfig::preferred_cluster;
|
||||||
arguments.hw_info.cluster_shape_fallback = dim3(1, 1, 1);
|
arguments.hw_info.cluster_shape_fallback = KernelConfig::fallback_cluster;
|
||||||
} else {
|
|
||||||
arguments.hw_info.cluster_shape = dim3(4, 4, 1);
|
|
||||||
arguments.hw_info.cluster_shape_fallback = dim3(2, 1, 1);
|
|
||||||
}
|
|
||||||
return arguments;
|
return arguments;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,11 +253,10 @@ void runGemm(
|
|||||||
int64_t n,
|
int64_t n,
|
||||||
int64_t k,
|
int64_t k,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
typename Fp4GemmSm100<T>::Gemm gemm;
|
typename T::Gemm gemm;
|
||||||
|
auto arguments = args_from_options<T>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||||
|
|
||||||
auto arguments = args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
size_t workspace_size = T::Gemm::get_workspace_size(arguments);
|
||||||
|
|
||||||
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
|
|
||||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||||
|
|
||||||
@@ -224,9 +266,51 @@ void runGemm(
|
|||||||
|
|
||||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dispatch function to select appropriate config based on M
|
||||||
|
template <typename OutType>
|
||||||
|
void cutlassFp4GemmDispatch(
|
||||||
|
torch::Tensor& D,
|
||||||
|
torch::Tensor const& A,
|
||||||
|
torch::Tensor const& B,
|
||||||
|
torch::Tensor const& A_sf,
|
||||||
|
torch::Tensor const& B_sf,
|
||||||
|
torch::Tensor const& alpha,
|
||||||
|
int64_t m,
|
||||||
|
int64_t n,
|
||||||
|
int64_t k,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
if (m <= 128) {
|
||||||
|
// m in [1, 128]
|
||||||
|
runGemm<Fp4GemmSm100<KernelConfigM128<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||||
|
} else if (m <= 256) {
|
||||||
|
// m in (128, 256]
|
||||||
|
runGemm<Fp4GemmSm100<KernelConfigM256<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||||
|
} else {
|
||||||
|
// m in (256, inf)
|
||||||
|
runGemm<Fp4GemmSm100<KernelConfigDefault<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatch function to select appropriate config based on M
|
||||||
|
template <>
|
||||||
|
void cutlassFp4GemmDispatch<float>(
|
||||||
|
torch::Tensor& D,
|
||||||
|
torch::Tensor const& A,
|
||||||
|
torch::Tensor const& B,
|
||||||
|
torch::Tensor const& A_sf,
|
||||||
|
torch::Tensor const& B_sf,
|
||||||
|
torch::Tensor const& alpha,
|
||||||
|
int64_t m,
|
||||||
|
int64_t n,
|
||||||
|
int64_t k,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
runGemm<Fp4GemmSm100<KernelConfigFp32>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||||
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void runGemm(
|
void cutlassFp4GemmDispatch(
|
||||||
at::Tensor& D,
|
at::Tensor& D,
|
||||||
at::Tensor const& A,
|
at::Tensor const& A,
|
||||||
at::Tensor const& B,
|
at::Tensor const& B,
|
||||||
@@ -358,11 +442,11 @@ void cutlass_scaled_fp4_mm_sm100a(
|
|||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||||
|
|
||||||
if (out_dtype == at::ScalarType::Half) {
|
if (out_dtype == at::ScalarType::Half) {
|
||||||
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
cutlassFp4GemmDispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||||
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
cutlassFp4GemmDispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||||
} else if (out_dtype == at::ScalarType::Float) {
|
} else if (out_dtype == at::ScalarType::Float) {
|
||||||
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
cutlassFp4GemmDispatch<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
|
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user