diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 83f74fb27..80fbadd57 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -47,12 +47,17 @@ from sglang.srt.utils import ( get_bool_env_var, is_hip, is_npu, + next_power_of_2, ) _is_hip = is_hip() _is_npu = is_npu() _is_fp8_fnuz = is_fp8_fnuz() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip +use_flashinfer_trtllm_moe = ( + global_server_args_dict["enable_flashinfer_trtllm_moe"] + and global_server_args_dict["enable_ep_moe"] +) if not (_is_npu or _is_hip): from sgl_kernel import silu_and_mul @@ -64,6 +69,13 @@ if _use_aiter: from aiter.fused_moe import fused_moe from aiter.ops.shuffle import shuffle_weight +if use_flashinfer_trtllm_moe: + try: + import flashinfer.fused_moe as fi_fused_moe + except ImportError: + fi_fused_moe = None + use_flashinfer_trtllm_moe = False + logger = logging.getLogger(__name__) @@ -140,6 +152,16 @@ class GroupedGemmRunner(torch.nn.Module): return c +def _get_tile_tokens_dim(num_tokens, top_k, num_experts): + # Guess tokens per expert assuming perfect expert distribution first. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + class EPMoE(torch.nn.Module): """ MoE Expert Parallel Impl @@ -776,14 +798,20 @@ class EPMoE(torch.nn.Module): ) return - if shard_id == "w2": + # Flashinfer assumes w31 format for w13_weight. Same for the scales. + if use_flashinfer_trtllm_moe: + actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id] + else: + actual_shard_id = shard_id + + if actual_shard_id == "w2": param.data[expert_id] = loaded_weight - elif shard_id == "w1": + elif actual_shard_id == "w1": param.data[expert_id][: self.intermediate_size, :] = loaded_weight - elif shard_id == "w3": + elif actual_shard_id == "w3": param.data[expert_id][self.intermediate_size :, :] = loaded_weight else: - raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") + raise ValueError(f"Expected shard_id w1,w2 or w3 but got {actual_shard_id}") def _load_fp8_scale( self, @@ -820,12 +848,18 @@ class EPMoE(torch.nn.Module): # Weight scales elif "weight_scale" in weight_name: if self.use_block_quant: + if use_flashinfer_trtllm_moe: + actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id] + else: + actual_shard_id = shard_id + block_n, block_k = self.block_shape[0], self.block_shape[1] - if shard_id == "w1": + + if actual_shard_id == "w1": param_data[expert_id][ : (self.intermediate_size + block_n - 1) // block_n, : ] = loaded_weight - elif shard_id == "w3": + elif actual_shard_id == "w3": param_data[expert_id][ (self.intermediate_size + block_n - 1) // block_n :, : ] = loaded_weight @@ -1315,12 +1349,73 @@ class DeepEPMoE(EPMoE): return down_output +class FlashInferEPMoE(EPMoE): + def __init__(self, *args, **kwargs): + renormalize = kwargs.pop("renormalize", True) + num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0) + use_grouped_topk = kwargs.pop("use_grouped_topk", False) + num_expert_group = kwargs.pop("num_expert_group", None) + topk_group = kwargs.pop("topk_group", None) + correction_bias = kwargs.pop("correction_bias", None) + super().__init__(*args, **kwargs) + self.renormalize = renormalize + self.num_fused_shared_experts = num_fused_shared_experts + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.correction_bias = correction_bias + + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + assert use_flashinfer_trtllm_moe + assert ( + self.activation == "silu" + ), "Only silu is supported for flashinfer blockscale fp8 moe" + assert ( + self.renormalize + ), "Renormalize is required for flashinfer blockscale fp8 moe" + assert ( + self.num_fused_shared_experts == 0 + ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe" + a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1]) + # NOTE: scales of hidden states have to be transposed! + a_sf_t = a_sf.t().contiguous() + assert fi_fused_moe is not None + return fi_fused_moe.trtllm_fp8_block_scale_moe( + routing_logits=router_logits.to(torch.float32), + routing_bias=self.correction_bias.to(hidden_states.dtype), + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=self.w13_weight, + gemm1_weights_scale=self.w13_weight_scale_inv, + gemm2_weights=self.w2_weight, + gemm2_weights_scale=self.w2_weight_scale_inv, + num_experts=self.num_experts, + top_k=self.top_k, + n_group=self.num_expert_group, + topk_group=self.topk_group, + intermediate_size=self.w2_weight.shape[2], + local_expert_offset=self.start_expert_id, + local_num_experts=self.num_experts_per_partition, + routed_scaling_factor=self.routed_scaling_factor, + tile_tokens_dim=_get_tile_tokens_dim( + hidden_states.shape[0], self.top_k, self.num_experts + ), + routing_method_type=2, # DeepSeek-styled routing method + use_shuffled_weight=False, + ) + + def get_moe_impl_class(): if global_server_args_dict["enable_deepep_moe"]: return DeepEPMoE - if global_server_args_dict["enable_flashinfer_moe"]: + if global_server_args_dict["enable_flashinfer_cutlass_moe"]: # Must come before EPMoE because FusedMoE also supports enable_ep_moe return FusedMoE + if use_flashinfer_trtllm_moe: + # Must come before EPMoE because FusedMoE also supports enable_ep_moe + return FlashInferEPMoE if global_server_args_dict["enable_ep_moe"]: return EPMoE return FusedMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 0c3cb0422..5983a6beb 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -75,7 +75,7 @@ class FusedMoE(torch.nn.Module): inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, - enable_flashinfer_moe: Optional[bool] = False, + enable_flashinfer_cutlass_moe: Optional[bool] = False, enable_ep_moe: Optional[bool] = False, ): super().__init__() @@ -92,16 +92,16 @@ class FusedMoE(torch.nn.Module): self.num_experts = num_experts self.expert_map = None - if enable_flashinfer_moe and quant_config is None: + if enable_flashinfer_cutlass_moe and quant_config is None: logger.warning("Disable flashinfer MoE when quantization config is None.") - enable_flashinfer_moe = False + enable_flashinfer_cutlass_moe = False enable_ep_moe = False - self.enable_flashinfer_moe = enable_flashinfer_moe + self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe if enable_ep_moe: assert ( - self.enable_flashinfer_moe - ), "FusedMoE only supports EP with --enable-flashinfer-moe" + self.enable_flashinfer_cutlass_moe + ), "FusedMoE only supports EP with --enable-flashinfer-cutlass-moe" self.ep_size = self.tp_size self.ep_rank = self.tp_rank self.tp_size = 1 @@ -141,7 +141,9 @@ class FusedMoE(torch.nn.Module): else: self.quant_method = quant_config.get_quant_method(self, prefix) if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod": - self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe + self.quant_method.enable_flashinfer_cutlass_moe = ( + self.enable_flashinfer_cutlass_moe + ) assert self.quant_method is not None self.quant_config = quant_config diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 9087f79b0..223d7d43f 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -711,7 +711,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): " quantization. Please use Blackwell and" " above." ) - self.enable_flashinfer_moe = False + self.enable_flashinfer_cutlass_moe = False def create_weights( self, @@ -865,7 +865,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) - if self.enable_flashinfer_moe: + if self.enable_flashinfer_cutlass_moe: w13_input_scale = layer.w13_input_scale.max().to(torch.float32) else: w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) @@ -894,7 +894,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) # GEMM 2 - if self.enable_flashinfer_moe: + if self.enable_flashinfer_cutlass_moe: w2_input_scale = layer.w2_input_scale.max().to(torch.float32) else: w2_input_scale = layer.w2_input_scale @@ -934,7 +934,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): @property def load_up_proj_weight_first(self) -> bool: # FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13 - return self.enable_flashinfer_moe + return self.enable_flashinfer_cutlass_moe def apply( self, @@ -954,7 +954,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." - if self.enable_flashinfer_moe: + if self.enable_flashinfer_cutlass_moe: assert ( not apply_router_weight_on_input ), "apply_router_weight_on_input is not supported for Flashinfer" diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 283da3394..5d174db77 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [ "enable_deepep_moe", "deepep_mode", "enable_ep_moe", - "enable_flashinfer_moe", + "enable_flashinfer_cutlass_moe", + "enable_flashinfer_trtllm_moe", "enable_flashinfer_allreduce_fusion", "moe_dense_tp_size", "ep_dispatch_algorithm", diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 7c627bc09..be6ef9bf3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -56,7 +56,11 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class +from sglang.srt.layers.moe.ep_moe.layer import ( + DeepEPMoE, + get_moe_impl_class, + use_flashinfer_trtllm_moe, +) from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization import deep_gemm_wrapper @@ -302,15 +306,19 @@ class DeepseekV2MoE(nn.Module): config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn ) - self.topk = TopK( - top_k=config.num_experts_per_tok + self.num_fused_shared_experts, - renormalize=config.norm_topk_prob, - use_grouped_topk=True, - num_expert_group=config.n_group, - num_fused_shared_experts=self.num_fused_shared_experts, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - routed_scaling_factor=self.routed_scaling_factor, + self.topk = ( + TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + ) + if not use_flashinfer_trtllm_moe + else None ) self.experts = get_moe_impl_class()( @@ -332,10 +340,22 @@ class DeepseekV2MoE(nn.Module): # Additional args for FusedMoE **( dict( - enable_flashinfer_moe=True, + enable_flashinfer_cutlass_moe=True, enable_ep_moe=global_server_args_dict["enable_ep_moe"], ) - if global_server_args_dict["enable_flashinfer_moe"] + if global_server_args_dict["enable_flashinfer_cutlass_moe"] + else {} + ), + **( + dict( + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + ) + if use_flashinfer_trtllm_moe else {} ), ) @@ -455,10 +475,12 @@ class DeepseekV2MoE(nn.Module): with torch.cuda.stream(self.alt_stream): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - topk_output = self.topk(hidden_states, router_logits) - final_hidden_states = self.experts( - hidden_states=hidden_states, topk_output=topk_output - ) + kwargs = {"hidden_states": hidden_states} + if self.topk is not None: + kwargs["topk_output"] = self.topk(hidden_states, router_logits) + else: + kwargs["router_logits"] = router_logits + final_hidden_states = self.experts(**kwargs) if not _is_cuda: final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) @@ -478,10 +500,12 @@ class DeepseekV2MoE(nn.Module): shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - topk_output = self.topk(hidden_states, router_logits) - final_hidden_states = self.experts( - hidden_states=hidden_states, topk_output=topk_output - ) + kwargs = {"hidden_states": hidden_states} + if self.topk is not None: + kwargs["topk_output"] = self.topk(hidden_states, router_logits) + else: + kwargs["router_logits"] = router_logits + final_hidden_states = self.experts(**kwargs) if not _is_cuda and not _use_aiter: # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 291678652..716e6c096 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -147,10 +147,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module): # Additional args for FusedMoE **( dict( - enable_flashinfer_moe=True, + enable_flashinfer_cutlass_moe=True, enable_ep_moe=global_server_args_dict["enable_ep_moe"], ) - if global_server_args_dict["enable_flashinfer_moe"] + if global_server_args_dict["enable_flashinfer_cutlass_moe"] else {} ), ) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 6b8655459..01235f7ac 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -120,10 +120,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # Additional args for FusedMoE **( dict( - enable_flashinfer_moe=True, + enable_flashinfer_cutlass_moe=True, enable_ep_moe=global_server_args_dict["enable_ep_moe"], ) - if global_server_args_dict["enable_flashinfer_moe"] + if global_server_args_dict["enable_flashinfer_cutlass_moe"] else {} ), ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 27091dc23..f1497d2a6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -169,7 +169,8 @@ class ServerArgs: ep_size: int = 1 enable_ep_moe: bool = False enable_deepep_moe: bool = False - enable_flashinfer_moe: bool = False + enable_flashinfer_cutlass_moe: bool = False + enable_flashinfer_trtllm_moe: bool = False enable_flashinfer_allreduce_fusion: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 @@ -428,12 +429,16 @@ class ServerArgs: ), "Please enable dp attention when setting enable_dp_lm_head. " # MoE kernel - if self.enable_flashinfer_moe: + if self.enable_flashinfer_cutlass_moe: assert ( self.quantization == "modelopt_fp4" ), "modelopt_fp4 quantization is required for Flashinfer MOE" os.environ["TRTLLM_ENABLE_PDL"] = "1" + if self.enable_flashinfer_trtllm_moe: + assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE" + logger.warning(f"Flashinfer TRTLLM MoE is enabled.") + # DeepEP MoE if self.enable_deepep_moe: if self.deepep_mode == "normal": @@ -1293,10 +1298,15 @@ class ServerArgs: help="Enabling expert parallelism for moe. The ep size is equal to the tp size.", ) parser.add_argument( - "--enable-flashinfer-moe", + "--enable-flashinfer-cutlass-moe", action="store_true", help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe", ) + parser.add_argument( + "--enable-flashinfer-trtllm-moe", + action="store_true", + help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe", + ) parser.add_argument( "--enable-flashinfer-allreduce-fusion", action="store_true",