From 10bfce71b35300b61cb9016a544eb79d61352f77 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team <89890040+yiakwy-xpu-ml-framework-team@users.noreply.github.com> Date: Mon, 20 Jan 2025 19:33:29 +0800 Subject: [PATCH] fix moe align blocks benchmark (#3003) --- .../benchmark_deepseekv3_moe_align_blocks.py | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py index 0a6049a12..d00f4985a 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py @@ -7,6 +7,8 @@ import triton import triton.language as tl from sgl_kernel import moe_align_block_size +USE_RANDOM_PERM = False + def ceil_div(a, b): return (a + b - 1) // b @@ -141,8 +143,13 @@ def moe_align_block_size_triton( def calculate_diff(batch_size, seq_len): num_experts = 256 block_size = 128 - topk_ids = torch.randint( - 0, num_experts, (batch_size, seq_len), dtype=torch.int32, device="cuda" + topk = 8 + + topk_ids = torch.stack( + [ + torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] + for _ in range(batch_size * seq_len) + ] ) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) @@ -169,7 +176,7 @@ def calculate_diff(batch_size, seq_len): expert_ids_triton = torch.empty_like(expert_ids_cuda) num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) - # 运行两个实现 + # compare the performance of cuda and triton implementation moe_align_block_size( topk_ids, num_experts, @@ -206,6 +213,15 @@ seq_length_range = [2**i for i in range(0, 16)] configs = list(itertools.product(batch_size_range, seq_length_range)) +def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: + topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda") + for i in range(num_tokens): + topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[ + :topk + ] + return topk_ids + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size", "seq_len"], @@ -223,9 +239,17 @@ def benchmark(batch_size, seq_len, provider): num_experts = 256 block_size = 128 topk = 8 - topk_ids = torch.randint( - 0, num_experts, (batch_size * seq_len, topk), dtype=torch.int32, device="cuda" - ) + + if USE_RANDOM_PERM: + topk_ids = get_topk_ids(batch_size * seq_len, num_experts, topk) + else: + topk_ids = torch.randint( + 0, + num_experts, + (batch_size * seq_len, topk), + dtype=torch.int32, + device="cuda", + ) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) sorted_ids = torch.empty(