add configs for block fp8 related kernels (#2628)

Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
Yineng Zhang
2024-12-28 23:12:04 +08:00
committed by GitHub
parent 333e3bfde5
commit 7863e4368a
37 changed files with 5131 additions and 50 deletions

View File

@@ -39,6 +39,7 @@ def benchmark_config(
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int] = None,
num_iters: int = 100,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
@@ -83,10 +84,23 @@ def benchmark_config(
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
if block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
@@ -114,6 +128,7 @@ def benchmark_config(
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
# JIT compilation & warmup
@@ -192,6 +207,7 @@ class BenchmarkWorker:
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0)
dtype_str = get_config_dtype_str(
@@ -199,8 +215,10 @@ class BenchmarkWorker:
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
op_config = get_moe_configs(
num_experts, shard_intermediate_size // 2, dtype_str
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
)
if op_config is None:
config = get_default_config(
@@ -223,6 +241,7 @@ class BenchmarkWorker:
dtype,
use_fp8_w8a8,
use_int8_w8a16,
block_shape,
)
return config, kernel_time
@@ -236,6 +255,7 @@ class BenchmarkWorker:
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
best_config = None
@@ -252,6 +272,7 @@ class BenchmarkWorker:
dtype,
use_fp8_w8a8,
use_int8_w8a16,
block_shape,
num_iters=10,
)
except triton.runtime.autotuner.OutOfResources:
@@ -287,6 +308,7 @@ def save_configs(
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
) -> None:
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
@@ -295,7 +317,10 @@ def save_configs(
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(
num_experts, shard_intermediate_size // 2, dtype_str
num_experts,
shard_intermediate_size // 2,
dtype_str,
block_shape,
)
print(f"Writing best config to {filename}...")
@@ -323,10 +348,10 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "DeepseekV2ForCausalLM":
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral
@@ -339,6 +364,13 @@ def main(args: argparse.Namespace):
dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
if args.batch_size is None:
batch_sizes = [
@@ -381,6 +413,14 @@ def main(args: argparse.Namespace):
if args.tune:
search_space = get_configs_compute_bound()
if block_shape is not None:
block_n, block_k = block_shape[0], block_shape[1]
search_space = [
config
for config in search_space
if block_n % config["BLOCK_SIZE_N"] == 0
and block_k % config["BLOCK_SIZE_K"] == 0
]
print(f"Start tuning over {len(search_space)} configurations...")
start = time.time()
@@ -396,6 +436,7 @@ def main(args: argparse.Namespace):
dtype,
use_fp8_w8a8,
use_int8_w8a16,
block_shape,
search_space,
)
for batch_size in batch_sizes
@@ -413,6 +454,7 @@ def main(args: argparse.Namespace):
dtype,
use_fp8_w8a8,
use_int8_w8a16,
block_shape,
)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
@@ -429,6 +471,7 @@ def main(args: argparse.Namespace):
dtype,
use_fp8_w8a8,
use_int8_w8a16,
block_shape,
)
for batch_size in batch_sizes
],