From eec9e471cad4aaf047f9f996fb069521ab9f50c7 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Wed, 22 Oct 2025 13:11:16 -0700 Subject: [PATCH] [NVIDIA] Update to leverage flashinfer trtllm FP4 MOE throughput kernel (#11563) Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../srt/layers/moe/fused_moe_triton/layer.py | 15 +---------- .../sglang/srt/layers/quantization/mxfp4.py | 27 +------------------ scripts/ci/ci_install_dependency.sh | 4 +-- 3 files changed, 4 insertions(+), 42 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 8ddad6096..702c1a6d9 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -49,7 +49,6 @@ from sglang.srt.utils import ( is_cpu, is_flashinfer_available, is_hip, - next_power_of_2, round_up, ) @@ -72,16 +71,6 @@ if should_use_flashinfer_trtllm_moe(): logger = logging.getLogger(__name__) -def _get_tile_tokens_dim(num_tokens, top_k, num_experts): - # Guess tokens per expert assuming perfect expert distribution first. - num_tokens_per_expert = (num_tokens * top_k) // num_experts - # And pad the number to the next power of 2. - tile_tokens_dim = next_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - return tile_tokens_dim - - def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher: a2a_backend = get_moe_a2a_backend() if a2a_backend.is_none(): @@ -1080,9 +1069,7 @@ class FlashInferFP4MoE(FusedMoE): local_expert_offset=self.moe_ep_rank * self.num_local_experts, local_num_experts=self.num_local_experts, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, - tile_tokens_dim=_get_tile_tokens_dim( - hidden_states.shape[0], topk_config.top_k, self.num_local_experts - ), + tile_tokens_dim=None, routing_method_type=RoutingMethodType.DeepSeekV3, do_finalize=True, )[0] diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 76757e501..014500d93 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -41,7 +41,6 @@ from sglang.srt.utils import ( is_triton_kernels_available, log_info_on_rank0, mxfp_supported, - next_power_of_2, round_up, set_weight_attrs, ) @@ -597,30 +596,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w2_weight = Parameter(w2_weight.data, requires_grad=False) torch.cuda.empty_cache() - def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): - # Number of tokens in the input tensor. - num_tokens = x.shape[0] - # Factor to account for the imbalance of the experts. - # factor equals to the - # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert - # - 1.0 means perfect expert distribution. - # - > 1.0 means some experts have more - # tokens than the perfect distribution. - # - < 1.0 does not make sense. - imbalance_factor = 1.3 - # Calculate the number of tokens per expert - # assuming perfect distribution. - num_tokens_per_expert = (num_tokens * top_k) // self.num_experts - # Apply the imbalance factor. - num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) - # And pad the number to the next power of 2. - tile_tokens_dim = next_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile - # as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - - return tile_tokens_dim - def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): @@ -696,7 +671,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset layer.num_local_experts, # local num experts None, - self._get_tile_tokens_dim(x, top_k), + None, # tile_tokens_dim 1, # routing_method_type, renormalize True, # do finalize )[0] diff --git a/scripts/ci/ci_install_dependency.sh b/scripts/ci/ci_install_dependency.sh index ee72f77d9..8c85a68f8 100755 --- a/scripts/ci/ci_install_dependency.sh +++ b/scripts/ci/ci_install_dependency.sh @@ -45,8 +45,8 @@ else # Install the main package without deps $PIP_CMD install -e "python[dev]" --no-deps $PIP_INSTALL_SUFFIX --force-reinstall - # Install flashinfer-python 0.4.0 dependency that requires prerelease (This should be removed when flashinfer fixes this issue) - $PIP_CMD install flashinfer-python==0.4.0 --prerelease=allow $PIP_INSTALL_SUFFIX + # Install flashinfer-python 0.4.1 dependency that requires prerelease (This should be removed when flashinfer fixes this issue) + $PIP_CMD install flashinfer-python==0.4.1 --prerelease=allow $PIP_INSTALL_SUFFIX # Install the main package $PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX --upgrade