Add fuse_moe per-channel tune (#10915)
This commit is contained in:
@@ -47,6 +47,7 @@ def benchmark_config(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: List[int] = None,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
@@ -152,6 +153,7 @@ def benchmark_config(
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
@@ -261,6 +263,7 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: List[int],
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
torch.cuda.manual_seed_all(0)
|
||||
@@ -272,7 +275,12 @@ class BenchmarkWorker:
|
||||
block_n = block_shape[0] if block_shape else 0
|
||||
block_k = block_shape[1] if block_shape else 0
|
||||
op_config = get_moe_configs(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
|
||||
num_experts,
|
||||
shard_intermediate_size // 2,
|
||||
dtype_str,
|
||||
block_n,
|
||||
block_k,
|
||||
per_channel_quant,
|
||||
)
|
||||
if op_config is None:
|
||||
config = get_default_config(
|
||||
@@ -299,6 +307,7 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
)
|
||||
return config, kernel_time
|
||||
@@ -314,6 +323,7 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: List[int],
|
||||
search_space: List[Dict[str, int]],
|
||||
) -> Dict[str, int]:
|
||||
@@ -333,6 +343,7 @@ class BenchmarkWorker:
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
num_iters=10,
|
||||
)
|
||||
@@ -373,6 +384,7 @@ def save_configs(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: List[int],
|
||||
) -> None:
|
||||
dtype_str = get_config_dtype_str(
|
||||
@@ -389,6 +401,7 @@ def save_configs(
|
||||
shard_intermediate_size // 2,
|
||||
dtype_str,
|
||||
block_shape,
|
||||
per_channel_quant,
|
||||
)
|
||||
|
||||
print(f"Writing best config to {filename}...")
|
||||
@@ -471,6 +484,7 @@ def main(args: argparse.Namespace):
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a8 = args.dtype == "int8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
per_channel_quant = args.per_channel_quant
|
||||
block_shape = None
|
||||
if (
|
||||
hasattr(config, "quantization_config")
|
||||
@@ -543,6 +557,7 @@ def main(args: argparse.Namespace):
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
search_space,
|
||||
)
|
||||
@@ -562,6 +577,7 @@ def main(args: argparse.Namespace):
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
)
|
||||
end = time.perf_counter()
|
||||
@@ -580,6 +596,7 @@ def main(args: argparse.Namespace):
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
@@ -603,6 +620,10 @@ if __name__ == "__main__":
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
|
||||
default="auto",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-channel-quant",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
|
||||
Reference in New Issue
Block a user