Optimize nvfp4 block scaled gemm kernel when M is small. (#10101)
This commit is contained in:
@@ -38,27 +38,74 @@ limitations under the License.
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
// Kernel Perf config
|
||||
// Config(half_t/bfloat16_t) for M <= 128
|
||||
template <typename T>
|
||||
struct KernelTraits {
|
||||
struct KernelConfigM128 {
|
||||
using OutputType = T;
|
||||
using MmaTileShape = Shape<_128, _256, _256>;
|
||||
using ClusterShape = Shape<int, int, _1>;
|
||||
using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
|
||||
const static dim3 preferred_cluster;
|
||||
const static dim3 fallback_cluster;
|
||||
};
|
||||
template <typename T>
|
||||
const dim3 KernelConfigM128<T>::preferred_cluster(1, 4, 1);
|
||||
template <typename T>
|
||||
const dim3 KernelConfigM128<T>::fallback_cluster(1, 2, 1);
|
||||
|
||||
// Config(half_t/bfloat16_t) for M <= 256
|
||||
template <typename T>
|
||||
struct KernelConfigM256 {
|
||||
using OutputType = T;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<int, int, _1>;
|
||||
using EpilogueTile = Shape<_128, _64>;
|
||||
using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
|
||||
const static dim3 preferred_cluster;
|
||||
const static dim3 fallback_cluster;
|
||||
};
|
||||
template <typename T>
|
||||
const dim3 KernelConfigM256<T>::preferred_cluster(2, 4, 1);
|
||||
template <typename T>
|
||||
const dim3 KernelConfigM256<T>::fallback_cluster(2, 1, 1);
|
||||
|
||||
template <>
|
||||
struct KernelTraits<float> {
|
||||
// Default config(half_t/bfloat16_t) for M > 256
|
||||
template <typename T>
|
||||
struct KernelConfigDefault {
|
||||
using OutputType = T;
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<int, int, _1>;
|
||||
using EpilogueTile = Shape<_128, _64>; // Avoid register spilling
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100;
|
||||
const static dim3 preferred_cluster;
|
||||
const static dim3 fallback_cluster;
|
||||
};
|
||||
template <typename T>
|
||||
const dim3 KernelConfigDefault<T>::preferred_cluster(4, 4, 1);
|
||||
template <typename T>
|
||||
const dim3 KernelConfigDefault<T>::fallback_cluster(2, 1, 1);
|
||||
|
||||
struct KernelConfigFp32 {
|
||||
using OutputType = float;
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<int, int, _1>;
|
||||
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm;
|
||||
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100;
|
||||
const static dim3 preferred_cluster;
|
||||
const static dim3 fallback_cluster;
|
||||
};
|
||||
const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1);
|
||||
const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1);
|
||||
|
||||
template <typename T>
|
||||
template <typename KernelConfig>
|
||||
struct Fp4GemmSm100 {
|
||||
using Config = KernelConfig; // For generating args
|
||||
using OutputType = typename KernelConfig::OutputType;
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
using LayoutATag = cutlass::layout::RowMajor;
|
||||
@@ -70,8 +117,8 @@ struct Fp4GemmSm100 {
|
||||
static constexpr int AlignmentB = 32;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = T;
|
||||
using ElementC = T;
|
||||
using ElementD = OutputType;
|
||||
using ElementC = OutputType;
|
||||
using LayoutCTag = cutlass::layout::RowMajor;
|
||||
using LayoutDTag = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
@@ -82,15 +129,15 @@ struct Fp4GemmSm100 {
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
||||
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
||||
using EpilogueTile = typename KernelTraits<T>::EpilogueTile;
|
||||
using EpilogueSchedule = typename KernelTraits<T>::EpilogueSchedule;
|
||||
using MainloopSchedule = typename KernelTraits<T>::MainloopSchedule;
|
||||
using MmaTileShape = typename KernelConfig::MmaTileShape;
|
||||
using ClusterShape = typename KernelConfig::ClusterShape;
|
||||
using EpilogueTile = typename KernelConfig::EpilogueTile;
|
||||
using EpilogueSchedule = typename KernelConfig::EpilogueSchedule;
|
||||
using MainloopSchedule = typename KernelConfig::MainloopSchedule;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
OperatorClass,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
EpilogueTile,
|
||||
@@ -182,19 +229,15 @@ typename T::Gemm::Arguments args_from_options(
|
||||
layout_SFB},
|
||||
{ // Epilogue arguments
|
||||
{}, // epilogue.thread
|
||||
static_cast<ElementD const*>(D.data_ptr()),
|
||||
nullptr,
|
||||
stride_D,
|
||||
static_cast<ElementD*>(D.data_ptr()),
|
||||
stride_D}};
|
||||
auto& fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
arguments.hw_info.cluster_shape = dim3(1, 4, 1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(1, 1, 1);
|
||||
} else {
|
||||
arguments.hw_info.cluster_shape = dim3(4, 4, 1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(2, 1, 1);
|
||||
}
|
||||
using KernelConfig = typename T::Config;
|
||||
arguments.hw_info.cluster_shape = KernelConfig::preferred_cluster;
|
||||
arguments.hw_info.cluster_shape_fallback = KernelConfig::fallback_cluster;
|
||||
return arguments;
|
||||
}
|
||||
|
||||
@@ -210,11 +253,10 @@ void runGemm(
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
cudaStream_t stream) {
|
||||
typename Fp4GemmSm100<T>::Gemm gemm;
|
||||
typename T::Gemm gemm;
|
||||
auto arguments = args_from_options<T>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
|
||||
auto arguments = args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
|
||||
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
|
||||
size_t workspace_size = T::Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
@@ -224,9 +266,51 @@ void runGemm(
|
||||
|
||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
// Dispatch function to select appropriate config based on M
|
||||
template <typename OutType>
|
||||
void cutlassFp4GemmDispatch(
|
||||
torch::Tensor& D,
|
||||
torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
cudaStream_t stream) {
|
||||
if (m <= 128) {
|
||||
// m in [1, 128]
|
||||
runGemm<Fp4GemmSm100<KernelConfigM128<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (m <= 256) {
|
||||
// m in (128, 256]
|
||||
runGemm<Fp4GemmSm100<KernelConfigM256<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else {
|
||||
// m in (256, inf)
|
||||
runGemm<Fp4GemmSm100<KernelConfigDefault<OutType>>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch function to select appropriate config based on M
|
||||
template <>
|
||||
void cutlassFp4GemmDispatch<float>(
|
||||
torch::Tensor& D,
|
||||
torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k,
|
||||
cudaStream_t stream) {
|
||||
runGemm<Fp4GemmSm100<KernelConfigFp32>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
}
|
||||
|
||||
#else
|
||||
template <typename T>
|
||||
void runGemm(
|
||||
void cutlassFp4GemmDispatch(
|
||||
at::Tensor& D,
|
||||
at::Tensor const& A,
|
||||
at::Tensor const& B,
|
||||
@@ -358,11 +442,11 @@ void cutlass_scaled_fp4_mm_sm100a(
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
cutlassFp4GemmDispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
cutlassFp4GemmDispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::Float) {
|
||||
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
cutlassFp4GemmDispatch<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user