From 1d24db834803994739df6eaed139083972b656a1 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Fri, 8 Aug 2025 00:46:42 -0700 Subject: [PATCH] Expert Parallelism for GPT-OSS (#8944) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 6 + .../layers/moe/fused_moe_triton/fused_moe.py | 113 +++++++++++++-- .../srt/layers/moe/fused_moe_triton/layer.py | 6 +- .../sglang/srt/layers/quantization/mxfp4.py | 132 +++++++++++------- .../sglang/srt/layers/quantization/unquant.py | 11 +- python/sglang/srt/models/gpt_oss.py | 101 +++++++------- python/sglang/srt/server_args.py | 14 +- python/sglang/srt/utils.py | 5 + 8 files changed, 269 insertions(+), 119 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 862561804..1dd097b4e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -76,6 +76,9 @@ class EPMoE(FusedMoE): prefix: str = "", activation: str = "silu", routed_scaling_factor: Optional[float] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, + with_bias: bool = False, ): super().__init__( num_experts=num_experts, @@ -91,6 +94,9 @@ class EPMoE(FusedMoE): activation=activation, # apply_router_weight_on_input=apply_router_weight_on_input, routed_scaling_factor=routed_scaling_factor, + activation_alpha=activation_alpha, + swiglu_limit=swiglu_limit, + with_bias=with_bias, ) self.start_expert_id = self.moe_ep_rank * self.num_local_experts diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index d2c65d973..2cd0099b4 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -319,6 +319,7 @@ def fused_moe_kernel( # Pointers to matrices a_ptr, b_ptr, + bias_ptr, c_ptr, a_scale_ptr, b_scale_ptr, @@ -340,6 +341,8 @@ def fused_moe_kernel( stride_be, stride_bk, stride_bn, + stride_bias_e, + stride_bias_n, stride_cm, stride_cn, stride_asm, @@ -449,6 +452,10 @@ def fused_moe_kernel( + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) ) + if bias_ptr is not None: + bias = tl.load( + bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n + ) if use_int8_w8a16: b_scale_ptrs = ( b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn @@ -526,18 +533,20 @@ def fused_moe_kernel( a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk + if use_int8_w8a16: + accumulator *= b_scale + elif use_fp8_w8a8 or use_int8_w8a8: + if group_k == 0 or group_n == 0: + accumulator *= a_scale * b_scale + + if bias_ptr is not None: + accumulator += bias + if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) - accumulator = accumulator * moe_weight[:, None] - if use_int8_w8a16: - accumulator = (accumulator * b_scale).to(compute_type) - elif use_fp8_w8a8 or use_int8_w8a8: - if group_k > 0 and group_n > 0: - accumulator = accumulator.to(compute_type) - else: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) - else: - accumulator = accumulator.to(compute_type) + accumulator *= moe_weight[:, None] + + accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -622,6 +631,7 @@ def moe_align_block_size( def invoke_fused_moe_kernel( A: torch.Tensor, B: torch.Tensor, + bias: Optional[torch.Tensor], C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], @@ -711,6 +721,7 @@ def invoke_fused_moe_kernel( ): assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 + assert bias is None fused_moe_kernel_gptq_awq[grid]( A, B, @@ -754,6 +765,7 @@ def invoke_fused_moe_kernel( fused_moe_kernel[grid]( A, B, + bias, C, A_scale, B_scale, @@ -770,6 +782,8 @@ def invoke_fused_moe_kernel( B.stride(0), B.stride(2), B.stride(1), + bias.stride(0) if bias is not None else 0, + bias.stride(1) if bias is not None else 0, C.stride(1), C.stride(2), A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, @@ -994,6 +1008,8 @@ def inplace_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + b1: Optional[torch.Tensor] = None, + b2: Optional[torch.Tensor] = None, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, @@ -1009,6 +1025,8 @@ def inplace_fused_experts( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, routed_scaling_factor: Optional[float] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, ) -> None: fused_experts_impl( hidden_states, @@ -1016,6 +1034,8 @@ def inplace_fused_experts( w2, topk_weights, topk_ids, + b1, + b2, True, activation, apply_router_weight_on_input, @@ -1033,6 +1053,8 @@ def inplace_fused_experts( block_shape, False, routed_scaling_factor, + activation_alpha, + swiglu_limit, ) @@ -1042,6 +1064,8 @@ def inplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + b1: Optional[torch.Tensor] = None, + b2: Optional[torch.Tensor] = None, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, @@ -1057,6 +1081,8 @@ def inplace_fused_experts_fake( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, routed_scaling_factor: Optional[float] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, ) -> None: pass @@ -1075,6 +1101,8 @@ def outplace_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + b1: Optional[torch.Tensor] = None, + b2: Optional[torch.Tensor] = None, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, @@ -1091,6 +1119,8 @@ def outplace_fused_experts( block_shape: Optional[List[int]] = None, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -1098,6 +1128,8 @@ def outplace_fused_experts( w2, topk_weights, topk_ids, + b1, + b2, False, activation, apply_router_weight_on_input, @@ -1115,6 +1147,8 @@ def outplace_fused_experts( block_shape, no_combine=no_combine, routed_scaling_factor=routed_scaling_factor, + activation_alpha=activation_alpha, + swiglu_limit=swiglu_limit, ) @@ -1124,6 +1158,8 @@ def outplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + b1: Optional[torch.Tensor] = None, + b2: Optional[torch.Tensor] = None, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, @@ -1140,6 +1176,8 @@ def outplace_fused_experts_fake( block_shape: Optional[List[int]] = None, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1157,6 +1195,8 @@ def fused_experts( w1: torch.Tensor, w2: torch.Tensor, topk_output: TopKOutput, + b1: Optional[torch.Tensor] = None, + b2: Optional[torch.Tensor] = None, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, @@ -1174,6 +1214,8 @@ def fused_experts( block_shape: Optional[List[int]] = None, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, ): topk_weights, topk_ids, _ = topk_output if inplace: @@ -1184,6 +1226,8 @@ def fused_experts( w2, topk_weights, topk_ids, + b1, + b2, activation, apply_router_weight_on_input, use_fp8_w8a8, @@ -1199,6 +1243,8 @@ def fused_experts( a2_scale, block_shape, routed_scaling_factor, + activation_alpha, + swiglu_limit, ) return hidden_states else: @@ -1208,6 +1254,8 @@ def fused_experts( w2, topk_weights, topk_ids, + b1, + b2, activation, apply_router_weight_on_input, use_fp8_w8a8, @@ -1224,6 +1272,8 @@ def fused_experts( block_shape, no_combine=no_combine, routed_scaling_factor=routed_scaling_factor, + activation_alpha=activation_alpha, + swiglu_limit=swiglu_limit, ) @@ -1319,12 +1369,22 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor): out.mul_(routed_scaling_factor) +@torch.compile +def swiglu_with_alpha_and_limit(x, alpha, limit): + gate, up = x[..., ::2], x[..., 1::2] + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + return gate * torch.sigmoid(gate * alpha) * (up + 1) + + def fused_experts_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + b1: Optional[torch.Tensor] = None, + b2: Optional[torch.Tensor] = None, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, @@ -1342,6 +1402,8 @@ def fused_experts_impl( block_shape: Optional[List[int]] = None, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, ): padded_size = padding_size if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: @@ -1353,7 +1415,7 @@ def fused_experts_impl( else: assert ( hidden_states.shape[1] == w1.shape[2] - padded_size - ), "Hidden size mismatch" + ), f"Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -1449,6 +1511,7 @@ def fused_experts_impl( invoke_fused_moe_kernel( curr_hidden_states, w1, + b1, intermediate_cache1, a1_scale, w1_scale, @@ -1470,13 +1533,24 @@ def fused_experts_impl( block_shape=block_shape, ) if activation == "silu": - if _is_cuda: + if activation_alpha is not None: + assert swiglu_limit is not None + intermediate_cache2 = swiglu_with_alpha_and_limit( + intermediate_cache1.view(-1, N), + activation_alpha, + swiglu_limit, + ) + elif _is_cuda: silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: vllm_ops.silu_and_mul( intermediate_cache2, intermediate_cache1.view(-1, N) ) elif activation == "gelu": + assert ( + activation_alpha is None + ), "activation_alpha is not supported for gelu" + assert swiglu_limit is None, "swiglu_limit is not supported for gelu" if _is_cuda: gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: @@ -1489,6 +1563,7 @@ def fused_experts_impl( invoke_fused_moe_kernel( intermediate_cache2, w2, + b2, ( intermediate_cache3 if not no_combine and topk_ids.shape[1] != 1 @@ -1567,6 +1642,8 @@ def fused_moe( w1: torch.Tensor, w2: torch.Tensor, topk_output: TopKOutput, + b1: Optional[torch.Tensor] = None, + b2: Optional[torch.Tensor] = None, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, @@ -1584,6 +1661,8 @@ def fused_moe( block_shape: Optional[List[int]] = None, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -1594,6 +1673,8 @@ def fused_moe( - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_output (TopKOutput): The top-k output of the experts. + - b1 (Optional[torch.Tensor]): Optional bias for w1. + - b2 (Optional[torch.Tensor]): Optional bias for w2. - inplace (bool): If True, perform the operation in-place. Defaults to False. - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner @@ -1615,6 +1696,10 @@ def fused_moe( a2. - block_shape: (Optional[List[int]]): Optional block size for block-wise quantization. + - activation_alpha (Optional[float]): Optional alpha for the activation + function. + - swiglu_limit (Optional[float]): Optional limit for the swiglu activation + function. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -1625,6 +1710,8 @@ def fused_moe( w1, w2, topk_output, + b1=b1, + b2=b2, inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -1642,4 +1729,6 @@ def fused_moe( block_shape=block_shape, no_combine=no_combine, routed_scaling_factor=routed_scaling_factor, + activation_alpha=activation_alpha, + swiglu_limit=swiglu_limit, ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index ec702ddb9..9bf97b690 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -199,7 +199,7 @@ class FusedMoE(torch.nn.Module): if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( - self.use_triton_kernels, with_bias=with_bias + self.use_triton_kernels ) else: self.quant_method = quant_config.get_quant_method(self, prefix) @@ -809,7 +809,9 @@ class FusedMoE(torch.nn.Module): # If we are in EP mode, we need to move the expert map to GPU. self.expert_map_gpu = self.expert_map_cpu.to(device="cuda") - if self.expert_map_gpu is not None: + if self.expert_map_gpu is not None and isinstance( + topk_output, StandardTopKOutput + ): topk_output = topk_output._replace( topk_ids=self.expert_map_gpu[topk_output.topk_ids] ) diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 619f0bfc9..62bfaf887 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -8,6 +8,7 @@ import logging from typing import TYPE_CHECKING, List, Optional import torch +import triton.language as tl from torch.nn.parameter import Parameter from sglang.srt.layers.quantization.base_config import ( @@ -24,6 +25,7 @@ from sglang.srt.utils import ( is_cuda, is_flashinfer_available, is_hip, + is_triton_kernels_available, log_info_on_rank0, next_power_of_2, round_up, @@ -31,7 +33,7 @@ from sglang.srt.utils import ( ) _is_sm100_supported = is_cuda() and is_sm100_supported() -has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None +has_triton_kernels = is_triton_kernels_available() if is_flashinfer_available(): @@ -188,12 +190,7 @@ class Mxfp4Config(QuantizationConfig): ): return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): - use_flashinfer = global_server_args_dict.get( - "enable_flashinfer_mxfp4_moe", False - ) - return Mxfp4MoEMethod( - use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer - ) + return Mxfp4MoEMethod(prefix) else: raise NotImplementedError("Mxfp4 attention layer is not implemented") return None @@ -206,15 +203,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__( self, - use_triton_kernels: bool = True, - with_bias: bool = True, - use_flashinfer: bool = False, + prefix: str, ): + from sglang.srt.managers.schedule_batch import global_server_args_dict + super().__init__() + self.topk_indices_dtype = None - self.use_triton_kernels = use_triton_kernels - self.with_bias = with_bias - self.use_flashinfer = use_flashinfer + self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"] + self.with_bias = False + self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"] self.triton_kernel_moe_forward = None self.triton_kernel_moe_with_bias_forward = None @@ -236,12 +234,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, + with_bias: bool = False, **extra_weight_attrs, ): - # print(f"hi {self=} create_weights {layer=}") self.num_experts = num_experts weight_dtype = torch.uint8 scale_dtype = torch.uint8 + self.with_bias = with_bias mxfp4_block = 32 # pad the intermediate size to be a multiple of 2 * mxfp4_block @@ -264,7 +263,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.zeros( - num_experts, + layer.num_local_experts, 2 * intermediate_size_per_partition_after_pad, hidden_size // 2, dtype=weight_dtype, @@ -276,7 +275,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w13_weight_scale = torch.nn.Parameter( torch.zeros( - num_experts, + layer.num_local_experts, 2 * intermediate_size_per_partition_after_pad, hidden_size // mxfp4_block, dtype=scale_dtype, @@ -288,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w13_weight_bias = torch.nn.Parameter( torch.zeros( - num_experts, + layer.num_local_experts, 2 * intermediate_size_per_partition_after_pad, dtype=torch.bfloat16, ), @@ -300,7 +299,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): # down_proj (row parallel) w2_weight = torch.nn.Parameter( torch.zeros( - num_experts, + layer.num_local_experts, hidden_size, intermediate_size_per_partition_after_pad // 2, dtype=weight_dtype, @@ -312,7 +311,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w2_weight_scale = torch.nn.Parameter( torch.zeros( - num_experts, + layer.num_local_experts, hidden_size, intermediate_size_per_partition_after_pad // mxfp4_block, dtype=scale_dtype, @@ -323,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_weight_scale, extra_weight_attrs) w2_weight_bias = torch.nn.Parameter( - torch.zeros(num_experts, hidden_size, dtype=torch.bfloat16), + torch.zeros(layer.num_local_experts, hidden_size, dtype=torch.bfloat16), requires_grad=False, ) layer.register_parameter("w2_weight_bias", w2_weight_bias) @@ -484,38 +483,51 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ) return - from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + if self.use_triton_kernels: - w13_weight_bias = layer.w13_weight_bias.to(torch.float32) - w2_weight_bias = layer.w2_weight_bias.to(torch.float32) + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig - layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False) - layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False) + w13_weight_bias = layer.w13_weight_bias.to(torch.float32) + w2_weight_bias = layer.w2_weight_bias.to(torch.float32) - num_warps = 8 + layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False) + layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False) - w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( - layer.w13_weight, layer.w13_weight_scale, num_warps - ) - w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( - layer.w2_weight, layer.w2_weight_scale, num_warps - ) + num_warps = 8 - self.w13_precision_config = PrecisionConfig( - weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex) - ) - self.w2_precision_config = PrecisionConfig( - weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) - ) + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + layer.w13_weight, layer.w13_weight_scale, num_warps + ) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( + layer.w2_weight, layer.w2_weight_scale, num_warps + ) - self.w13_weight_triton_tensor = w13_weight - self.w2_weight_triton_tensor = w2_weight + self.w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex) + ) + self.w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) + ) - # need to delete the original weights to save memory on single GPU - del layer.w13_weight - del layer.w2_weight - layer.w13_weight = None - layer.w2_weight = None + self.w13_weight_triton_tensor = w13_weight + self.w2_weight_triton_tensor = w2_weight + del layer.w13_weight + del layer.w2_weight + else: + from triton_kernels.numerics_details.mxfp import upcast_from_mxfp + + w13_weight = upcast_from_mxfp( + layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1 + ) + w2_weight = upcast_from_mxfp( + layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1 + ) + del layer.w13_weight + del layer.w2_weight + del layer.w13_weight_scale + del layer.w2_weight_scale + layer.w13_weight = Parameter(w13_weight.data, requires_grad=False) + layer.w2_weight = Parameter(w2_weight.data, requires_grad=False) torch.cuda.empty_cache() def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): @@ -580,13 +592,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): None, # output1_scale_scalar None, # output1_scale_gate_scalar None, # output2_scale_scalar - self.num_experts, + layer.num_experts, top_k, None, # n_group None, # topk_group self.intermediate_size, # padded to multiple of 256 - 0, # local_expert_offset - self.num_experts, # local num experts + layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset + layer.num_local_experts, # local num experts None, self._get_tile_tokens_dim(x, top_k), 1, # routing_method_type, renormalize @@ -595,10 +607,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): return trtllm_gen_output if self.use_triton_kernels: + assert ( + layer.moe_ep_size == 1 + ), "Expert parallel is not supported when using triton kernels" if self.with_bias: - # TODO why we do not put weights on layer? - assert layer.w13_weight is None - assert layer.w2_weight is None return self.triton_kernel_moe_with_bias_forward( hidden_states=x, w1=self.w13_weight_triton_tensor, @@ -620,4 +632,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): topk_output=topk_output, ) else: - raise NotImplementedError() + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_output=topk_output, + b1=layer.w13_weight_bias, + b2=layer.w2_weight_bias, + inplace=inplace, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, + activation_alpha=activation_alpha, + swiglu_limit=swiglu_limit, + ) diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index c5558e3c1..9c33e3173 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -126,10 +126,10 @@ class UnquantizedLinearMethod(LinearMethodBase): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" - def __init__(self, use_triton_kernels: bool = False, with_bias: bool = False): + def __init__(self, use_triton_kernels: bool = False): super().__init__() self.use_triton_kernels = use_triton_kernels - self.with_bias = with_bias + self.with_bias = False self.triton_kernel_moe_forward = None self.triton_kernel_moe_with_bias_forward = None @@ -151,8 +151,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, + with_bias: bool = False, **extra_weight_attrs, ): + self.with_bias = with_bias + # Fused gate_up_proj (column parallel) w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size if self.use_triton_kernels: @@ -319,12 +322,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, + b1=getattr(layer, "w13_weight_bias", None), + b2=getattr(layer, "w2_weight_bias", None), topk_output=topk_output, inplace=inplace and not no_combine, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, no_combine=no_combine, routed_scaling_factor=routed_scaling_factor, + activation_alpha=activation_alpha, + swiglu_limit=swiglu_limit, ) def forward_cpu( diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index fd9d9441c..98ef0775e 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -28,6 +28,7 @@ from sglang.srt.distributed import ( get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size, get_moe_tensor_parallel_rank, + get_moe_tensor_parallel_world_size, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -96,11 +97,6 @@ class GptOssSparseMoeBlock(nn.Module): self.activation = config.hidden_act self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702) self.swiglu_limit = config.swiglu_limit - if self.tp_size > config.num_local_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_local_experts}." - ) if global_server_args_dict["enable_flashinfer_mxfp4_moe"]: self.topk = None @@ -708,22 +704,26 @@ class GptOssForCausalLM(nn.Module): loaded_params: set[str] = set() mxfp4_block = 32 - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() + moe_tp_rank = get_moe_tensor_parallel_rank() + moe_tp_size = get_moe_tensor_parallel_world_size() + moe_ep_rank = get_moe_expert_parallel_rank() + moe_ep_size = get_moe_expert_parallel_world_size() + intermediate_size = self.config.intermediate_size intermediate_size_block = intermediate_size // mxfp4_block - per_rank_intermediate_size_block = intermediate_size_block // tp_size + per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block # Calculate common slicing bounds for current rank - tp_rank_start = tp_rank * per_rank_intermediate_size - tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) - - # Attention heads per rank - heads_per_rank = self.config.num_attention_heads // tp_size - head_start = tp_rank * heads_per_rank - - num_experts = self.config.num_local_experts + assert self.config.num_local_experts % moe_ep_size == 0 + moe_num_global_experts = self.config.num_local_experts + moe_num_local_experts = self.config.num_local_experts // moe_ep_size + moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size + moe_tp_rank_end = min( + (moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size + ) + moe_ep_rank_start = moe_ep_rank * moe_num_local_experts + moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts for name, weight in weights: weight = weight.cuda() @@ -735,10 +735,14 @@ class GptOssForCausalLM(nn.Module): # flat weight from (E, 2 * N, block_size, entry_per_block) # to (E, 2 * N, -1), shouldn't trigger copy for contiguous weight = weight.view( - num_experts, 2 * intermediate_size, -1 + moe_num_global_experts, 2 * intermediate_size, -1 ).contiguous() - narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, + ..., + ] param = params_dict[new_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -757,9 +761,13 @@ class GptOssForCausalLM(nn.Module): # same flatten here, but since 2 mx4 value are packed in 1 # uint8, divide by 2 weight = weight.view( - num_experts, -1, intermediate_size // 2 + moe_num_global_experts, -1, intermediate_size // 2 ).contiguous() - narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2] + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + ..., + moe_tp_rank_start // 2 : moe_tp_rank_end // 2, + ] param = params_dict[new_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -775,7 +783,11 @@ class GptOssForCausalLM(nn.Module): elif "gate_up_proj_scales" in name: # Handle MLP gate and up projection weights scale new_name = name.replace("gate_up_proj_scales", "w13_weight_scale") - narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, + ..., + ] param = params_dict[new_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -792,7 +804,9 @@ class GptOssForCausalLM(nn.Module): # Handle MLP down projection weights new_name = name.replace("down_proj_scales", "w2_weight_scale") narrow_weight = weight[ - ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block + moe_ep_rank_start:moe_ep_rank_end, + ..., + moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block, ] param = params_dict[new_name] @@ -809,7 +823,10 @@ class GptOssForCausalLM(nn.Module): # Handle MLP gate and up projection biases new_name = name.replace("gate_up_proj_bias", "w13_weight_bias") - narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] + narrow_weight = weight[ + moe_ep_rank_start:moe_ep_rank_end, + 2 * moe_tp_rank_start : 2 * moe_tp_rank_end, + ] param = params_dict[new_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -823,15 +840,20 @@ class GptOssForCausalLM(nn.Module): loaded_params.add(new_name) elif "down_proj_bias" in name: - if get_moe_tensor_parallel_rank() != 0: - weight = torch.zeros_like(weight) + narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...] + if moe_tp_rank != 0: + narrow_weight = torch.zeros_like(narrow_weight) # Handle MLP down projection bias new_name = name.replace("down_proj_bias", "w2_weight_bias") param = params_dict[new_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader( - param, weight, weight_name=new_name, shard_id=None, expert_id=None + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, ) loaded_params.add(new_name) @@ -910,27 +932,12 @@ class GptOssForCausalLM(nn.Module): ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] - - if self.quant_config is not None and (self.quant_config.get_name() == "mxfp4"): - expert_params_mapping = ( - get_moe_impl_class().make_expert_params_mapping_fused_mxfp4( - ckpt_gate_up_proj_name="gate_up_proj_blocks", - ckpt_down_proj_name="down_proj_blocks", - ckpt_gate_up_proj_bias_name="gate_up_proj_bias", - ckpt_down_proj_bias_name="down_proj_bias", - ckpt_gate_up_proj_scale_name="gate_up_proj_scales", - ckpt_down_proj_scale_name="down_proj_scales", - ) - ) - else: - expert_params_mapping = ( - get_moe_impl_class().make_expert_params_mapping_fused( - ckpt_gate_up_proj_name="gate_up_proj", - ckpt_down_proj_name="down_proj", - ckpt_gate_up_proj_bias_name="gate_up_proj_bias", - ckpt_down_proj_bias_name="down_proj_bias", - ) - ) + expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused( + ckpt_gate_up_proj_name="gate_up_proj", + ckpt_down_proj_name="down_proj", + ckpt_gate_up_proj_bias_name="gate_up_proj_bias", + ckpt_down_proj_bias_name="down_proj_bias", + ) params_dict = dict(self.named_parameters()) params_checker = {k: False for k, v in params_dict.items()} diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9d1839ff4..3a5c22fb3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -37,6 +37,7 @@ from sglang.srt.utils import ( is_hip, is_port_available, is_remote_url, + is_triton_kernels_available, is_valid_ipv6_address, nullable_str, ) @@ -492,10 +493,15 @@ class ServerArgs: "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel." ) else: - self.enable_triton_kernel_moe = True - logger.info( - "Detected GPT-OSS model, enabling triton_kernels MOE kernel." - ) + if self.enable_triton_kernel_moe: + assert ( + self.ep_size == 1 + ), "Triton kernel MoE is only supported when ep_size == 1" + if not self.enable_triton_kernel_moe and self.ep_size == 1: + self.enable_triton_kernel_moe = True + logger.info( + "Detected GPT-OSS model, enabling triton_kernels MOE kernel." + ) self.disable_hybrid_swa_memory = True diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 1e07a4136..edf441945 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2961,3 +2961,8 @@ class ConcurrentCounter: other tasks to run while waiting. When the counter becomes zero, the coroutine resumes. """ self.wait_for(lambda count: count == 0) + + +@lru_cache(maxsize=1) +def is_triton_kernels_available() -> bool: + return importlib.util.find_spec("triton_kernels") is not None