diff --git a/benchmark/kernels/fused_moe_triton/benchmark_ep_pre_reorder_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_ep_pre_reorder_triton.py new file mode 100644 index 000000000..c62424357 --- /dev/null +++ b/benchmark/kernels/fused_moe_triton/benchmark_ep_pre_reorder_triton.py @@ -0,0 +1,100 @@ +import argparse +import itertools + +import pandas as pd +import torch +import triton + +from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel + + +def benchmark_pre_reorder(batch_size, topk, model_config): + hidden_size = model_config["hidden_size"] + block_size = model_config["block_size"] + expert_range = model_config["expert_range"] + + input_ptr = torch.randn(batch_size, hidden_size, dtype=torch.float16, device="cuda") + gateup_input_ptr = torch.zeros( + batch_size * topk, hidden_size, dtype=torch.float16, device="cuda" + ) + src2dst_ptr = torch.randint( + 0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device="cuda" + ) + topk_ids_ptr = torch.randint( + expert_range[0], + expert_range[1] + 1, + (batch_size, topk), + dtype=torch.int32, + device="cuda", + ) + a1_scales_ptr = torch.rand( + expert_range[1] - expert_range[0] + 1, dtype=torch.float32, device="cuda" + ) + + input_ptr = input_ptr.view(-1) + gateup_input_ptr = gateup_input_ptr.view(-1) + src2dst_ptr = src2dst_ptr.view(-1) + topk_ids_ptr = topk_ids_ptr.view(-1) + + def run_kernel(): + pre_reorder_triton_kernel[(batch_size,)]( + input_ptr, + gateup_input_ptr, + src2dst_ptr, + topk_ids_ptr, + a1_scales_ptr, + expert_range[0], + expert_range[1], + topk, + hidden_size, + block_size, + ) + + for _ in range(10): + run_kernel() + torch.cuda.synchronize() + + ms, _, _ = triton.testing.do_bench(run_kernel, quantiles=[0.5, 0.2, 0.8]) + return ms + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--hidden-size", type=int, required=True) + parser.add_argument("--block-size", type=int, default=512) + args = parser.parse_args() + + model_config = { + "hidden_size": args.hidden_size, + "block_size": args.block_size, + "expert_range": (0, 255), + } + + batch_sizes = [64, 128, 256, 512, 640, 768, 1024] + topks = [2, 4, 8] + configs = list(itertools.product(batch_sizes, topks)) + + # Prepare results dict: keys = topk, each row is indexed by batch_size + results_dict = {topk: {} for topk in topks} + + for batch_size, topk in configs: + ms = benchmark_pre_reorder(batch_size, topk, model_config) + results_dict[topk][batch_size] = ms + + # Build dataframe + df = pd.DataFrame( + { + "batch_size": batch_sizes, + **{ + f"TopK={topk}": [results_dict[topk].get(bs, None) for bs in batch_sizes] + for topk in topks + }, + } + ) + + print("\npre-reorder-performance:") + print(df.to_string(index=False, float_format="%.6f")) + + +if __name__ == "__main__": + main() diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 8c005527a..56c6c7db7 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -184,8 +184,10 @@ def pre_reorder_triton_kernel( src_idx = tl.program_id(0) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk - src_ptr = input_ptr + src_idx * hidden_size + + vec = tl.arange(0, BLOCK_SIZE) + for idx in range(topk): expert_id = tl.load(topk_ids_ptr + idx) if expert_id >= start_expert_id and expert_id <= end_expert_id: @@ -197,7 +199,7 @@ def pre_reorder_triton_kernel( dst_idx = tl.load(src2dst_ptr + idx) dst_ptr = gateup_input_ptr + dst_idx * hidden_size for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): - offset = start_offset + tl.arange(0, BLOCK_SIZE) + offset = start_offset + vec mask = offset < hidden_size in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) out_data = (in_data * scale).to(OutDtype) @@ -481,8 +483,11 @@ def post_reorder_triton_kernel( computed = False store_ptr = output_ptr + src_idx * hidden_size + + vec = tl.arange(0, BLOCK_SIZE) + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): - offset = start_offset + tl.arange(0, BLOCK_SIZE) + offset = start_offset + vec mask = offset < hidden_size sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) @@ -499,7 +504,7 @@ def post_reorder_triton_kernel( if computed == False: for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): - offset = start_offset + tl.arange(0, BLOCK_SIZE) + offset = start_offset + vec mask = offset < hidden_size tl.store( store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask