[code clean] add H20 cutlass groupGemm default config (#9333)
Co-authored-by: wanghanpei <wanghanpei@bytedance.com>
This commit is contained in:
@@ -437,34 +437,6 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define JOIN_STRUCT_PP_NAME(m, n, k, a, b, c) sm90_fp8_pp_config##_##m##_##n##_##k##_##a##_##b##_##c
|
|
||||||
|
|
||||||
#define JOIN_STRUCT_CO_NAME(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
|
|
||||||
|
|
||||||
#define GENERATE_SM90_FP8_PP_CONFIG(M, N, K, A, B, C) \
|
|
||||||
struct JOIN_STRUCT_PP_NAME(M, N, K, A, B, C) { \
|
|
||||||
using ElementA = cutlass::float_e4m3_t; \
|
|
||||||
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
|
||||||
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
|
||||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; \
|
|
||||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
|
|
||||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
|
|
||||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
|
|
||||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
|
|
||||||
};
|
|
||||||
|
|
||||||
#define GENERATE_SM90_FP8_CO_CONFIG(M, N, K, A, B, C) \
|
|
||||||
struct JOIN_STRUCT_CO_NAME(M, N, K, A, B, C) { \
|
|
||||||
using ElementA = cutlass::float_e4m3_t; \
|
|
||||||
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
|
||||||
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
|
||||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; \
|
|
||||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
|
|
||||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
|
|
||||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
|
|
||||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename OutType>
|
template <typename OutType>
|
||||||
void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||||
torch::Tensor& output,
|
torch::Tensor& output,
|
||||||
@@ -509,20 +481,28 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||||
};
|
};
|
||||||
|
|
||||||
// [NOTE] Tuned for H20
|
// [NOTE] default for H20
|
||||||
GENERATE_SM90_FP8_PP_CONFIG(64, 128, 128, 1, 2, 1)
|
struct MmaConfigH20_default {
|
||||||
|
using ElementA = cutlass::float_e4m3_t;
|
||||||
|
using MmaTileShape = Shape<_64, _128, _128>;
|
||||||
|
using ClusterShape = Shape<_1, _2, _1>;
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||||
|
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
|
||||||
|
|
||||||
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||||
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||||
|
};
|
||||||
|
|
||||||
int num_experts = (int)expert_offsets.size(0);
|
int num_experts = (int)expert_offsets.size(0);
|
||||||
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||||
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||||
|
|
||||||
bool tuning_H20_kernel = getBoolEnv("SGL_TUNE_DEVICE_KERNEL");
|
|
||||||
|
|
||||||
const std::string H20_device_type_str = "NVIDIA H20";
|
const std::string H20_device_type_str = "NVIDIA H20";
|
||||||
bool is_h20 = isDeviceType(H20_device_type_str);
|
bool is_h20_device = isDeviceType(H20_device_type_str);
|
||||||
|
|
||||||
if (is_h20 && tuning_H20_kernel) {
|
if (is_h20_device) {
|
||||||
using execute_gemm_config = sm90_fp8_pp_config_64_128_128_1_2_1;
|
using execute_gemm_config = MmaConfigH20_default;
|
||||||
run_get_group_gemm_starts<
|
run_get_group_gemm_starts<
|
||||||
execute_gemm_config::LayoutSFA,
|
execute_gemm_config::LayoutSFA,
|
||||||
execute_gemm_config::LayoutSFB,
|
execute_gemm_config::LayoutSFB,
|
||||||
|
|||||||
Reference in New Issue
Block a user