From 253454de9b53c8797076f0ad9e9118c0c9186a56 Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Mon, 7 Jul 2025 11:05:49 +0800 Subject: [PATCH] Integrate triton moe kernel (#7689) Co-authored-by: luoyuan.luo --- .../benchmark_sglang_fused_moe_triton.py | 271 ++++++++++++++++++ .../layers/moe/fused_moe_triton/fused_moe.py | 2 + .../srt/layers/moe/fused_moe_triton/layer.py | 149 ++++++---- .../fused_moe_triton/triton_kernels_moe.py | 176 ++++++++++++ python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/server_args.py | 6 + test/srt/test_triton_fused_moe.py | 146 ++++++++++ 7 files changed, 697 insertions(+), 54 deletions(-) create mode 100644 benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py create mode 100644 test/srt/test_triton_fused_moe.py diff --git a/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py new file mode 100644 index 000000000..c392f8e77 --- /dev/null +++ b/benchmark/kernels/fused_moe_triton/benchmark_sglang_fused_moe_triton.py @@ -0,0 +1,271 @@ +# python3 benchmark/kernels/fused_moe_triton/sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8 +import argparse + +import torch +import triton +from transformers import AutoConfig + +from sglang.srt.distributed.parallel_state import ( + destroy_distributed_environment, + destroy_model_parallel, + init_distributed_environment, + initialize_model_parallel, +) +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_moe as fused_moe_sglang, +) +from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( + triton_kernel_moe_forward, +) + + +def get_model_config(model_name: str, tp_size: int): + """Get model configuration parameters""" + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + + if config.architectures[0] == "Qwen2MoeForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] == "Qwen3MoeForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: + E = ( + config.n_routed_experts + 1 + if config.architectures[0] in ["DeepseekV3ForCausalLM"] + else config.n_routed_experts + ) + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + else: + # Default: Mixtral + E = config.num_local_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // tp_size + + block_shape = None + if ( + hasattr(config, "quantization_config") + and "weight_block_size" in config.quantization_config + ): + block_shape = config.quantization_config["weight_block_size"] + assert len(block_shape) == 2 + + shape_configs = { + "num_experts": E, + "topk": topk, + "hidden_size": config.hidden_size, + "shard_intermediate_size": shard_intermediate_size, + "dtype": config.torch_dtype, + "block_shape": block_shape, + } + print(f"{shape_configs=}") + return shape_configs + + +def fused_moe_triton_api( + x, + w1, + w2, + input_gating, + topk, +): + return triton_kernel_moe_forward( + x, + w1, + w2, + input_gating, + topk, + renormalize=False, + ) + + +def fused_moe_sglang_api( + x, + w1, + w2, + input_gating, + topk, + use_fp8_w8a8=False, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, + block_shape=None, +): + return fused_moe_sglang( + x, + w1, + w2, + input_gating, + topk, + renormalize=False, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + ) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=list([128, 256, 512, 1024, 2048, 4096, 8192]), + line_arg="provider", + line_vals=[ + "sglang_fused_moe_triton_v340", + "sglang_fused_moe_triton", + ], + line_names=[ + "sglang_fused_moe_triton_v340", + "sglang_fused_moe_triton", + ], + styles=[ + ("blue", "-"), + ("green", "-"), + ], + ylabel="Time (ms)", + plot_name="fused-moe-performance", + args={}, + ) +) +def benchmark( + batch_size, + provider, + model_config, + use_fp8_w8a8=False, + use_cuda_graph: bool = False, +): + print(f"benchmark {provider} with batch_size={batch_size}") + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + + num_tokens = batch_size + num_experts = model_config["num_experts"] + hidden_size = model_config["hidden_size"] + shard_intermediate_size = model_config["shard_intermediate_size"] + topk = model_config["topk"] + dtype = model_config["dtype"] + block_shape = model_config["block_shape"] + + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + + w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype + ) + + w1_tri = w1.clone() + w2_tri = w2.clone() + w1_tri = w1_tri.transpose(-2, -1).contiguous() + w2_tri = w2_tri.transpose(-2, -1).contiguous() + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + + if provider == "sglang_fused_moe_triton_v340": + api_func = fused_moe_triton_api + api_kwargs = { + "x": x, + "w1": w1_tri, + "w2": w2_tri, + "input_gating": input_gating, + "topk": topk, + } + else: + api_func = fused_moe_sglang_api + api_kwargs = { + "x": x, + "w1": w1, + "w2": w2, + "input_gating": input_gating, + "topk": topk, + "use_fp8_w8a8": use_fp8_w8a8, + "block_shape": block_shape, + } + + # Warmup + for _ in range(10): + _ = api_func(**api_kwargs) + torch.cuda.synchronize() + + if use_cuda_graph: + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + api_func(**api_kwargs) + torch.cuda.synchronize() + + bench_lambda = lambda: graph.replay() + else: + bench_lambda = lambda: api_func(**api_kwargs) + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench(bench_lambda, quantiles=quantiles) + return ms, min_ms, max_ms + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument("--tp-size", type=int, default=2) + parser.add_argument("--use-fp8-w8a8", action="store_true") + parser.add_argument( + "--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay" + ) + parser.add_argument( + "--save-path", + type=str, + default="./configs/benchmark_ops/sglang_fused_moe/", + ) + parser.add_argument("--trust-remote-code", action="store_true") + args = parser.parse_args() + + try: + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + init_method="tcp://127.0.0.1:23456", + world_size=1, + rank=0, + ) + + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method="tcp://127.0.0.1:23456", + local_rank=0, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) + + initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + ) + + model_config = get_model_config(args.model, args.tp_size) + benchmark.run( + show_plots=True, + print_data=True, + save_path=args.save_path, + model_config=model_config, + use_fp8_w8a8=args.use_fp8_w8a8, + use_cuda_graph=args.use_cuda_graph, + ) + finally: + destroy_model_parallel() + destroy_distributed_environment() + + +if __name__ == "__main__": + main() diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index e6deeeae7..baf8f5c87 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -1737,6 +1737,7 @@ def fused_moe( renormalize: bool, inplace: bool = False, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, num_fused_shared_experts: int = 0, @@ -1822,6 +1823,7 @@ def fused_moe( topk_ids, inplace=inplace, activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 997297be6..5445b4f23 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1,5 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py +import importlib from abc import abstractmethod from enum import Enum from typing import Callable, List, Optional, Tuple @@ -19,6 +20,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.utils import ( cpu_has_amx_support, @@ -29,8 +31,15 @@ from sglang.srt.utils import ( use_intel_amx_backend, ) +has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None + if torch.cuda.is_available(): from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + + if has_triton_kernels: + from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( + triton_kernel_moe_forward, + ) else: fused_experts = None # type: ignore @@ -87,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" + def __init__(self, use_triton_kernels: bool = False): + super().__init__() + self.use_triton_kernels = use_triton_kernels + def create_weights( self, layer: torch.nn.Module, @@ -97,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): **extra_weight_attrs, ): # Fused gate_up_proj (column parallel) + w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size + if self.use_triton_kernels: + w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype - ), + torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) + w2_weight_n, w2_weight_k = ( + hidden_size, + intermediate_size, + ) + if self.use_triton_kernels: + w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, hidden_size, intermediate_size, dtype=params_dtype - ), + torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) @@ -192,59 +210,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) - if _use_aiter: - assert not no_combine, "unsupported" - if apply_router_weight_on_input: - assert ( - topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - x = x * topk_weights.to(x.dtype) - topk_weights = torch.ones_like( - topk_weights, dtype=torch.float32 - ) # topk_weights must be FP32 (float32) - - return fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=( - ActivationType.Silu if activation == "silu" else ActivationType.Gelu - ), - ) - else: - return fused_experts( + if self.use_triton_kernels: + return triton_kernel_moe_forward( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace and not no_combine, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - no_combine=no_combine, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, routed_scaling_factor=routed_scaling_factor, ) + if _use_aiter: + assert not no_combine, "unsupported" + if apply_router_weight_on_input: + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + x = x * topk_weights.to(x.dtype) + topk_weights = torch.ones_like( + topk_weights, dtype=torch.float32 + ) # topk_weights must be FP32 (float32) + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=( + ActivationType.Silu + if activation == "silu" + else ActivationType.Gelu + ), + ) + else: + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace and not no_combine, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, + ) + def forward_cpu( self, layer: torch.nn.Module, @@ -475,9 +506,13 @@ class FusedMoE(torch.nn.Module): self.inplace = inplace self.no_combine = no_combine + self.use_triton_kernels = ( + not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"] + ) + if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod() + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( + self.use_triton_kernels ) else: self.quant_method = quant_config.get_quant_method(self, prefix) @@ -597,6 +632,8 @@ class FusedMoE(torch.nn.Module): ) else: if not self.use_presharded_weights: + if self.use_triton_kernels: + loaded_weight = loaded_weight.transpose(-2, -1) loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) @@ -630,6 +667,8 @@ class FusedMoE(torch.nn.Module): ) else: if not self.use_presharded_weights: + if self.use_triton_kernels: + loaded_weight = loaded_weight.transpose(-2, -1) loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) @@ -716,6 +755,8 @@ class FusedMoE(torch.nn.Module): # should be whatever dimension intermediate_size is is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + if self.use_triton_kernels: + is_transposed = True if is_transposed: shard_dim = int(not shard_dim) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py new file mode 100644 index 000000000..57b7f20f0 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py @@ -0,0 +1,176 @@ +# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035 +from typing import Optional + +import torch +from sgl_kernel import gelu_and_mul, silu_and_mul +from triton_kernels.matmul_ogs import matmul_ogs +from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing + +from sglang.srt.utils import direct_register_custom_op + + +def triton_kernel_moe_forward( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + + if not renormalize: + gating_output = torch.softmax(gating_output, dim=-1) + routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize) + + return triton_kernel_fused_experts( + hidden_states, + w1, + w2, + routing_data, + gather_idx, + scatter_idx, + inplace=inplace, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=use_fp8_w8a8, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + ) + + +# This is a triton implementation of the fused_experts function +def triton_kernel_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + routing_data: RoutingData, + gather_indx: GatherIndx, + scatter_indx: ScatterIndx, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + + assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported" + assert per_channel_quant == False, "per_channel_quant is not supported" + assert expert_map == None, "expert_map is not supported" + assert w1_scale == None, "w1_scale is not supported" + assert w2_scale == None, "w2_scale is not supported" + assert a1_scale == None, "a1_scale is not supported" + assert a2_scale == None, "a2_scale is not supported" + assert block_shape == None, "block_shape is not supported" + + # type check + assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16" + assert w1.dtype == torch.bfloat16, "w1 must be bfloat16" + assert w2.dtype == torch.bfloat16, "w2 must be bfloat16" + + # Shape check + assert hidden_states.ndim == 2, "hidden_states must be 2D" + assert ( + hidden_states.shape[-1] == w1.shape[-2] + ), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}" + assert ( + w2.shape[-1] == w1.shape[1] + ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}" + + # feature check + assert inplace == False, "Inplace is not supported in new triton MoE kernel" + + M, K = hidden_states.shape + E, _, N = w1.shape + n_expts_act = routing_data.n_expts_act + dtype = hidden_states.dtype + + if global_num_experts == -1: + global_num_experts = E + + # consistent with default implementation + intermediate_cache2 = torch.empty( + (M * n_expts_act, N // 2), device="cuda", dtype=dtype + ) + + intermediate_cache1 = matmul_ogs( + hidden_states, + w1, + None, + routing_data, + gather_indx=gather_indx, + gammas=routing_data.gate_scal if apply_router_weight_on_input else None, + ) + + if activation == "silu": + silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + elif activation == "gelu": + gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + intermediate_cache3 = matmul_ogs( + intermediate_cache2, + w2, + None, + routing_data, + scatter_indx=scatter_indx, + gammas=None if apply_router_weight_on_input else routing_data.gate_scal, + ) + + return intermediate_cache3 + + +def triton_kernel_moe_forward_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="forward_cuda_triton", + op_func=triton_kernel_moe_forward, + mutates_args=[], + fake_impl=triton_kernel_moe_forward_fake, +) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 7f6a641ef..2c03f8f67 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "triton_attention_reduce_in_fp32", "num_reserved_decode_tokens", "weight_loader_disable_mmap", + "enable_triton_kernel_moe", ] # Put some global args for easy access diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ce9710985..51e5ecc8b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -222,6 +222,7 @@ class ServerArgs: disable_chunked_prefix_cache: bool = False disable_fast_image_processor: bool = False enable_return_hidden_states: bool = False + enable_triton_kernel_moe: bool = False warmups: Optional[str] = None # Debug tensor dumps @@ -1554,6 +1555,11 @@ class ServerArgs: action="store_true", help="Enable returning hidden states with responses.", ) + parser.add_argument( + "--enable-triton-kernel-moe", + action="store_true", + help="Use triton moe grouped gemm kernel.", + ) parser.add_argument( "--warmups", type=str, diff --git a/test/srt/test_triton_fused_moe.py b/test/srt/test_triton_fused_moe.py new file mode 100644 index 000000000..8d014f6c7 --- /dev/null +++ b/test/srt/test_triton_fused_moe.py @@ -0,0 +1,146 @@ +import unittest + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( + triton_kernel_moe_forward, +) +from sglang.test.test_utils import CustomTestCase + + +class TestFusedMOE(CustomTestCase): + NUM_EXPERTS = [8, 64] + TOP_KS = [2, 4] + + @staticmethod + def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01): + """Create a random CUDA tensor + + Args: + shape: Tensor shape + dtype: Data type + mean: Mean value + std: Standard deviation + + Returns: + torch.Tensor: Randomly initialized CUDA tensor + """ + return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std) + + def get_tolerance(self, dtype): + """Get tolerance values for different data types + + Args: + dtype: Data type + + Returns: + tuple: (relative tolerance, absolute tolerance) + """ + if dtype == torch.float32: + return 1e-5, 1e-5 + elif dtype in [torch.float16, torch.bfloat16]: + return 1e-5, 1e-5 + else: + return 1e-2, 1e-2 # Default values for other types + + def torch_naive_moe( + self, + a, + w1, + w2, + score, + topk, + ): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + if w1.dtype == torch.float8_e4m3fn: + w1_compute = w1.to(a.dtype) + w2_compute = w2.to(a.dtype) + else: + w1_compute = w1 + w2_compute = w2 + + for i in range(w1_compute.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1_compute[i].transpose(0, 1) + ) @ w2_compute[i].transpose(0, 1) + + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + def _test_case(self, m, n, k, e, topk, dtype): + rtol, atol = self.get_tolerance(dtype) + + a = self.create_random_cuda_tensor((m, k), dtype) + w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) + w2 = self.create_random_cuda_tensor((e, k, n), dtype) + w1_tri = w1.clone() + w2_tri = w2.clone() + w1_tri = w1_tri.transpose(-2, -1).contiguous() + w2_tri = w2_tri.transpose(-2, -1).contiguous() + score = self.create_random_cuda_tensor((m, e), dtype) + + triton_output = triton_kernel_moe_forward( + a, w1_tri, w2_tri, score, topk, renormalize=False + ) + torch_output = self.torch_naive_moe(a, w1, w2, score, topk) + torch.testing.assert_close(triton_output, torch_output, rtol=rtol, atol=atol) + + def test_various_configurations(self): + m_values = [1, 32, 64, 256] + n_values = [128, 1024] + k_values = [128, 512, 1024] + dtypes = [torch.bfloat16] + + # Calculate total number of tests + total_tests = ( + len(m_values) + * len(n_values) + * len(k_values) + * len(self.NUM_EXPERTS) + * len(self.TOP_KS) + * len(dtypes) + ) + + # Create progress bar + with tqdm(total=total_tests, desc="Running MoE tests") as pbar: + for m in m_values: + for n in n_values: + for k in k_values: + for e in self.NUM_EXPERTS: + for topk in self.TOP_KS: + for dtype in dtypes: + with self.subTest( + m=m, + n=n, + k=k, + e=e, + topk=topk, + dtype=dtype, + ): + self._test_case( + m, + n, + k, + e, + topk, + dtype, + ) + torch.cuda.empty_cache() + pbar.update(1) + + +if __name__ == "__main__": + unittest.main()