diff --git a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py new file mode 100644 index 000000000..1f54f9f9f --- /dev/null +++ b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py @@ -0,0 +1,275 @@ +import argparse + +import torch +import triton +from torch.nn import functional as F +from transformers import AutoConfig + +from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton + + +def get_model_config(model_name: str, tp_size: int): + """Get model configuration parameters""" + config = AutoConfig.from_pretrained(model_name) + + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + intermediate_size = config.ffn_config.ffn_hidden_size + shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] == "Qwen2MoeForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + else: + # Default: Mixtral + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + + shape_configs = { + "num_experts": E, + "topk": topk, + "hidden_size": config.hidden_size, + "shard_intermediate_size": shard_intermediate_size, + "dtype": config.torch_dtype, + } + print(f"{shape_configs=}") + return shape_configs + + +def fused_topk_native( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + M, _ = hidden_states.shape + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + topk_weights = F.softmax(gating_output.float(), dim=-1) + topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +@torch.compile +def fused_moe_torch( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=False, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, +) -> torch.Tensor: + assert not use_fp8_w8a8, "Not supported" + + topk_weights, topk_ids = fused_topk_native( + hidden_states=x, + gating_output=input_gating, + topk=topk, + renormalize=True, + ) + w13_weights = w1[topk_ids] + w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) + w2_weights = w2[topk_ids] + x1 = F.gelu(torch.einsum("ti,taoi -> tao", x, w1_weights)) + x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) + expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) + return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) + + +def fused_moe_torch_compile( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=False, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, +): + return fused_moe_torch( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + +def fused_moe_sglang_api( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=False, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, +): + return fused_moe_triton( + x, + w1, + w2, + input_gating, + topk, + renormalize=True, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=list(range(1, 5)), + line_arg="provider", + line_vals=[ + "fused_moe_triton", + "fused_moe_torch_compile", + ], + line_names=[ + "fused_moe_triton", + "fused_moe_torch_compile", + ], + styles=[ + ("blue", "-"), + ("green", "-"), + ], + ylabel="Time (ms)", + plot_name="fused-moe-performance", + args={}, + ) +) +def benchmark(batch_size, provider, model_config, use_fp8=False): + print(f"benchmark {provider} with batch_size={batch_size}") + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + + num_tokens = batch_size + num_experts = model_config["num_experts"] + hidden_size = model_config["hidden_size"] + shard_intermediate_size = model_config["shard_intermediate_size"] + topk = model_config["topk"] + dtype = model_config["dtype"] + + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + + if use_fp8: + init_dtype = dtype + w1 = torch.randn( + num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype + ) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype + ) + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) + else: + w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype + ) + w1_scale = w2_scale = a1_scale = a2_scale = None + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + + # Warmup + api_func = ( + fused_moe_torch_compile + if provider == "fused_moe_torch_compile" + else fused_moe_sglang_api + ) + for _ in range(10): + y = api_func( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + torch.cuda.synchronize() + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: api_func( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + )[0], + quantiles=quantiles, + ) + return ms, min_ms, max_ms + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument("--tp-size", type=int, default=2) + parser.add_argument("--use-fp8", action="store_true") + parser.add_argument( + "--save-path", + type=str, + default="./configs/benchmark_ops/fused_moe_torch_compile/", + ) + args = parser.parse_args() + + model_config = get_model_config(args.model, args.tp_size) + benchmark.run( + show_plots=True, + print_data=True, + save_path=args.save_path, + model_config=model_config, + use_fp8=args.use_fp8, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py index 33b85f40e..7bfb2731b 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py @@ -1,22 +1,11 @@ import argparse -import numbers -from typing import Optional import torch import triton -from torch.nn import init -from torch.nn.parameter import Parameter from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm -from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_moe_configs as get_moe_configs_vllm, -) -from vllm.utils import FlexibleArgumentParser from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang -from sglang.srt.layers.fused_moe_triton.fused_moe import ( - get_moe_configs as get_moe_configs_sglang, -) def get_model_config(model_name: str, tp_size: int): @@ -39,19 +28,21 @@ def get_model_config(model_name: str, tp_size: int): intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size else: - # Default: Mixtral, Grok1, etc. + # Default: Mixtral E = config.num_local_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size - return { + shape_configs = { "num_experts": E, "topk": topk, "hidden_size": config.hidden_size, "shard_intermediate_size": shard_intermediate_size, "dtype": config.torch_dtype, } + print(f"{shape_configs=}") + return shape_configs def fused_moe_vllm_api( @@ -133,7 +124,7 @@ def fused_moe_sglang_api( ) ) def benchmark(batch_size, provider, model_config, use_fp8=False): - print(f"benchmark for batch_size={batch_size}") + print(f"benchmark {provider} with batch_size={batch_size}") torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -210,7 +201,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False): def main(): - parser = FlexibleArgumentParser() + parser = argparse.ArgumentParser() parser.add_argument( "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" ) diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 9b232264a..6f6a57be1 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -1,5 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py import argparse +import json import time from datetime import datetime from typing import Any, Dict, List, Tuple, TypedDict @@ -9,10 +10,14 @@ import torch import triton from ray.experimental.tqdm_ray import tqdm from transformers import AutoConfig -from vllm.platforms import current_platform -from vllm.utils import FlexibleArgumentParser -from sglang.srt.layers.fused_moe_triton.fused_moe import * +from sglang.srt.layers.fused_moe_triton.fused_moe import ( + fused_moe, + get_config_dtype_str, + get_config_file_name, + get_default_config, + get_moe_configs, +) class BenchmarkConfig(TypedDict): @@ -92,7 +97,7 @@ def benchmark_config( input_gating.copy_(gating_output[i]) def run(): - from sglang.srt.layers.fused_moe_triton.fused_moe import override_config + from sglang.srt.layers.fused_moe_triton import override_config with override_config(config): fused_moe( @@ -174,7 +179,7 @@ class BenchmarkWorker: def __init__(self, seed: int) -> None: torch.set_default_device("cuda") - current_platform.seed_everything(seed) + torch.cuda.manual_seed_all(0) self.seed = seed def benchmark( @@ -188,7 +193,7 @@ class BenchmarkWorker: use_fp8_w8a8: bool, use_int8_w8a16: bool, ) -> Tuple[Dict[str, int], float]: - current_platform.seed_everything(self.seed) + torch.cuda.manual_seed_all(0) dtype_str = get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 ) @@ -319,7 +324,7 @@ def main(args: argparse.Namespace): intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size else: - # Default: Mixtral. + # Default: Mixtral E = config.num_local_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size @@ -430,7 +435,7 @@ def main(args: argparse.Namespace): if __name__ == "__main__": - parser = FlexibleArgumentParser() + parser = argparse.ArgumentParser() parser.add_argument( "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" )