From 77830a265e9f61dfa0a51540625925fe61112e9f Mon Sep 17 00:00:00 2001 From: lukec <118525388+sleepcoo@users.noreply.github.com> Date: Thu, 25 Sep 2025 21:12:09 +0800 Subject: [PATCH] Add fuse_moe per-channel tune (#10915) --- .../tuning_fused_moe_triton.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 40a1979a1..eecc3ca2b 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -47,6 +47,7 @@ def benchmark_config( use_fp8_w8a8: bool, use_int8_w8a8: bool, use_int8_w8a16: bool, + per_channel_quant: bool, block_shape: List[int] = None, num_iters: int = 100, ) -> float: @@ -152,6 +153,7 @@ def benchmark_config( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + per_channel_quant=per_channel_quant, block_shape=block_shape, ) @@ -261,6 +263,7 @@ class BenchmarkWorker: use_fp8_w8a8: bool, use_int8_w8a8: bool, use_int8_w8a16: bool, + per_channel_quant: bool, block_shape: List[int], ) -> Tuple[Dict[str, int], float]: torch.cuda.manual_seed_all(0) @@ -272,7 +275,12 @@ class BenchmarkWorker: block_n = block_shape[0] if block_shape else 0 block_k = block_shape[1] if block_shape else 0 op_config = get_moe_configs( - num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k + num_experts, + shard_intermediate_size // 2, + dtype_str, + block_n, + block_k, + per_channel_quant, ) if op_config is None: config = get_default_config( @@ -299,6 +307,7 @@ class BenchmarkWorker: use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, + per_channel_quant, block_shape, ) return config, kernel_time @@ -314,6 +323,7 @@ class BenchmarkWorker: use_fp8_w8a8: bool, use_int8_w8a8: bool, use_int8_w8a16: bool, + per_channel_quant: bool, block_shape: List[int], search_space: List[Dict[str, int]], ) -> Dict[str, int]: @@ -333,6 +343,7 @@ class BenchmarkWorker: use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, + per_channel_quant, block_shape, num_iters=10, ) @@ -373,6 +384,7 @@ def save_configs( use_fp8_w8a8: bool, use_int8_w8a8: bool, use_int8_w8a16: bool, + per_channel_quant: bool, block_shape: List[int], ) -> None: dtype_str = get_config_dtype_str( @@ -389,6 +401,7 @@ def save_configs( shard_intermediate_size // 2, dtype_str, block_shape, + per_channel_quant, ) print(f"Writing best config to {filename}...") @@ -471,6 +484,7 @@ def main(args: argparse.Namespace): use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a8 = args.dtype == "int8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" + per_channel_quant = args.per_channel_quant block_shape = None if ( hasattr(config, "quantization_config") @@ -543,6 +557,7 @@ def main(args: argparse.Namespace): use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, + per_channel_quant, block_shape, search_space, ) @@ -562,6 +577,7 @@ def main(args: argparse.Namespace): use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, + per_channel_quant, block_shape, ) end = time.perf_counter() @@ -580,6 +596,7 @@ def main(args: argparse.Namespace): use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, + per_channel_quant, block_shape, ) for batch_size in batch_sizes @@ -603,6 +620,10 @@ if __name__ == "__main__": choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"], default="auto", ) + parser.add_argument( + "--per-channel-quant", + action="store_true", + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--tune", action="store_true")