From 076103535c933f5ac3505d5c887b8073a9044c38 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Wed, 28 May 2025 15:20:01 +0800 Subject: [PATCH] fix log_info_on_rank0 error when run benchmark (#6260) --- benchmark/kernels/fused_moe_triton/README.md | 13 +++- .../benchmark_torch_compile_fused_moe.py | 18 ++--- ...nchmark_vllm_vs_sglang_fused_moe_triton.py | 72 ++++++++++++++----- 3 files changed, 76 insertions(+), 27 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/README.md b/benchmark/kernels/fused_moe_triton/README.md index a0a7ca9c8..d9a1d1af3 100644 --- a/benchmark/kernels/fused_moe_triton/README.md +++ b/benchmark/kernels/fused_moe_triton/README.md @@ -58,15 +58,22 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri # Compare with FP8 mode for Qwen2-57B python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ --model Qwen/Qwen2-57B-A14B-Instruct \ - --use-fp8 + --use-fp8-w8a8 # Compare with custom TP size python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ - --tp-size 4 + --model deepseek-ai/DeepSeek-V3-0324 \ + --tp-size 8 + +# Compare with custom TP size and n_share_experts_fusion +python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ + --model deepseek-ai/DeepSeek-V3-0324 \ + --tp-size 8 \ + --n-share-experts-fusion 8 ``` The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`). - `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel. -Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`. +Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel. diff --git a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py index 13c83726c..bf3f80bb5 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py @@ -1,3 +1,4 @@ +# python3 benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8 import argparse import torch @@ -31,11 +32,12 @@ def get_model_config(model_name: str, tp_size: int): intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size elif config.architectures[0] == "Qwen3MoeForCausalLM": - E = config.num_experts + E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: + E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size @@ -99,7 +101,7 @@ def fused_moe_torch( a1_scale=None, a2_scale=None, ) -> torch.Tensor: - assert not use_fp8_w8a8, "Not supported" + assert not use_fp8_w8a8, "Fp8_w8a8 fused_moe is not supported for torch compile" topk_weights, topk_ids = fused_topk_native( hidden_states=x, @@ -193,7 +195,7 @@ def fused_moe_sglang_api( args={}, ) ) -def benchmark(batch_size, provider, model_config, use_fp8=False): +def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False): print(f"benchmark {provider} with batch_size={batch_size}") torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -208,7 +210,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): x = torch.randn(num_tokens, hidden_size, dtype=dtype) - if use_fp8: + if use_fp8_w8a8: init_dtype = dtype w1 = torch.randn( num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype @@ -244,7 +246,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): w2, input_gating, topk, - use_fp8_w8a8=use_fp8, + use_fp8_w8a8=use_fp8_w8a8, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, @@ -260,7 +262,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): w2, input_gating, topk, - use_fp8_w8a8=use_fp8, + use_fp8_w8a8=use_fp8_w8a8, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, @@ -277,7 +279,7 @@ def main(): "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" ) parser.add_argument("--tp-size", type=int, default=2) - parser.add_argument("--use-fp8", action="store_true") + parser.add_argument("--use-fp8-w8a8", action="store_true") parser.add_argument( "--save-path", type=str, @@ -291,7 +293,7 @@ def main(): print_data=True, save_path=args.save_path, model_config=model_config, - use_fp8=args.use_fp8, + use_fp8_w8a8=args.use_fp8_w8a8, ) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py index 85517e4e7..98ebcd728 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py @@ -1,3 +1,4 @@ +# python3 benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8 import argparse import torch @@ -6,12 +7,18 @@ import vllm from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm +from sglang.srt.distributed.parallel_state import ( + destroy_distributed_environment, + destroy_model_parallel, + init_distributed_environment, + initialize_model_parallel, +) from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( fused_moe as fused_moe_sglang, ) -def get_model_config(model_name: str, tp_size: int): +def get_model_config(model_name: str, tp_size: int, n_share_experts_fusion: int = 0): """Get model configuration parameters""" config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) @@ -36,7 +43,12 @@ def get_model_config(model_name: str, tp_size: int): intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: - E = config.n_routed_experts + n_share_fusion_experts = n_share_experts_fusion + E = ( + config.n_routed_experts + n_share_fusion_experts + if config.architectures[0] in ["DeepseekV3ForCausalLM"] + else config.n_routed_experts + ) topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size @@ -182,7 +194,7 @@ def fused_moe_sglang_api( args={}, ) ) -def benchmark(batch_size, provider, model_config, use_fp8=False): +def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False): print(f"benchmark {provider} with batch_size={batch_size}") torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -193,12 +205,12 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): shard_intermediate_size = model_config["shard_intermediate_size"] topk = model_config["topk"] dtype = model_config["dtype"] - block_shape = getattr(model_config, "block_shape", None) + block_shape = model_config["block_shape"] x = torch.randn(num_tokens, hidden_size, dtype=dtype) w1_scale = w2_scale = a1_scale = a2_scale = None - if use_fp8: + if use_fp8_w8a8: init_dtype = dtype w1 = torch.randn( num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype @@ -247,7 +259,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): w2, input_gating, topk, - use_fp8_w8a8=use_fp8, + use_fp8_w8a8=use_fp8_w8a8, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, @@ -264,7 +276,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): w2, input_gating, topk, - use_fp8_w8a8=use_fp8, + use_fp8_w8a8=use_fp8_w8a8, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, @@ -282,7 +294,8 @@ def main(): "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" ) parser.add_argument("--tp-size", type=int, default=2) - parser.add_argument("--use-fp8", action="store_true") + parser.add_argument("--n-share-experts-fusion", type=int, default=0) + parser.add_argument("--use-fp8-w8a8", action="store_true") parser.add_argument( "--save-path", type=str, @@ -290,14 +303,41 @@ def main(): ) args = parser.parse_args() - model_config = get_model_config(args.model, args.tp_size) - benchmark.run( - show_plots=True, - print_data=True, - save_path=args.save_path, - model_config=model_config, - use_fp8=args.use_fp8, - ) + try: + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + init_method="tcp://127.0.0.1:23456", + world_size=1, + rank=0, + ) + + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method="tcp://127.0.0.1:23456", + local_rank=0, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) + + initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + ) + + model_config = get_model_config( + args.model, args.tp_size, args.n_share_experts_fusion + ) + benchmark.run( + show_plots=True, + print_data=True, + save_path=args.save_path, + model_config=model_config, + use_fp8_w8a8=args.use_fp8_w8a8, + ) + finally: + destroy_model_parallel() + destroy_distributed_environment() if __name__ == "__main__":