[fix]: fix cutlass moe ut and and Opt H20 cutlass groupGemm performance (#9272)
Co-authored-by: wanghanpei <wanghanpei@bytedance.com>
This commit is contained in:
@@ -153,9 +153,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False, # Use False for benchmarking to avoid side effects if run multiple times
|
||||
(topk_weights, topk_ids, "dummy"),
|
||||
inplace=False,
|
||||
activation="silu", # Assuming SiLU activation common in MoEs
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
@@ -221,8 +220,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
x,
|
||||
w1, # Original shape
|
||||
w2, # Original shape
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
(topk_weights, topk_ids, "dummy"),
|
||||
inplace=False, # Important: Use False to get output tensor
|
||||
activation="silu",
|
||||
use_fp8_w8a8=True,
|
||||
@@ -266,7 +264,7 @@ if __name__ == "__main__":
|
||||
"--batch-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 4, 8, 16, 32, 64, 128, 256, 512], # Adjusted default
|
||||
default=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # Adjusted default
|
||||
help="List of batch sizes to test",
|
||||
)
|
||||
parser.add_argument("--check", action="store_true", help="Enable check mode")
|
||||
|
||||
@@ -437,6 +437,34 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
}
|
||||
}
|
||||
|
||||
#define JOIN_STRUCT_PP_NAME(m, n, k, a, b, c) sm90_fp8_pp_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||
|
||||
#define JOIN_STRUCT_CO_NAME(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||
|
||||
#define GENERATE_SM90_FP8_PP_CONFIG(M, N, K, A, B, C) \
|
||||
struct JOIN_STRUCT_PP_NAME(M, N, K, A, B, C) { \
|
||||
using ElementA = cutlass::float_e4m3_t; \
|
||||
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; \
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
|
||||
};
|
||||
|
||||
#define GENERATE_SM90_FP8_CO_CONFIG(M, N, K, A, B, C) \
|
||||
struct JOIN_STRUCT_CO_NAME(M, N, K, A, B, C) { \
|
||||
using ElementA = cutlass::float_e4m3_t; \
|
||||
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; \
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
|
||||
};
|
||||
|
||||
template <typename OutType>
|
||||
void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
torch::Tensor& output,
|
||||
@@ -481,13 +509,24 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
// [NOTE] Tuned for H20
|
||||
GENERATE_SM90_FP8_PP_CONFIG(64, 128, 128, 1, 2, 1)
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||
|
||||
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
|
||||
// For H20 with K > 128, use Pingpong Schedule
|
||||
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
|
||||
bool tuning_H20_kernel = getBoolEnv("SGL_TUNE_DEVICE_KERNEL");
|
||||
|
||||
const std::string H20_device_type_str = "NVIDIA H20";
|
||||
bool is_h20 = isDeviceType(H20_device_type_str);
|
||||
|
||||
if (is_h20 && tuning_H20_kernel) {
|
||||
using execute_gemm_config = sm90_fp8_pp_config_64_128_128_1_2_1;
|
||||
run_get_group_gemm_starts<
|
||||
execute_gemm_config::LayoutSFA,
|
||||
execute_gemm_config::LayoutSFB,
|
||||
execute_gemm_config::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -503,7 +542,8 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
|
||||
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, execute_gemm_config, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -518,37 +558,71 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
expert_offsets,
|
||||
workspace);
|
||||
} else {
|
||||
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
|
||||
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
|
||||
// For H20 with K > 128, use Pingpong Schedule
|
||||
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
} else {
|
||||
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
|
||||
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -254,6 +254,25 @@ inline int getSMVersion() {
|
||||
return sm_major * 10 + sm_minor;
|
||||
}
|
||||
|
||||
inline bool isDeviceType(const std::string& device_type) {
|
||||
int deviceCount;
|
||||
CHECK_CUDA_SUCCESS(cudaGetDeviceCount(&deviceCount));
|
||||
|
||||
int device_id = -1;
|
||||
if (deviceCount >= 1) {
|
||||
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaDeviceProp prop;
|
||||
CHECK_CUDA_SUCCESS(cudaGetDeviceProperties(&prop, device_id));
|
||||
if (device_type == std::string(prop.name)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool getBoolEnv(char const* name) {
|
||||
char const* env = std::getenv(name);
|
||||
return env && env[0] == '1' && env[1] == '\0';
|
||||
|
||||
Reference in New Issue
Block a user