From 849f58d617e773d1a1edd49dacb2fac8257240b0 Mon Sep 17 00:00:00 2001 From: GaoYuYang Date: Sat, 8 Feb 2025 21:58:21 +0800 Subject: [PATCH] Update fused_moe's benchmark (#3346) --- ...nchmark_vllm_vs_sglang_fused_moe_triton.py | 97 ++++++++++++++----- 1 file changed, 75 insertions(+), 22 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py index faf5c6b4e..4edb2dff8 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py @@ -2,6 +2,7 @@ import argparse import torch import triton +import vllm from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm @@ -29,11 +30,11 @@ def get_model_config(model_name: str, tp_size: int): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size - elif config.architectures[0] == "DeepseekV2ForCausalLM": + elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: E = config.n_routed_experts topk = config.num_experts_per_tok - intermediate_size = config.intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size else: # Default: Mixtral E = config.num_local_experts @@ -41,12 +42,27 @@ def get_model_config(model_name: str, tp_size: int): intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size + vllm_version_num = ( + vllm.__version_tuple__[0] * 100 + + vllm.__version_tuple__[1] * 10 + + vllm.__version_tuple__[2] + ) + block_shape = None + if ( + hasattr(config, "quantization_config") + and "weight_block_size" in config.quantization_config + ): + block_shape = config.quantization_config["weight_block_size"] + assert len(block_shape) == 2 + assert vllm_version_num >= 66, "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1" + shape_configs = { "num_experts": E, "topk": topk, "hidden_size": config.hidden_size, "shard_intermediate_size": shard_intermediate_size, "dtype": config.torch_dtype, + "block_shape": block_shape, } print(f"{shape_configs=}") return shape_configs @@ -63,21 +79,39 @@ def fused_moe_vllm_api( w2_scale=None, a1_scale=None, a2_scale=None, + block_shape=None, ): - return fused_moe_vllm( - x, - w1, - w2, - input_gating, - topk, - renormalize=True, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - ) + if block_shape is not None: + return fused_moe_vllm( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + ) + else: + return fused_moe_vllm( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) def fused_moe_sglang_api( @@ -91,6 +125,7 @@ def fused_moe_sglang_api( w2_scale=None, a1_scale=None, a2_scale=None, + block_shape=None, ): return fused_moe_sglang( x, @@ -105,6 +140,7 @@ def fused_moe_sglang_api( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) @@ -141,8 +177,10 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): shard_intermediate_size = model_config["shard_intermediate_size"] topk = model_config["topk"] dtype = model_config["dtype"] + block_shape = getattr(model_config, "block_shape", None) x = torch.randn(num_tokens, hidden_size, dtype=dtype) + w1_scale = w2_scale = a1_scale = a2_scale = None if use_fp8: init_dtype = dtype @@ -154,16 +192,29 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): ) w1 = w1.to(torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fn) - w1_scale = torch.randn(num_experts, dtype=torch.float32) - w2_scale = torch.randn(num_experts, dtype=torch.float32) - a1_scale = torch.randn(1, dtype=torch.float32) - a2_scale = torch.randn(1, dtype=torch.float32) + + if block_shape is None: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) + else: + block_n, block_k = block_shape[0], block_shape[1] + n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n + n_tiles_w2 = (hidden_size + block_n - 1) // block_n + k_tiles_w1 = (hidden_size + block_k - 1) // block_k + k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k + w1_scale = torch.rand( + (num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32 + ) + w2_scale = torch.rand( + (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32 + ) else: w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype) w2 = torch.randn( num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype ) - w1_scale = w2_scale = a1_scale = a2_scale = None input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) @@ -185,6 +236,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) torch.cuda.synchronize() @@ -201,6 +253,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, )[0], quantiles=quantiles, )