diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index d2faf12cf..6f896e297 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -25,14 +25,22 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( silu_and_mul_triton_kernel, tma_align_input_scale, ) -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.fused_moe_triton.layer import ( + FlashInferFusedMoE, + FusedMoE, + should_use_flashinfer_trtllm_moe, +) from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod +from sglang.srt.layers.quantization.fp8 import ( + Fp8Config, + Fp8MoEMethod, + get_tile_tokens_dim, +) from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, sglang_per_token_group_quant_fp8, @@ -49,7 +57,6 @@ from sglang.srt.utils import ( get_bool_env_var, is_hip, is_npu, - next_power_of_2, ) if TYPE_CHECKING: @@ -63,10 +70,7 @@ _is_hip = is_hip() _is_npu = is_npu() _is_fp8_fnuz = is_fp8_fnuz() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip -use_flashinfer_trtllm_moe = ( - global_server_args_dict["enable_flashinfer_trtllm_moe"] - and global_server_args_dict["enable_ep_moe"] -) + if not (_is_npu or _is_hip): from sgl_kernel import silu_and_mul @@ -76,26 +80,9 @@ if _use_aiter: from aiter.fused_moe import fused_moe from aiter.ops.shuffle import shuffle_weight -if use_flashinfer_trtllm_moe: - try: - import flashinfer.fused_moe as fi_fused_moe - except ImportError: - fi_fused_moe = None - use_flashinfer_trtllm_moe = False - logger = logging.getLogger(__name__) -def _get_tile_tokens_dim(num_tokens, top_k, num_experts): - # Guess tokens per expert assuming perfect expert distribution first. - num_tokens_per_expert = (num_tokens * top_k) // num_experts - # And pad the number to the next power of 2. - tile_tokens_dim = next_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - return tile_tokens_dim - - class EPMoE(FusedMoE): """ MoE Expert Parallel Impl @@ -731,10 +718,10 @@ class FlashInferEPMoE(EPMoE): self.num_expert_group = num_expert_group self.topk_group = topk_group self.correction_bias = correction_bias - self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe + self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe() def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - assert use_flashinfer_trtllm_moe + assert self.use_flashinfer_trtllm_moe assert ( self.activation == "silu" ), "Only silu is supported for flashinfer blockscale fp8 moe" @@ -747,8 +734,9 @@ class FlashInferEPMoE(EPMoE): a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1]) # NOTE: scales of hidden states have to be transposed! a_sf_t = a_sf.t().contiguous() - assert fi_fused_moe is not None - return fi_fused_moe.trtllm_fp8_block_scale_moe( + from flashinfer.fused_moe import trtllm_fp8_block_scale_moe + + return trtllm_fp8_block_scale_moe( routing_logits=router_logits.to(torch.float32), routing_bias=self.correction_bias.to(hidden_states.dtype), hidden_states=a_q, @@ -765,7 +753,7 @@ class FlashInferEPMoE(EPMoE): local_expert_offset=self.start_expert_id, local_num_experts=self.num_local_experts, routed_scaling_factor=self.routed_scaling_factor, - tile_tokens_dim=_get_tile_tokens_dim( + tile_tokens_dim=get_tile_tokens_dim( hidden_states.shape[0], self.top_k, self.num_experts ), routing_method_type=2, # DeepSeek-styled routing method @@ -779,9 +767,6 @@ def get_moe_impl_class(): if global_server_args_dict["enable_flashinfer_cutlass_moe"]: # Must come before EPMoE because FusedMoE also supports enable_ep_moe return FusedMoE - if use_flashinfer_trtllm_moe: - # Must come before EPMoE because FusedMoE also supports enable_ep_moe - return FlashInferEPMoE if global_server_args_dict["enable_ep_moe"]: - return EPMoE - return FusedMoE + return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE + return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 88e150e4d..e3a16669b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1,10 +1,13 @@ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py +import importlib.util import logging from enum import Enum +from functools import lru_cache from typing import List, Optional, Tuple import torch +from packaging import version as pkg_version from sglang.srt.distributed import ( get_moe_expert_parallel_rank, @@ -33,6 +36,15 @@ _is_cpu = is_cpu() logger = logging.getLogger(__name__) +@lru_cache(maxsize=1) +def should_use_flashinfer_trtllm_moe(): + return global_server_args_dict["enable_flashinfer_trtllm_moe"] and ( + not importlib.util.find_spec("flashinfer") + or pkg_version.parse(__import__("flashinfer").__version__) + >= pkg_version.parse("0.2.9rc1") + ) + + class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" @@ -455,7 +467,7 @@ class FusedMoE(torch.nn.Module): ) # Flashinfer assumes w31 format for w13_weight. Same for the scales. - if getattr(self, "use_flashinfer_trtllm_moe", False): + if should_use_flashinfer_trtllm_moe(): shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id] WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported] @@ -687,3 +699,44 @@ class FusedMoE(torch.nn.Module): for expert_id in range(num_experts) for shard_id in ["w1", "w2", "w3"] ] + + +class FlashInferFusedMoE(FusedMoE): + def __init__(self, *args, **kwargs): + renormalize = kwargs.pop("renormalize", True) + num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0) + use_grouped_topk = kwargs.pop("use_grouped_topk", False) + num_expert_group = kwargs.pop("num_expert_group", None) + topk_group = kwargs.pop("topk_group", None) + correction_bias = kwargs.pop("correction_bias", None) + super().__init__(*args, **kwargs) + self.renormalize = renormalize + self.num_fused_shared_experts = num_fused_shared_experts + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.correction_bias = correction_bias + + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + assert self.quant_method is not None + assert ( + self.renormalize + ), "Renormalize is required for flashinfer blockscale fp8 moe" + assert ( + self.num_fused_shared_experts == 0 + ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe" + # Matrix multiply. + final_hidden_states = self.quant_method.apply_with_router_logits( + layer=self, + x=hidden_states, + router_logits=router_logits, + activation=self.activation, + routed_scaling_factor=self.routed_scaling_factor, + ) + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 49a3af57f..0578ee60c 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -72,6 +72,7 @@ from sglang.srt.utils import ( is_hip, is_npu, log_info_on_rank0, + next_power_of_2, print_warning_once, set_weight_attrs, use_intel_amx_backend, @@ -490,6 +491,16 @@ class Fp8LinearMethod(LinearMethodBase): ) +def get_tile_tokens_dim(num_tokens, top_k, num_experts): + # Guess tokens per expert assuming perfect expert distribution first. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and @@ -1076,6 +1087,47 @@ class Fp8MoEMethod(FusedMoEMethodBase): routed_scaling_factor=routed_scaling_factor, ) + def apply_with_router_logits( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + *, + activation: str = "silu", + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + assert ( + activation == "silu" + ), "Only silu is supported for flashinfer blockscale fp8 moe" + a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1]) + # NOTE: scales of hidden states have to be transposed! + a_sf_t = a_sf.t().contiguous() + from flashinfer.fused_moe import trtllm_fp8_block_scale_moe + + return trtllm_fp8_block_scale_moe( + routing_logits=router_logits.to(torch.float32), + routing_bias=layer.correction_bias.to(x.dtype), + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale_inv, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale_inv, + num_experts=layer.num_experts, + top_k=layer.top_k, + n_group=layer.num_expert_group, + topk_group=layer.topk_group, + intermediate_size=layer.w2_weight.shape[2], + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=routed_scaling_factor, + tile_tokens_dim=get_tile_tokens_dim( + x.shape[0], layer.top_k, layer.num_experts + ), + routing_method_type=2, # DeepSeek-styled routing method + use_shuffled_weight=False, + ) + def maybe_apply_hip_fused_experts( self, layer: torch.nn.Module, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index bd0e35a2e..5ed19ed86 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -59,7 +59,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import ( DeepEPMoE, get_moe_impl_class, - use_flashinfer_trtllm_moe, + should_use_flashinfer_trtllm_moe, ) from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.topk import TopK @@ -317,7 +317,7 @@ class DeepseekV2MoE(nn.Module): correction_bias=self.gate.e_score_correction_bias, routed_scaling_factor=self.routed_scaling_factor, ) - if not use_flashinfer_trtllm_moe + if not should_use_flashinfer_trtllm_moe() else None ) @@ -352,11 +352,10 @@ class DeepseekV2MoE(nn.Module): renormalize=config.norm_topk_prob, use_grouped_topk=True, num_expert_group=config.n_group, - num_fused_shared_experts=self.num_fused_shared_experts, topk_group=config.topk_group, correction_bias=self.gate.e_score_correction_bias, ) - if use_flashinfer_trtllm_moe + if should_use_flashinfer_trtllm_moe() else {} ), ) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 6031e7600..645ecf344 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -52,7 +52,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import ( DeepEPMoE, get_moe_impl_class, - use_flashinfer_trtllm_moe, + should_use_flashinfer_trtllm_moe, ) from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -426,7 +426,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): correction_bias=self.gate.e_score_correction_bias, routed_scaling_factor=self.routed_scaling_factor, ) - if not use_flashinfer_trtllm_moe + if not should_use_flashinfer_trtllm_moe() else None ) @@ -465,7 +465,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): topk_group=config.topk_group, correction_bias=self.gate.e_score_correction_bias, ) - if use_flashinfer_trtllm_moe + if should_use_flashinfer_trtllm_moe() else {} ), ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c4a520f1c..2927a7071 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -460,10 +460,6 @@ class ServerArgs: f"Flashinfer cutlass MoE and EP MoE are enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) - if self.enable_flashinfer_trtllm_moe: - assert self.enable_ep_moe, "EP MoE is required for Flashinfer TRTLLM MOE" - logger.warning(f"Flashinfer TRTLLM MoE is enabled.") - # DeepEP MoE if self.enable_deepep_moe: if self.deepep_mode == "normal":