From 7130a7cea9cec6161d458c2eea718ea4ce2c1324 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Wed, 12 Mar 2025 13:48:38 +0800 Subject: [PATCH] refine sgl_moe_align_block_size_benchmark (#4327) --- .../benchmark/bench_moe_align_block_size.py | 107 +++++++++++++----- 1 file changed, 77 insertions(+), 30 deletions(-) rename benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py => sgl-kernel/benchmark/bench_moe_align_block_size.py (73%) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py b/sgl-kernel/benchmark/bench_moe_align_block_size.py similarity index 73% rename from benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py rename to sgl-kernel/benchmark/bench_moe_align_block_size.py index 1d9504d0a..1fb0bd342 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py +++ b/sgl-kernel/benchmark/bench_moe_align_block_size.py @@ -4,7 +4,8 @@ import itertools import torch import triton import triton.language as tl -from sgl_kernel import moe_align_block_size +from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size +from vllm import _custom_ops as ops USE_RANDOM_PERM = False @@ -139,15 +140,11 @@ def moe_align_block_size_triton( ) -def calculate_diff(batch_size, seq_len): - num_experts = 256 - block_size = 128 - topk = 8 - +def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): topk_ids = torch.stack( [ torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] - for _ in range(batch_size * seq_len) + for _ in range(num_tokens) ] ) @@ -175,8 +172,13 @@ def calculate_diff(batch_size, seq_len): expert_ids_triton = torch.zeros_like(expert_ids_cuda) num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) - # compare the performance of cuda and triton implementation - moe_align_block_size( + sorted_ids_vllm = torch.empty_like(sorted_ids_cuda) + sorted_ids_vllm.fill_(topk_ids.numel()) + expert_ids_vllm = torch.zeros_like(expert_ids_cuda) + num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_cuda) + + # compare the performance of cuda, triton and vllm implementation + sgl_moe_align_block_size( topk_ids, num_experts, block_size, @@ -194,22 +196,43 @@ def calculate_diff(batch_size, seq_len): expert_ids_triton, num_tokens_post_pad_triton, ) + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids_vllm, + expert_ids_vllm, + num_tokens_post_pad_vllm, + ) if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( num_tokens_post_pad_cuda, num_tokens_post_pad_triton ): - print("✅ CUDA and Triton implementations match") + print("✅ SGL and Triton implementations match") else: - print("❌ CUDA and Triton implementations do not match") - print("CUDA expert_ids:", expert_ids_cuda) + print("❌ SGL and Triton implementations do not match") + print("SGL expert_ids:", expert_ids_cuda) print("Triton expert_ids:", expert_ids_triton) - print("CUDA num_tokens_post_pad:", num_tokens_post_pad_cuda) + print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda) print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton) + if torch.allclose(expert_ids_cuda, expert_ids_vllm) and torch.allclose( + num_tokens_post_pad_cuda, num_tokens_post_pad_vllm + ): + print("✅ SGL and VLLM implementations match") + else: + print("❌ SGL and VLLM implementations do not match") + print("SGL expert_ids:", expert_ids_cuda) + print("VLLM expert_ids:", expert_ids_vllm) + print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda) + print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm) -batch_size_range = [2**i for i in range(0, 8)] -seq_length_range = [2**i for i in range(0, 16)] -configs = list(itertools.product(batch_size_range, seq_length_range)) + +num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] +num_experts_range = [32, 64, 128, 256] +topk_range = [2, 4, 8] + +configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: @@ -223,29 +246,27 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["batch_size", "seq_len"], - x_vals=[list(_) for _ in configs], + x_names=["num_tokens", "num_experts", "topk"], + x_vals=configs, line_arg="provider", - line_vals=["cuda", "triton"], - line_names=["CUDA", "Triton"], - styles=[("blue", "-"), ("red", "-")], + line_vals=["sgl", "triton", "vllm"], + line_names=["SGL", "Triton", "VLLM"], + styles=[("blue", "-"), ("red", "-"), ("green", "-")], ylabel="us", plot_name="moe-align-block-size-performance", args={}, ) ) -def benchmark(batch_size, seq_len, provider): - num_experts = 256 +def benchmark(num_tokens, num_experts, topk, provider): block_size = 128 - topk = 8 if USE_RANDOM_PERM: - topk_ids = get_topk_ids(batch_size * seq_len, num_experts, topk) + topk_ids = get_topk_ids(num_tokens, num_experts, topk) else: topk_ids = torch.randint( 0, num_experts, - (batch_size * seq_len, topk), + (num_tokens, topk), dtype=torch.int32, device="cuda", ) @@ -268,9 +289,9 @@ def benchmark(batch_size, seq_len, provider): ) quantiles = [0.5, 0.2, 0.8] - if provider == "cuda": + if provider == "sgl": ms, min_ms, max_ms = triton.testing.do_bench( - lambda: moe_align_block_size( + lambda: sgl_moe_align_block_size( topk_ids, num_experts, block_size, @@ -282,7 +303,7 @@ def benchmark(batch_size, seq_len, provider): ), quantiles=quantiles, ) - else: + elif provider == "triton": ms, min_ms, max_ms = triton.testing.do_bench( lambda: moe_align_block_size_triton( topk_ids, @@ -294,6 +315,18 @@ def benchmark(batch_size, seq_len, provider): ), quantiles=quantiles, ) + else: # vllm + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids.clone(), + expert_ids.clone(), + num_tokens_post_pad.clone(), + ), + quantiles=quantiles, + ) return 1000 * ms, 1000 * max_ms, 1000 * min_ms @@ -306,8 +339,22 @@ if __name__ == "__main__": default="./configs/benchmark_ops/moe_align_blocks/", help="Path to save moe align benchmark results", ) + parser.add_argument( + "--num_experts", + type=int, + default=256, + choices=[8, 64, 128, 256], + help="Number of experts for benchmark", + ) + parser.add_argument( + "--topk", + type=int, + default=8, + choices=[2, 4, 8], + help="Top-k value for benchmark", + ) args = parser.parse_args() - calculate_diff(batch_size=4, seq_len=1024) + calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk) benchmark.run(print_data=True)