add configs for block fp8 related kernels (#2628)
Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
@@ -39,6 +39,7 @@ def benchmark_config(
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int] = None,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
@@ -83,10 +84,23 @@ def benchmark_config(
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_fp8_w8a8:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
if block_shape is None:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
else:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
|
||||
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
|
||||
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
|
||||
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
|
||||
w1_scale = torch.rand(
|
||||
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.rand(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||
)
|
||||
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
@@ -114,6 +128,7 @@ def benchmark_config(
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
@@ -192,6 +207,7 @@ class BenchmarkWorker:
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
torch.cuda.manual_seed_all(0)
|
||||
dtype_str = get_config_dtype_str(
|
||||
@@ -199,8 +215,10 @@ class BenchmarkWorker:
|
||||
)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
block_n = block_shape[0] if block_shape else 0
|
||||
block_k = block_shape[1] if block_shape else 0
|
||||
op_config = get_moe_configs(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str
|
||||
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
|
||||
)
|
||||
if op_config is None:
|
||||
config = get_default_config(
|
||||
@@ -223,6 +241,7 @@ class BenchmarkWorker:
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
return config, kernel_time
|
||||
|
||||
@@ -236,6 +255,7 @@ class BenchmarkWorker:
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
search_space: List[Dict[str, int]],
|
||||
) -> Dict[str, int]:
|
||||
best_config = None
|
||||
@@ -252,6 +272,7 @@ class BenchmarkWorker:
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
num_iters=10,
|
||||
)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
@@ -287,6 +308,7 @@ def save_configs(
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
) -> None:
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
@@ -295,7 +317,10 @@ def save_configs(
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
filename = get_config_file_name(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str
|
||||
num_experts,
|
||||
shard_intermediate_size // 2,
|
||||
dtype_str,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
print(f"Writing best config to {filename}...")
|
||||
@@ -323,10 +348,10 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "DeepseekV2ForCausalLM":
|
||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Default: Mixtral
|
||||
@@ -339,6 +364,13 @@ def main(args: argparse.Namespace):
|
||||
dtype = config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
block_shape = None
|
||||
if (
|
||||
hasattr(config, "quantization_config")
|
||||
and "weight_block_size" in config.quantization_config
|
||||
):
|
||||
block_shape = config.quantization_config["weight_block_size"]
|
||||
assert len(block_shape) == 2
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
@@ -381,6 +413,14 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if args.tune:
|
||||
search_space = get_configs_compute_bound()
|
||||
if block_shape is not None:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
search_space = [
|
||||
config
|
||||
for config in search_space
|
||||
if block_n % config["BLOCK_SIZE_N"] == 0
|
||||
and block_k % config["BLOCK_SIZE_K"] == 0
|
||||
]
|
||||
print(f"Start tuning over {len(search_space)} configurations...")
|
||||
|
||||
start = time.time()
|
||||
@@ -396,6 +436,7 @@ def main(args: argparse.Namespace):
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
search_space,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
@@ -413,6 +454,7 @@ def main(args: argparse.Namespace):
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
@@ -429,6 +471,7 @@ def main(args: argparse.Namespace):
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user