diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 2c02a7463..ca0c2c5f0 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -38,6 +38,7 @@ from sglang.srt.utils import ( is_flashinfer_available, is_hip, next_power_of_2, + round_up, ) if is_flashinfer_available(): @@ -146,7 +147,6 @@ class FusedMoE(torch.nn.Module): self.layer_id = layer_id self.top_k = top_k - self.hidden_size = hidden_size self.num_experts = num_experts self.num_fused_shared_experts = num_fused_shared_experts self.expert_map_cpu = None @@ -206,6 +206,16 @@ class FusedMoE(torch.nn.Module): assert self.quant_method is not None self.quant_config = quant_config + if ( + self.quant_config is not None + and self.quant_config.get_name() == "mxfp4" + and ( + get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE") + or get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE") + ) + ): + hidden_size = round_up(hidden_size, 256) + self.hidden_size = hidden_size self.quant_method.create_weights( layer=self, num_experts=self.num_local_experts, @@ -784,6 +794,14 @@ class FusedMoE(torch.nn.Module): ) def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput): + origin_hidden_states_dim = hidden_states.shape[-1] + if self.hidden_size != origin_hidden_states_dim: + hidden_states = torch.nn.functional.pad( + hidden_states, + (0, self.hidden_size - origin_hidden_states_dim), + mode="constant", + value=0.0, + ) assert self.quant_method is not None if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe: @@ -829,7 +847,7 @@ class FusedMoE(torch.nn.Module): if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - return final_hidden_states + return final_hidden_states[..., :origin_hidden_states_dim].contiguous() @classmethod def make_expert_params_mapping( diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 7103cb8be..db5d23acc 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -21,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import ( direct_register_custom_op, + get_bool_env_var, is_cuda, is_flashinfer_available, is_hip, @@ -31,6 +32,12 @@ from sglang.srt.utils import ( has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None +# Environment variables for FlashInfer MXFP4 MoE backend +USE_FLASHINFER_MXFP4_MOE = get_bool_env_var("SGLANG_USE_FLASHINFER_MXFP4_MOE", "false") +USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var( + "SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false" +) + if is_flashinfer_available(): # from flashinfer.fused_moe import cutlass_fused_moe from flashinfer import ( @@ -228,16 +235,28 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.num_experts = num_experts weight_dtype = torch.uint8 scale_dtype = torch.uint8 - - intermediate_size *= 2 mxfp4_block = 32 - self.intermediate_size = intermediate_size + # pad the intermediate size to be a multiple of 2 * mxfp4_block + # for to hold non-uniform sharded tensor as well as swizzling + if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE: + intermediate_size_per_partition_after_pad = round_up(intermediate_size, 256) + hidden_size = round_up(hidden_size, 256) + elif is_hip(): + intermediate_size_per_partition_after_pad = round_up(intermediate_size, 128) + else: + intermediate_size_per_partition_after_pad = round_up(intermediate_size, 64) + + self.intermediate_size = intermediate_size_per_partition_after_pad + self.hidden_size = hidden_size # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.zeros( - num_experts, 2 * intermediate_size, hidden_size // 2, dtype=weight_dtype + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // 2, + dtype=weight_dtype, ), requires_grad=False, ) @@ -247,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w13_weight_scale = torch.nn.Parameter( torch.zeros( num_experts, - 2 * intermediate_size, + 2 * intermediate_size_per_partition_after_pad, hidden_size // mxfp4_block, dtype=scale_dtype, ), @@ -257,7 +276,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): set_weight_attrs(w13_weight_scale, extra_weight_attrs) w13_weight_bias = torch.nn.Parameter( - torch.zeros(num_experts, 2 * intermediate_size, dtype=torch.bfloat16), + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + dtype=torch.bfloat16, + ), requires_grad=False, ) layer.register_parameter("w13_weight_bias", w13_weight_bias) @@ -266,7 +289,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): # down_proj (row parallel) w2_weight = torch.nn.Parameter( torch.zeros( - num_experts, hidden_size, intermediate_size // 2, dtype=weight_dtype + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // 2, + dtype=weight_dtype, ), requires_grad=False, ) @@ -277,7 +303,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): torch.zeros( num_experts, hidden_size, - intermediate_size // mxfp4_block, + intermediate_size_per_partition_after_pad // mxfp4_block, dtype=scale_dtype, ), requires_grad=False, @@ -293,6 +319,158 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_weight_bias, extra_weight_attrs) def process_weights_after_loading(self, layer): + if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE: + logger.info( + "Shuffling MoE weights for FlashInfer, it might take a while..." + ) + layer.gemm1_alpha = Parameter( + torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_beta = Parameter( + torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_clamp_limit = Parameter( + torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + sf_block_size = 32 # mxfp4 block size + + assert ( + layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2 + ) + assert ( + layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size + ) + assert ( + layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size + and layer.w2_weight.shape[2] == self.intermediate_size // 2 + ) + assert ( + layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size + ) + assert ( + layer.w13_weight_bias.dim() == 2 + and layer.w13_weight_bias.shape[0] == self.num_experts + and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2 + ) + assert ( + layer.w2_weight_bias.dim() == 2 + and layer.w2_weight_bias.shape[0] == self.num_experts + and layer.w2_weight_bias.shape[1] == self.hidden_size + ) + + w13_weight_scale = layer.w13_weight_scale.data + w2_weight_scale = layer.w2_weight_scale.data + w13_weight = layer.w13_weight.data + w2_weight = layer.w2_weight.data + w13_bias = layer.w13_weight_bias.data.to(torch.float32) + w2_bias = layer.w2_weight_bias.data.to(torch.float32) + + # Swap w1 and w3 as the definition of + # swiglu is different in the trtllm-gen + def swap_every_two_rows(x, axis=-1): + shape = x.shape + if axis < 0: + axis = len(shape) + axis + + # Create a new shape with pairs swapped along specified axis + new_shape = list(shape) + new_shape[axis] = shape[axis] // 2 + new_shape.insert(axis + 1, 2) + + # Reshape to expose pairs, swap them, and reshape back + x = x.reshape(*new_shape) + x = x.flip(axis + 1) + new_shape = list(shape) + return x.reshape(*new_shape) + + w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2) + w13_weight = swap_every_two_rows(w13_weight, -2) + w13_bias = swap_every_two_rows(w13_bias, -1) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_mxfp4_shuffled = [] + gemm1_scales_mxfp4_shuffled = [] + gemm2_weights_mxfp4_shuffled = [] + gemm2_scales_mxfp4_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + for i in range(self.num_experts): + gemm1_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m) + ) + gemm1_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a( + w13_weight_scale[i].view(torch.uint8), epilogue_tile_m + ) + ) + gemm1_bias_shuffled.append( + shuffle_matrix_a( + w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m + ) + ) + + gemm2_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m) + ) + gemm2_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a( + w2_weight_scale[i].view(torch.uint8), epilogue_tile_m + ) + ) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m) + ) + + w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) + w13_weight_scale = ( + torch.stack(gemm1_scales_mxfp4_shuffled) + .reshape( + self.num_experts, + 2 * self.intermediate_size, + self.hidden_size // sf_block_size, + ) + .view(torch.float8_e4m3fn) + ) + + w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) + w2_weight_scale = ( + torch.stack(gemm2_scales_mxfp4_shuffled) + .reshape( + self.num_experts, + self.hidden_size, + self.intermediate_size // sf_block_size, + ) + .view(torch.float8_e4m3fn) + ) + + layer.w13_weight = Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False) + layer.w2_weight = Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False) + layer.w13_weight_bias = Parameter( + torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1), + requires_grad=False, + ) + layer.w2_weight_bias = Parameter( + torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1), + requires_grad=False, + ) + return from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -366,22 +544,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): activation_alpha: Optional[float] = None, swiglu_limit: Optional[float] = None, ) -> torch.Tensor: - # avoid import error when triton_kernel is not installed - # from vllm.model_executor.layers.fused_moe.triton_kernels_moe import ( - # triton_kernel_moe_forward) - - """ - if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE - or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE): - assert not self.moe.use_ep, ( - "EP is not supported for flashinfer mxfp4 moe backend yet.") - if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE: + if USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE: + # When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input, + # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations, + # which can theoretically improve performance + if USE_FLASHINFER_MXFP4_BF16_MOE: assert x.dtype == torch.bfloat16 x_quant = x x_scale = None else: x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + + topk_weights, topk_ids, router_logits = topk_output + top_k = topk_weights.shape[-1] + trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias @@ -412,7 +589,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): True, # do finalize )[0] return trtllm_gen_output - """ if self.use_triton_kernels: if self.with_bias: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2623a1027..6412398bb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -464,7 +464,21 @@ class ServerArgs: model_arch = self.get_hf_config().architectures[0] if model_arch in ["GptOssForCausalLM"]: self.attention_backend = "triton" - self.enable_triton_kernel_moe = True + + # Check if FlashInfer MXFP4 MoE is enabled + from sglang.srt.utils import get_bool_env_var + + USE_FLASHINFER_MXFP4_MOE = get_bool_env_var( + "SGLANG_USE_FLASHINFER_MXFP4_MOE", "false" + ) + USE_FLASHINFER_MXFP4_BF16_MOE = get_bool_env_var( + "SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE", "false" + ) + + # Only enable Triton kernel MoE if FlashInfer is not enabled + if not (USE_FLASHINFER_MXFP4_MOE or USE_FLASHINFER_MXFP4_BF16_MOE): + self.enable_triton_kernel_moe = True + self.disable_hybrid_swa_memory = True quantization_config = getattr(