[fix]: fix cutlass moe ut and and Opt H20 cutlass groupGemm performance (#9272)
Co-authored-by: wanghanpei <wanghanpei@bytedance.com>
This commit is contained in:
@@ -153,9 +153,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
|||||||
x,
|
x,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
topk_weights,
|
(topk_weights, topk_ids, "dummy"),
|
||||||
topk_ids,
|
inplace=False,
|
||||||
inplace=False, # Use False for benchmarking to avoid side effects if run multiple times
|
|
||||||
activation="silu", # Assuming SiLU activation common in MoEs
|
activation="silu", # Assuming SiLU activation common in MoEs
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
@@ -221,8 +220,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
|||||||
x,
|
x,
|
||||||
w1, # Original shape
|
w1, # Original shape
|
||||||
w2, # Original shape
|
w2, # Original shape
|
||||||
topk_weights,
|
(topk_weights, topk_ids, "dummy"),
|
||||||
topk_ids,
|
|
||||||
inplace=False, # Important: Use False to get output tensor
|
inplace=False, # Important: Use False to get output tensor
|
||||||
activation="silu",
|
activation="silu",
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
@@ -266,7 +264,7 @@ if __name__ == "__main__":
|
|||||||
"--batch-sizes",
|
"--batch-sizes",
|
||||||
type=int,
|
type=int,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=[1, 4, 8, 16, 32, 64, 128, 256, 512], # Adjusted default
|
default=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # Adjusted default
|
||||||
help="List of batch sizes to test",
|
help="List of batch sizes to test",
|
||||||
)
|
)
|
||||||
parser.add_argument("--check", action="store_true", help="Enable check mode")
|
parser.add_argument("--check", action="store_true", help="Enable check mode")
|
||||||
|
|||||||
@@ -437,6 +437,34 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define JOIN_STRUCT_PP_NAME(m, n, k, a, b, c) sm90_fp8_pp_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||||
|
|
||||||
|
#define JOIN_STRUCT_CO_NAME(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
|
||||||
|
|
||||||
|
#define GENERATE_SM90_FP8_PP_CONFIG(M, N, K, A, B, C) \
|
||||||
|
struct JOIN_STRUCT_PP_NAME(M, N, K, A, B, C) { \
|
||||||
|
using ElementA = cutlass::float_e4m3_t; \
|
||||||
|
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||||
|
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; \
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
|
||||||
|
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
|
||||||
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
|
||||||
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
|
||||||
|
};
|
||||||
|
|
||||||
|
#define GENERATE_SM90_FP8_CO_CONFIG(M, N, K, A, B, C) \
|
||||||
|
struct JOIN_STRUCT_CO_NAME(M, N, K, A, B, C) { \
|
||||||
|
using ElementA = cutlass::float_e4m3_t; \
|
||||||
|
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
|
||||||
|
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
|
||||||
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; \
|
||||||
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
|
||||||
|
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
|
||||||
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
|
||||||
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
|
||||||
|
};
|
||||||
|
|
||||||
template <typename OutType>
|
template <typename OutType>
|
||||||
void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||||
torch::Tensor& output,
|
torch::Tensor& output,
|
||||||
@@ -481,13 +509,24 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// [NOTE] Tuned for H20
|
||||||
|
GENERATE_SM90_FP8_PP_CONFIG(64, 128, 128, 1, 2, 1)
|
||||||
|
|
||||||
int num_experts = (int)expert_offsets.size(0);
|
int num_experts = (int)expert_offsets.size(0);
|
||||||
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||||
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||||
|
|
||||||
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
|
bool tuning_H20_kernel = getBoolEnv("SGL_TUNE_DEVICE_KERNEL");
|
||||||
// For H20 with K > 128, use Pingpong Schedule
|
|
||||||
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
|
const std::string H20_device_type_str = "NVIDIA H20";
|
||||||
|
bool is_h20 = isDeviceType(H20_device_type_str);
|
||||||
|
|
||||||
|
if (is_h20 && tuning_H20_kernel) {
|
||||||
|
using execute_gemm_config = sm90_fp8_pp_config_64_128_128_1_2_1;
|
||||||
|
run_get_group_gemm_starts<
|
||||||
|
execute_gemm_config::LayoutSFA,
|
||||||
|
execute_gemm_config::LayoutSFB,
|
||||||
|
execute_gemm_config::ScaleConfig>(
|
||||||
expert_offsets,
|
expert_offsets,
|
||||||
a_ptrs,
|
a_ptrs,
|
||||||
b_ptrs,
|
b_ptrs,
|
||||||
@@ -503,7 +542,8 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
layout_sfb,
|
layout_sfb,
|
||||||
problem_sizes,
|
problem_sizes,
|
||||||
problem_sizes_transpose);
|
problem_sizes_transpose);
|
||||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
|
|
||||||
|
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, execute_gemm_config, cutlass::layout::RowMajor>(
|
||||||
out_ptrs,
|
out_ptrs,
|
||||||
a_ptrs,
|
a_ptrs,
|
||||||
b_ptrs,
|
b_ptrs,
|
||||||
@@ -518,37 +558,71 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
expert_offsets,
|
expert_offsets,
|
||||||
workspace);
|
workspace);
|
||||||
} else {
|
} else {
|
||||||
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
|
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
|
||||||
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
// For H20 with K > 128, use Pingpong Schedule
|
||||||
expert_offsets,
|
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
|
||||||
a_ptrs,
|
expert_offsets,
|
||||||
b_ptrs,
|
a_ptrs,
|
||||||
out_ptrs,
|
b_ptrs,
|
||||||
a_scales_ptrs,
|
out_ptrs,
|
||||||
b_scales_ptrs,
|
a_scales_ptrs,
|
||||||
a,
|
b_scales_ptrs,
|
||||||
b,
|
a,
|
||||||
output,
|
b,
|
||||||
scales_a,
|
output,
|
||||||
scales_b,
|
scales_a,
|
||||||
layout_sfa,
|
scales_b,
|
||||||
layout_sfb,
|
layout_sfa,
|
||||||
problem_sizes,
|
layout_sfb,
|
||||||
problem_sizes_transpose);
|
problem_sizes,
|
||||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
|
problem_sizes_transpose);
|
||||||
out_ptrs,
|
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
|
||||||
a_ptrs,
|
out_ptrs,
|
||||||
b_ptrs,
|
a_ptrs,
|
||||||
a_scales_ptrs,
|
b_ptrs,
|
||||||
b_scales_ptrs,
|
a_scales_ptrs,
|
||||||
stride_a,
|
b_scales_ptrs,
|
||||||
stride_b,
|
stride_a,
|
||||||
stride_c,
|
stride_b,
|
||||||
layout_sfa,
|
stride_c,
|
||||||
layout_sfb,
|
layout_sfa,
|
||||||
problem_sizes,
|
layout_sfb,
|
||||||
expert_offsets,
|
problem_sizes,
|
||||||
workspace);
|
expert_offsets,
|
||||||
|
workspace);
|
||||||
|
} else {
|
||||||
|
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
|
||||||
|
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||||
|
expert_offsets,
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
out_ptrs,
|
||||||
|
a_scales_ptrs,
|
||||||
|
b_scales_ptrs,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
output,
|
||||||
|
scales_a,
|
||||||
|
scales_b,
|
||||||
|
layout_sfa,
|
||||||
|
layout_sfb,
|
||||||
|
problem_sizes,
|
||||||
|
problem_sizes_transpose);
|
||||||
|
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
|
||||||
|
out_ptrs,
|
||||||
|
a_ptrs,
|
||||||
|
b_ptrs,
|
||||||
|
a_scales_ptrs,
|
||||||
|
b_scales_ptrs,
|
||||||
|
stride_a,
|
||||||
|
stride_b,
|
||||||
|
stride_c,
|
||||||
|
layout_sfa,
|
||||||
|
layout_sfb,
|
||||||
|
problem_sizes,
|
||||||
|
expert_offsets,
|
||||||
|
workspace);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -254,6 +254,25 @@ inline int getSMVersion() {
|
|||||||
return sm_major * 10 + sm_minor;
|
return sm_major * 10 + sm_minor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool isDeviceType(const std::string& device_type) {
|
||||||
|
int deviceCount;
|
||||||
|
CHECK_CUDA_SUCCESS(cudaGetDeviceCount(&deviceCount));
|
||||||
|
|
||||||
|
int device_id = -1;
|
||||||
|
if (deviceCount >= 1) {
|
||||||
|
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaDeviceProp prop;
|
||||||
|
CHECK_CUDA_SUCCESS(cudaGetDeviceProperties(&prop, device_id));
|
||||||
|
if (device_type == std::string(prop.name)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool getBoolEnv(char const* name) {
|
inline bool getBoolEnv(char const* name) {
|
||||||
char const* env = std::getenv(name);
|
char const* env = std::getenv(name);
|
||||||
return env && env[0] == '1' && env[1] == '\0';
|
return env && env[0] == '1' && env[1] == '\0';
|
||||||
|
|||||||
Reference in New Issue
Block a user