[Perf]Use Cooperative Schedule for H100 & H200 & H800 in fp8_blockwise_scaled_grouped_mm (#8722)
This commit is contained in:
@@ -485,7 +485,8 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||||
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||||
|
|
||||||
if (a.size(1) > 128) {
|
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
|
||||||
|
// For H20 with K > 128, use Pingpong Schedule
|
||||||
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
|
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
|
||||||
expert_offsets,
|
expert_offsets,
|
||||||
a_ptrs,
|
a_ptrs,
|
||||||
@@ -517,7 +518,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
|||||||
expert_offsets,
|
expert_offsets,
|
||||||
workspace);
|
workspace);
|
||||||
} else {
|
} else {
|
||||||
// Small K
|
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
|
||||||
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||||
expert_offsets,
|
expert_offsets,
|
||||||
a_ptrs,
|
a_ptrs,
|
||||||
|
|||||||
Reference in New Issue
Block a user