Add fuse_moe per-channel tune (#10915)

This commit is contained in:
lukec
2025-09-25 21:12:09 +08:00
committed by GitHub
parent fce170480a
commit 77830a265e

View File

@@ -47,6 +47,7 @@ def benchmark_config(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int] = None,
num_iters: int = 100,
) -> float:
@@ -152,6 +153,7 @@ def benchmark_config(
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
)
@@ -261,6 +263,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0)
@@ -272,7 +275,12 @@ class BenchmarkWorker:
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
op_config = get_moe_configs(
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
num_experts,
shard_intermediate_size // 2,
dtype_str,
block_n,
block_k,
per_channel_quant,
)
if op_config is None:
config = get_default_config(
@@ -299,6 +307,7 @@ class BenchmarkWorker:
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
)
return config, kernel_time
@@ -314,6 +323,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
@@ -333,6 +343,7 @@ class BenchmarkWorker:
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
num_iters=10,
)
@@ -373,6 +384,7 @@ def save_configs(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
per_channel_quant: bool,
block_shape: List[int],
) -> None:
dtype_str = get_config_dtype_str(
@@ -389,6 +401,7 @@ def save_configs(
shard_intermediate_size // 2,
dtype_str,
block_shape,
per_channel_quant,
)
print(f"Writing best config to {filename}...")
@@ -471,6 +484,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
per_channel_quant = args.per_channel_quant
block_shape = None
if (
hasattr(config, "quantization_config")
@@ -543,6 +557,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
search_space,
)
@@ -562,6 +577,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
)
end = time.perf_counter()
@@ -580,6 +596,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
per_channel_quant,
block_shape,
)
for batch_size in batch_sizes
@@ -603,6 +620,10 @@ if __name__ == "__main__":
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
default="auto",
)
parser.add_argument(
"--per-channel-quant",
action="store_true",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")