diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index ac1a831ac..38f123247 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1295,6 +1295,9 @@ class DeepEPMoE(EPMoE): def get_moe_impl_class(): if global_server_args_dict["enable_deepep_moe"]: return DeepEPMoE + if global_server_args_dict["enable_flashinfer_moe"]: + # Must come before EPMoE because FusedMoE also supports enable_ep_moe + return FusedMoE if global_server_args_dict["enable_ep_moe"]: return EPMoE return FusedMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 7cf8de28a..6a82db210 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -314,6 +314,8 @@ class FusedMoE(torch.nn.Module): inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + enable_flashinfer_moe: Optional[bool] = False, + enable_ep_moe: Optional[bool] = False, ): super().__init__() @@ -324,9 +326,34 @@ class FusedMoE(torch.nn.Module): self.tp_size = ( tp_size if tp_size is not None else get_tensor_model_parallel_world_size() ) + self.tp_rank = get_tensor_model_parallel_rank() + self.num_experts = num_experts + self.expert_map = None + self.enable_flashinfer_moe = enable_flashinfer_moe + if enable_ep_moe: + assert ( + self.enable_flashinfer_moe + ), "FusedMoE only supports EP with --enable-flashinfer-moe" + self.ep_size = self.tp_size + self.ep_rank = self.tp_rank + self.tp_size = 1 + self.tp_rank = 0 + # Create a tensor of size num_experts filled with -1 + self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32) + # Create a expert map for the local experts + assert num_experts % self.ep_size == 0 + self.local_num_experts = num_experts // self.ep_size + self.expert_map[ + self.ep_rank + * self.local_num_experts : (self.ep_rank + 1) + * self.local_num_experts + ] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu") + else: + self.ep_size = 1 + self.ep_rank = 0 + self.local_num_experts = num_experts self.routed_scaling_factor = routed_scaling_factor self.top_k = top_k - self.num_experts = num_experts assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results @@ -344,7 +371,6 @@ class FusedMoE(torch.nn.Module): self.use_presharded_weights = use_presharded_weights self.inplace = inplace self.no_combine = no_combine - self.local_num_experts = num_experts if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -352,11 +378,13 @@ class FusedMoE(torch.nn.Module): ) else: self.quant_method = quant_config.get_quant_method(self, prefix) + if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod": + self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe assert self.quant_method is not None self.quant_method.create_weights( layer=self, - num_experts=num_experts, + num_experts=self.local_num_experts, hidden_size=hidden_size, # FIXME: figure out which intermediate_size to use intermediate_size=self.intermediate_size_per_partition, @@ -450,12 +478,15 @@ class FusedMoE(torch.nn.Module): # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. - if shard_id == "w1": - expert_data = expert_data.narrow(shard_dim, 0, shard_size) # w3, up_proj: Load into second logical weight of w13. + # trtllm cutlass kernel assumes differently + assert shard_id in ("w1", "w3") + switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False) + if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"): + start = shard_size else: - assert shard_id == "w3" - expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + start = 0 + expert_data = expert_data.narrow(shard_dim, start, shard_size) expert_data.copy_(loaded_weight) def _load_w2( @@ -509,6 +540,11 @@ class FusedMoE(torch.nn.Module): assert shard_id in ("w1", "w3") expert_data.copy_(loaded_weight) + def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: + if self.expert_map is None: + return expert_id + return self.expert_map[expert_id].item() + def weight_loader( self, param: torch.nn.Parameter, @@ -517,6 +553,13 @@ class FusedMoE(torch.nn.Module): shard_id: str, expert_id: int, ) -> None: + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) + if expert_id == -1: + return + + # TP rank is set to 0 if EP is enabled + tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank() + # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -541,7 +584,6 @@ class FusedMoE(torch.nn.Module): SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} expert_data = param.data[expert_id] - tp_rank = get_tensor_model_parallel_rank() # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors @@ -549,7 +591,7 @@ class FusedMoE(torch.nn.Module): is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: - shard_dim = ~shard_dim + shard_dim = int(not shard_dim) # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: @@ -690,9 +732,19 @@ class FusedMoE(torch.nn.Module): activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, routed_scaling_factor=self.routed_scaling_factor, + **( + dict( + tp_rank=self.tp_rank, + tp_size=self.tp_size, + ep_rank=self.ep_rank, + ep_size=self.ep_size, + ) + if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod" + else {} + ), ) - if self.reduce_results and self.tp_size > 1: + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index fed4d52dc..913a5bb99 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -29,11 +29,17 @@ from sglang.srt.layers.quantization.utils import ( requantize_with_max_scale, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import is_cuda +from sglang.srt.utils import is_cuda, next_power_of_2 if is_cuda(): from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant +try: + from flashinfer import fp4_quantize as fp4_quantize + from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe +except ImportError: + flashinfer_cutlass_fused_moe = None + # Initialize logger for the module logger = logging.getLogger(__name__) @@ -429,6 +435,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase): layer.alpha = Parameter( layer.input_scale * layer.weight_scale_2, requires_grad=False ) + layer.input_scale_inv = Parameter( + (1 / input_scale_2).to(torch.float32), requires_grad=False + ) # Pad and blockwise interleave weight_scale scales = layer.weight_scale @@ -467,7 +476,7 @@ class ModelOptFp4LinearMethod(LinearMethodBase): output_shape = [x_m, w_n] # Quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale) + x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv) assert x_fp4.dtype == torch.uint8 assert x_scale_interleaved.dtype == torch.float8_e4m3fn @@ -521,6 +530,7 @@ class ModelOptNvFp4FusedMoEMethod: " quantization. Please use Blackwell and" " above." ) + self.enable_flashinfer_moe = False def create_weights( self, @@ -674,7 +684,10 @@ class ModelOptNvFp4FusedMoEMethod: w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) - w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) + if self.enable_flashinfer_moe: + w13_input_scale = layer.w13_input_scale.max().to(torch.float32) + else: + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) layer.g1_alphas = Parameter( (w13_input_scale * w13_weight_scale_2).to(torch.float32), requires_grad=False, @@ -700,14 +713,19 @@ class ModelOptNvFp4FusedMoEMethod: layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) # GEMM 2 + if self.enable_flashinfer_moe: + w2_input_scale = layer.w2_input_scale.max().to(torch.float32) + else: + w2_input_scale = layer.w2_input_scale + layer.g2_alphas = Parameter( - (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), + (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), requires_grad=False, ) # This is for quantization, so we need to invert it. layer.w2_input_scale_quant = Parameter( - (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False + (1 / w2_input_scale).to(torch.float32), requires_grad=False ) assert ( @@ -727,11 +745,16 @@ class ModelOptNvFp4FusedMoEMethod: layer.cutlass_moe_params = CutlassMoEParams( CutlassMoEType.BlockscaledFP4, device, - num_experts=layer.num_experts, + num_experts=layer.num_experts, # global num experts intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n hidden_size=layer.w13_weight.shape[2] * 2, ) # k + @property + def load_up_proj_weight_first(self) -> bool: + # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 + return self.enable_flashinfer_moe + def apply( self, layer: torch.nn.Module, @@ -750,11 +773,13 @@ class ModelOptNvFp4FusedMoEMethod: inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, + ep_rank: Optional[int] = None, + ep_size: Optional[int] = None, + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." - - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts topk_weights, topk_ids = select_experts( @@ -771,6 +796,35 @@ class ModelOptNvFp4FusedMoEMethod: routed_scaling_factor=routed_scaling_factor, ) + if self.enable_flashinfer_moe: + assert ( + not apply_router_weight_on_input + ), "apply_router_weight_on_input is not supported for Flashinfer" + # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision + # and fp4 quantized weights loaded from the checkpoint + output = flashinfer_cutlass_fused_moe( + x, + topk_ids.to(torch.int), + topk_weights, + layer.w13_weight.view(torch.long), + layer.w2_weight.view(torch.long), + x.dtype, + quant_scales=[ + layer.w13_input_scale_quant, + layer.w13_blockscale_swizzled.view(torch.int32), + layer.g1_alphas, + layer.w2_input_scale_quant, + layer.w2_blockscale_swizzled.view(torch.int32), + layer.g2_alphas, + ], + ep_size=ep_size, + ep_rank=ep_rank, + tp_size=tp_size, + tp_rank=tp_rank, + tune_max_num_tokens=next_power_of_2(x.shape[0]), + ) + return output[0] + from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 return cutlass_moe_fp4( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3e9039e8c..035b11ca4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -86,6 +86,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "enable_deepep_moe", "deepep_mode", "enable_ep_moe", + "enable_flashinfer_moe", "moe_dense_tp_size", "ep_dispatch_algorithm", "deepep_config", diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0993d9682..864d27cc7 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -226,6 +226,7 @@ class DeepseekV2MoE(nn.Module): layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -238,6 +239,7 @@ class DeepseekV2MoE(nn.Module): ) self.config = config self.layer_id = layer_id + self.alt_stream = alt_stream if self.tp_size > config.n_routed_experts: raise ValueError( @@ -275,6 +277,15 @@ class DeepseekV2MoE(nn.Module): if global_server_args_dict["enable_deepep_moe"] else {} ), + # Additional args for FusedMoE + **( + dict( + enable_flashinfer_moe=True, + enable_ep_moe=global_server_args_dict["enable_ep_moe"], + ) + if global_server_args_dict["enable_flashinfer_moe"] + else {} + ), ) if config.n_shared_experts is not None and self.num_fused_shared_experts == 0: @@ -338,10 +349,36 @@ class DeepseekV2MoE(nn.Module): self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None ) -> torch.Tensor: if not self._enable_deepep_moe: - return self.forward_normal(hidden_states) + DUAL_STREAM_TOKEN_THRESHOLD = 1024 + if ( + self.alt_stream is not None + and self.num_fused_shared_experts == 0 + and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD + ): + return self.forward_normal_dual_stream(hidden_states) + else: + return self.forward_normal(hidden_states) else: return self.forward_deepep(hidden_states, forward_batch) + def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor: + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + shared_output = self._forward_shared_experts(hidden_states) + with torch.cuda.stream(self.alt_stream): + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + if not _is_cuda: + final_hidden_states *= self.routed_scaling_factor + current_stream.wait_stream(self.alt_stream) + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states + def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) @@ -1446,6 +1483,7 @@ class DeepseekV2DecoderLayer(nn.Module): quant_config=quant_config, prefix=add_prefix("mlp", prefix), layer_id=self.layer_id, + alt_stream=alt_stream, ) else: if enable_moe_dense_fully_dp(): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f4d110d16..a6a2a0b90 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -152,6 +152,7 @@ class ServerArgs: ep_size: int = 1 enable_ep_moe: bool = False enable_deepep_moe: bool = False + enable_flashinfer_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None @@ -244,7 +245,15 @@ class ServerArgs: logger.warning( f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) - + if self.enable_flashinfer_moe: + assert ( + self.quantization == "modelopt_fp4" + ), "modelopt_fp4 quantization is required for Flashinfer MOE" + os.environ["TRTLLM_ENABLE_PDL"] = "1" + self.disable_shared_experts_fusion = True + logger.warning( + f"Flashinfer MoE is enabled. Shared expert fusion is disabled." + ) # Set missing default values if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -1162,6 +1171,11 @@ class ServerArgs: action="store_true", help="Enabling expert parallelism for moe. The ep size is equal to the tp size.", ) + parser.add_argument( + "--enable-flashinfer-moe", + action="store_true", + help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe", + ) parser.add_argument( "--enable-deepep-moe", action="store_true",