Add fuse_moe per-channel tune (#10915)
This commit is contained in:
@@ -47,6 +47,7 @@ def benchmark_config(
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a8: bool,
|
use_int8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
|
per_channel_quant: bool,
|
||||||
block_shape: List[int] = None,
|
block_shape: List[int] = None,
|
||||||
num_iters: int = 100,
|
num_iters: int = 100,
|
||||||
) -> float:
|
) -> float:
|
||||||
@@ -152,6 +153,7 @@ def benchmark_config(
|
|||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
|
per_channel_quant=per_channel_quant,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -261,6 +263,7 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a8: bool,
|
use_int8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
|
per_channel_quant: bool,
|
||||||
block_shape: List[int],
|
block_shape: List[int],
|
||||||
) -> Tuple[Dict[str, int], float]:
|
) -> Tuple[Dict[str, int], float]:
|
||||||
torch.cuda.manual_seed_all(0)
|
torch.cuda.manual_seed_all(0)
|
||||||
@@ -272,7 +275,12 @@ class BenchmarkWorker:
|
|||||||
block_n = block_shape[0] if block_shape else 0
|
block_n = block_shape[0] if block_shape else 0
|
||||||
block_k = block_shape[1] if block_shape else 0
|
block_k = block_shape[1] if block_shape else 0
|
||||||
op_config = get_moe_configs(
|
op_config = get_moe_configs(
|
||||||
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
|
num_experts,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
dtype_str,
|
||||||
|
block_n,
|
||||||
|
block_k,
|
||||||
|
per_channel_quant,
|
||||||
)
|
)
|
||||||
if op_config is None:
|
if op_config is None:
|
||||||
config = get_default_config(
|
config = get_default_config(
|
||||||
@@ -299,6 +307,7 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
|
per_channel_quant,
|
||||||
block_shape,
|
block_shape,
|
||||||
)
|
)
|
||||||
return config, kernel_time
|
return config, kernel_time
|
||||||
@@ -314,6 +323,7 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a8: bool,
|
use_int8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
|
per_channel_quant: bool,
|
||||||
block_shape: List[int],
|
block_shape: List[int],
|
||||||
search_space: List[Dict[str, int]],
|
search_space: List[Dict[str, int]],
|
||||||
) -> Dict[str, int]:
|
) -> Dict[str, int]:
|
||||||
@@ -333,6 +343,7 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
|
per_channel_quant,
|
||||||
block_shape,
|
block_shape,
|
||||||
num_iters=10,
|
num_iters=10,
|
||||||
)
|
)
|
||||||
@@ -373,6 +384,7 @@ def save_configs(
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a8: bool,
|
use_int8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
|
per_channel_quant: bool,
|
||||||
block_shape: List[int],
|
block_shape: List[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
dtype_str = get_config_dtype_str(
|
dtype_str = get_config_dtype_str(
|
||||||
@@ -389,6 +401,7 @@ def save_configs(
|
|||||||
shard_intermediate_size // 2,
|
shard_intermediate_size // 2,
|
||||||
dtype_str,
|
dtype_str,
|
||||||
block_shape,
|
block_shape,
|
||||||
|
per_channel_quant,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Writing best config to {filename}...")
|
print(f"Writing best config to {filename}...")
|
||||||
@@ -471,6 +484,7 @@ def main(args: argparse.Namespace):
|
|||||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
use_int8_w8a8 = args.dtype == "int8_w8a8"
|
use_int8_w8a8 = args.dtype == "int8_w8a8"
|
||||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
|
per_channel_quant = args.per_channel_quant
|
||||||
block_shape = None
|
block_shape = None
|
||||||
if (
|
if (
|
||||||
hasattr(config, "quantization_config")
|
hasattr(config, "quantization_config")
|
||||||
@@ -543,6 +557,7 @@ def main(args: argparse.Namespace):
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
|
per_channel_quant,
|
||||||
block_shape,
|
block_shape,
|
||||||
search_space,
|
search_space,
|
||||||
)
|
)
|
||||||
@@ -562,6 +577,7 @@ def main(args: argparse.Namespace):
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
|
per_channel_quant,
|
||||||
block_shape,
|
block_shape,
|
||||||
)
|
)
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
@@ -580,6 +596,7 @@ def main(args: argparse.Namespace):
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
|
per_channel_quant,
|
||||||
block_shape,
|
block_shape,
|
||||||
)
|
)
|
||||||
for batch_size in batch_sizes
|
for batch_size in batch_sizes
|
||||||
@@ -603,6 +620,10 @@ if __name__ == "__main__":
|
|||||||
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
|
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
|
||||||
default="auto",
|
default="auto",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per-channel-quant",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--batch-size", type=int, required=False)
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
parser.add_argument("--tune", action="store_true")
|
parser.add_argument("--tune", action="store_true")
|
||||||
|
|||||||
Reference in New Issue
Block a user