From 915140fd18c9ff4193e994e6d756ea762a52240a Mon Sep 17 00:00:00 2001 From: azhurkevich <101208641+azhurkevich@users.noreply.github.com> Date: Mon, 4 Aug 2025 03:10:02 -0700 Subject: [PATCH] [NVIDIA] Add Low Latency NVFP4 decode kernels from Flashinfer (#8552) Co-authored-by: Cheng Wan --- python/sglang/srt/layers/moe/ep_moe/layer.py | 25 +- .../srt/layers/moe/fused_moe_triton/layer.py | 189 +++++++++- python/sglang/srt/layers/moe/utils.py | 16 + .../srt/layers/quantization/modelopt_quant.py | 327 ++++++++++++++---- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 49 +-- python/sglang/srt/models/glm4_moe.py | 6 +- python/sglang/srt/server_args.py | 7 + 8 files changed, 504 insertions(+), 117 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 66fbb36ea..ac5371871 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( silu_and_mul_masked_post_quant_fwd, tma_align_input_scale, ) -from sglang.srt.layers.moe.fused_moe_triton.layer import ( - FlashInferFusedMoE, - FusedMoE, - should_use_flashinfer_trtllm_moe, -) +from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE from sglang.srt.layers.moe.topk import TopKOutput -from sglang.srt.layers.moe.utils import DeepEPMode +from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8 import ( @@ -48,7 +44,6 @@ _is_npu = is_npu() _is_fp8_fnuz = is_fp8_fnuz() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip - if not (_is_npu or _is_hip): from sgl_kernel import silu_and_mul @@ -741,6 +736,22 @@ class FlashInferEPMoE(EPMoE): def get_moe_impl_class(): if global_server_args_dict["moe_a2a_backend"].is_deepep(): return DeepEPMoE + + # NEW: Direct FP4 detection (bypasses EP requirements) + # Check for FP4 quantization with TRTLLM flag, regardless of EP + if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False): + try: + # Check the quantization argument directly + quantization = global_server_args_dict.get("quantization") + if quantization == "modelopt_fp4": + from sglang.srt.layers.moe.fused_moe_triton.layer import ( + FlashInferFP4MoE, + ) + + return FlashInferFP4MoE + except: + pass + if global_server_args_dict["enable_flashinfer_cutlass_moe"]: return FusedMoE if get_moe_expert_parallel_world_size() > 1: diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index d0a9ed132..c30535d7f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1,13 +1,14 @@ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py -import importlib.util +import datetime +import glob import logging +import os +import sys from enum import Enum -from functools import lru_cache from typing import List, Optional, Tuple import torch -from packaging import version as pkg_version from sglang.srt.distributed import ( get_moe_expert_parallel_rank, @@ -22,6 +23,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( ) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.layers.moe.topk import StandardTopKOutput +from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -29,22 +31,58 @@ from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight -from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip +from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, + is_cpu, + is_flashinfer_available, + is_hip, + next_power_of_2, +) + +if is_flashinfer_available(): + from flashinfer import ( + RoutingMethodType, + fp4_quantize, + reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, + shuffle_matrix_sf_a, + ) _is_hip = is_hip() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() + +# Try to import FP4 TRTLLM function if flashinfer is available +trtllm_fp4_block_scale_moe = None +if should_use_flashinfer_trtllm_moe(): + try: + from flashinfer.fused_moe import trtllm_fp4_block_scale_moe + except ImportError: + trtllm_fp4_block_scale_moe = None + logger = logging.getLogger(__name__) -@lru_cache(maxsize=1) -def should_use_flashinfer_trtllm_moe(): - return global_server_args_dict["enable_flashinfer_trtllm_moe"] and ( - not importlib.util.find_spec("flashinfer") - or pkg_version.parse(__import__("flashinfer").__version__) - >= pkg_version.parse("0.2.9rc1") - ) +def _is_fp4_quantization_enabled(): + """Check if ModelOpt FP4 quantization is enabled.""" + try: + # Use the same simple check that works for class selection + quantization = global_server_args_dict.get("quantization") + return quantization == "modelopt_fp4" + except: + return False + + +def _get_tile_tokens_dim(num_tokens, top_k, num_experts): + # Guess tokens per expert assuming perfect expert distribution first. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim class FusedMoeWeightScaleSupported(Enum): @@ -157,10 +195,6 @@ class FusedMoE(torch.nn.Module): ) else: self.quant_method = quant_config.get_quant_method(self, prefix) - if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod": - self.quant_method.enable_flashinfer_cutlass_moe = ( - self.enable_flashinfer_cutlass_moe - ) assert self.quant_method is not None self.quant_config = quant_config @@ -747,7 +781,130 @@ class FlashInferFusedMoE(FusedMoE): routed_scaling_factor=self.routed_scaling_factor, ) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states + + +class FlashInferFP4MoE(FusedMoE): + """FP4 TRTLLM MoE implementation using FlashInfer.""" + + def __init__(self, *args, **kwargs): + # Extract DeepSeek-specific parameters + renormalize = kwargs.pop("renormalize", True) + num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0) + use_grouped_topk = kwargs.pop("use_grouped_topk", False) + num_expert_group = kwargs.pop("num_expert_group", None) + topk_group = kwargs.pop("topk_group", None) + correction_bias = kwargs.pop("correction_bias", None) + + # Extract additional TopK parameters that were previously extracted in forward + routed_scaling_factor = kwargs.pop("routed_scaling_factor", None) + + super().__init__(*args, **kwargs) + + # Store DeepSeek parameters + self.renormalize = renormalize + self.num_fused_shared_experts = num_fused_shared_experts + self.use_grouped_topk = use_grouped_topk + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.correction_bias = correction_bias + self.routed_scaling_factor = routed_scaling_factor + + # --------------------------------------------------------------------- + # Helper: quantize hidden states to FP4 each forward pass + # --------------------------------------------------------------------- + def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor): + """ + Quantize hidden states using global scale factor from quantization method. + + Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading. + Only block scales are computed at runtime for efficiency. + + Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32) + """ + + # flashinfer.fp4_quantize returns (packed_uint8, scale_fp8) + # Only the block scales are computed at runtime + hs_fp4_bytes, hs_sf_bytes = fp4_quantize( + hidden_states, + self.w13_input_scale_quant, + 16, # sf_vec_size + False, # use_ue8m0 + False, # is_sf_swizzled_layout + ) + + hs_fp4 = hs_fp4_bytes.reshape( + hidden_states.shape[0], hidden_states.shape[1] // 2 + ) + hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1) + + return hs_fp4, hs_sf + + def forward(self, hidden_states: torch.Tensor, topk_output): + """Forward pass using FP4 TRTLLM kernel. + + Args: + hidden_states: Input tensor + topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode + """ + + # TRTLLM mode expects (TopK_config, router_logits) tuple + if not isinstance(topk_output, tuple) or len(topk_output) != 2: + raise ValueError( + f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}" + ) + + _, router_logits = topk_output + + hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states) + + router_logits = router_logits.to(torch.float32) + + result = trtllm_fp4_block_scale_moe( + routing_logits=router_logits, + routing_bias=self.correction_bias.to(hidden_states.dtype), + hidden_states=hs_fp4, + hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(), + gemm1_weights=self.gemm1_weights_fp4_shuffled.data, + gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view( + torch.float8_e4m3fn + ), + gemm2_weights=self.gemm2_weights_fp4_shuffled.data, + gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view( + torch.float8_e4m3fn + ), + output1_scale_scalar=self.g1_scale_c.data, + output1_scale_gate_scalar=self.g1_alphas.data, + output2_scale_scalar=self.g2_alphas.data, + num_experts=self.num_experts, + top_k=self.top_k, + n_group=self.num_expert_group, + topk_group=self.topk_group, + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.moe_ep_rank * self.num_local_experts, + local_num_experts=self.num_local_experts, + routed_scaling_factor=self.routed_scaling_factor, + tile_tokens_dim=_get_tile_tokens_dim( + hidden_states.shape[0], self.top_k, self.num_local_experts + ), + routing_method_type=RoutingMethodType.DeepSeekV3, + do_finalize=True, + )[0] + + return result + + +def get_fused_moe_impl_class(): + """Factory function to get the appropriate FusedMoE implementation class.""" + if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled(): + # Use FP4 variant when FP4 quantization is enabled + return FlashInferFP4MoE + elif should_use_flashinfer_trtllm_moe(): + # Use regular FlashInfer variant for non-FP4 FlashInfer cases + return FlashInferFusedMoE + else: + # Default case + return FusedMoE diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 06b174995..f08b34e40 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -1,4 +1,20 @@ +import importlib.util from enum import Enum +from functools import lru_cache + +from packaging import version as pkg_version + +from sglang.srt.managers.schedule_batch import global_server_args_dict + + +@lru_cache(maxsize=1) +def should_use_flashinfer_trtllm_moe(): + result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and ( + not importlib.util.find_spec("flashinfer") + or pkg_version.parse(__import__("flashinfer").__version__) + >= pkg_version.parse("0.2.9rc1") + ) + return result class MoeA2ABackend(Enum): diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index bf7ce8727..7073f6be5 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1,13 +1,15 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py from __future__ import annotations +import importlib.util import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import torch from torch.nn.parameter import Parameter from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType +from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -29,6 +31,7 @@ from sglang.srt.layers.quantization.utils import ( requantize_with_max_scale, ) from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import is_cuda, next_power_of_2 if TYPE_CHECKING: @@ -39,6 +42,11 @@ if is_cuda(): try: from flashinfer import mm_fp4 as fp4_gemm + from flashinfer import ( + reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, + shuffle_matrix_sf_a, + ) enable_flashinfer_fp4_gemm = True except ImportError: @@ -47,6 +55,9 @@ except ImportError: else: fp4_gemm = None enable_flashinfer_fp4_gemm = False + reorder_rows_for_gated_act_gemm = None + shuffle_matrix_a = None + shuffle_matrix_sf_a = None try: from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe @@ -527,6 +538,7 @@ class ModelOptFp4Config(QuantizationConfig): ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded( @@ -536,6 +548,9 @@ class ModelOptFp4Config(QuantizationConfig): return ModelOptFp4LinearMethod(self) if self.kv_cache_quant_algo and isinstance(layer, RadixAttention): return ModelOptFp8KVCacheMethod(self) + elif isinstance(layer, FlashInferFP4MoE): + # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling + return ModelOptNvFp4FusedMoEMethod(self) elif isinstance(layer, FusedMoE): return ModelOptNvFp4FusedMoEMethod(self) return None @@ -726,7 +741,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): " quantization. Please use Blackwell and" " above." ) - self.enable_flashinfer_cutlass_moe = False + self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe() + + @property + def enable_flashinfer_cutlass_moe(self) -> bool: + """Access the global enable_flashinfer_cutlass_moe setting.""" + return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False) def create_weights( self, @@ -743,16 +763,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): " dynamic quantization is not supported." ) + # TODO(ch-wan): check if this is needed layer.num_experts = num_experts + layer.num_local_experts = num_experts + layer.intermediate_size_per_partition = intermediate_size_per_partition layer.params_dtype = params_dtype layer.quant_config = self.quant_config + weight_dtype = torch.uint8 weight_scale_dtype = torch.float8_e4m3fn weight_loader = extra_weight_attrs.get("weight_loader") # GEMM 1 w13_weight = ModelWeightParameter( data=torch.empty( - num_experts, + layer.local_num_experts, 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // 2, @@ -767,7 +791,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): # GEMM 2 w2_weight = ModelWeightParameter( data=torch.empty( - num_experts, + layer.num_local_experts, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // 2, @@ -781,7 +805,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): w13_weight_scale = ModelWeightParameter( data=torch.empty( - num_experts, + layer.num_local_experts, 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // self.quant_config.group_size, @@ -795,7 +819,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): w2_weight_scale = ModelWeightParameter( data=torch.empty( - num_experts, + layer.num_local_experts, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // self.quant_config.group_size, @@ -814,13 +838,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ) w13_weight_scale_2 = PerTensorScaleParameter( - data=torch.empty(num_experts, 2, dtype=torch.float32), + data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) w2_weight_scale_2 = PerTensorScaleParameter( - data=torch.empty(num_experts, dtype=torch.float32), + data=torch.empty(layer.num_local_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) @@ -830,18 +854,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ) w13_input_scale = PerTensorScaleParameter( - data=torch.empty(num_experts, 2, dtype=torch.float32), + data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) w2_input_scale = PerTensorScaleParameter( - data=torch.empty(num_experts, dtype=torch.float32), + data=torch.empty(layer.num_local_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_input_scale", w2_input_scale) - def swizzle_blockscale(self, scale: torch.tensor): + def swizzle_blockscale(self, scale: torch.Tensor): assert scale.dtype == torch.float8_e4m3fn # Pad and blockwise interleave weight_scale scale_ndim = scale.ndim @@ -866,9 +890,125 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): else swizzled_scale.reshape(B, M, K) ) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + def prepare_static_weights_for_kernel( + self, + # args_dequant, + # args, + gemm1_weights, + gemm2_weights, + gemm1_scales_linear_fp4_bytes, + gemm2_scales_linear_fp4_bytes, + hidden_size, + intermediate_size, + num_experts, + ): + from flashinfer import ( + RoutingMethodType, + e2m1_and_ufp8sf_scale_to_float, + fp4_quantize, + next_positive_power_of_2, + reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, + shuffle_matrix_sf_a, + ) - # GEMM 1 + """Prepare quantized weights for kernel (done offline with weights).""" + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + + # Convert quantized weights to proper formats + gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape( + num_experts, 2 * intermediate_size, hidden_size // 2 + ) # packed fp4 + gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( + torch.float8_e4m3fn + ).reshape( + num_experts, 2 * intermediate_size, hidden_size // 16 + ) # fp8 scaling factors + + gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( + num_experts, hidden_size, intermediate_size // 2 + ) # packed fp4 + gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( + torch.float8_e4m3fn + ).reshape( + num_experts, hidden_size, intermediate_size // 16 + ) # fp8 scaling factors + + # Reorder rows of W1 and scales for fused gated activation + gemm1_weights_fp4_interleaved = [] + gemm1_scales_fp4_interleaved = [] + for i in range(num_experts): + gemm1_weights_fp4_interleaved.append( + reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()) + ) + gemm1_scales_fp4_interleaved.append( + reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone()) + ) + + # Stack weights and scales for all experts + gemm1_weights_fp4_interleaved = torch.stack( + gemm1_weights_fp4_interleaved + ).reshape(num_experts, 2 * intermediate_size, hidden_size // 2) + gemm1_scales_fp4_interleaved = torch.stack( + gemm1_scales_fp4_interleaved + ).reshape(num_experts, 2 * intermediate_size, hidden_size // 16) + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_fp4_shuffled = [] + gemm1_scales_fp4_shuffled = [] + gemm2_weights_fp4_shuffled = [] + gemm2_scales_fp4_shuffled = [] + for i in range(num_experts): + gemm1_weights_fp4_shuffled.append( + shuffle_matrix_a( + gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m + ) + ) + gemm1_scales_fp4_shuffled.append( + shuffle_matrix_sf_a( + gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m + ) + ) + + gemm2_weights_fp4_shuffled.append( + shuffle_matrix_a( + gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m + ) + ) + gemm2_scales_fp4_shuffled.append( + shuffle_matrix_sf_a( + gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m + ) + ) + + # Stack weights for all experts + gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) + gemm1_scales_fp4_shuffled = ( + torch.stack(gemm1_scales_fp4_shuffled) + .view(torch.float8_e4m3fn) + .reshape(num_experts, 2 * intermediate_size, hidden_size // 16) + ) + + gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) + gemm2_scales_fp4_shuffled = ( + torch.stack(gemm2_scales_fp4_shuffled) + .view(torch.float8_e4m3fn) + .reshape(num_experts, hidden_size, intermediate_size // 16) + ) + return ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Process FP4 MoE weights after loading from serialized checkpoint. + + Only supports pre-quantized checkpoints with FP8 weights and scales. + """ + + # GEMM 1 scale processing if not torch.allclose( layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] ): @@ -880,73 +1020,123 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) - if self.enable_flashinfer_cutlass_moe: + # Calculate input scales based on strategy + if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe: w13_input_scale = layer.w13_input_scale.max().to(torch.float32) + w2_input_scale = layer.w2_input_scale.max().to(torch.float32) else: w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) + w2_input_scale = layer.w2_input_scale + + # Create shared parameters layer.g1_alphas = Parameter( (w13_input_scale * w13_weight_scale_2).to(torch.float32), requires_grad=False, ) - - assert ( - layer.w13_weight_scale.shape[2] % 16 == 0 - ), "Expected weight_scale.dim(1) to be divisible by 16" - assert ( - layer.w13_weight_scale.dtype == torch.float8_e4m3fn - ), "Weight Blockscale must be represented as FP8-E4M3" - w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale) - - layer.w13_blockscale_swizzled = Parameter( - w13_blockscale_swizzled, requires_grad=False - ) - del layer.w13_weight_scale - - # This is for quantization, so we need to invert it. - layer.w13_input_scale_quant = Parameter( - (1 / w13_input_scale).to(torch.float32), requires_grad=False - ) - - layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) - - # GEMM 2 - if self.enable_flashinfer_cutlass_moe: - w2_input_scale = layer.w2_input_scale.max().to(torch.float32) - else: - w2_input_scale = layer.w2_input_scale - layer.g2_alphas = Parameter( (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), requires_grad=False, ) - - # This is for quantization, so we need to invert it. + layer.w13_input_scale_quant = Parameter( + (1 / w13_input_scale).to(torch.float32), requires_grad=False + ) layer.w2_input_scale_quant = Parameter( (1 / w2_input_scale).to(torch.float32), requires_grad=False ) - assert ( - layer.w2_weight_scale.shape[2] % 16 == 0 - ), "Expected weight_scale.dim(1) to be divisible by 16" - assert ( - layer.w2_weight_scale.dtype == torch.float8_e4m3fn - ), "Weight Blockscale must be represented as FP8-E4M3" - w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + # Validate weight scales + for name, weight_scale in [ + ("w13", layer.w13_weight_scale), + ("w2", layer.w2_weight_scale), + ]: + assert ( + weight_scale.shape[2] % 16 == 0 + ), f"Expected {name}_weight_scale.dim(2) to be divisible by 16" + assert ( + weight_scale.dtype == torch.float8_e4m3fn + ), f"{name} Weight Blockscale must be represented as FP8-E4M3" - layer.w2_blockscale_swizzled = Parameter( - w2_blockscale_swizzled, requires_grad=False - ) - del layer.w2_weight_scale - layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + # Weight processing based on strategy + if ( + self.enable_flashinfer_trtllm_moe + and reorder_rows_for_gated_act_gemm is not None + and shuffle_matrix_sf_a is not None + ): + # FlashInfer TRTLLM processing - handles both w13 and w2 + ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) = self.prepare_static_weights_for_kernel( + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + layer.w2_weight.size(-2), # hidden_size + layer.w13_weight.size(-2) // 2, # intermediate_size + layer.w13_weight.size(0), # num_experts + ) - device = layer.w13_weight.device - layer.cutlass_moe_params = CutlassMoEParams( - CutlassMoEType.BlockscaledFP4, - device, - num_experts=layer.num_experts, # global num experts - intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n - hidden_size=layer.w13_weight.shape[2] * 2, - ) # k + # Set flashinfer parameters + layer.gemm1_weights_fp4_shuffled = Parameter( + gemm1_weights_fp4_shuffled, requires_grad=False + ) + layer.gemm2_weights_fp4_shuffled = Parameter( + gemm2_weights_fp4_shuffled, requires_grad=False + ) + layer.gemm1_scales_fp4_shuffled = Parameter( + gemm1_scales_fp4_shuffled, requires_grad=False + ) + layer.gemm2_scales_fp4_shuffled = Parameter( + gemm2_scales_fp4_shuffled, requires_grad=False + ) + + # Additional parameter needed for TRT-LLM + layer.g1_scale_c = Parameter( + (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), + requires_grad=False, + ) + + # Clean up weights that won't be used by TRT-LLM + del ( + layer.w2_weight, + layer.w2_weight_scale, + layer.w13_weight, + layer.w13_weight_scale, + ) + + print("Applied flashinfer weight processing for both w13 and w2") + + else: + # CUTLASS processing - handle w13 and w2 separately + + # Process w13 weights + w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale) + layer.w13_blockscale_swizzled = Parameter( + w13_blockscale_swizzled, requires_grad=False + ) + layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) + + # Process w2 weights + w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + layer.w2_blockscale_swizzled = Parameter( + w2_blockscale_swizzled, requires_grad=False + ) + layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + + # Both flashinfer cutlass and regular cutlass use same processing for w2 + print("Applied weight processing for both w13 and w2") + + # Set up CUTLASS MoE parameters + device = layer.w13_weight.device + layer.cutlass_moe_params = CutlassMoEParams( + CutlassMoEType.BlockscaledFP4, + device, + num_experts=layer.num_experts, # global num experts + intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n + hidden_size=layer.w13_weight.shape[2] * 2, + ) # k @property def load_up_proj_weight_first(self) -> bool: @@ -971,13 +1161,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." + # Check if this is a FlashInferFP4MoE layer that should handle its own forward + if hasattr(layer, "gemm1_weights_fp4_shuffled"): + # This layer was processed with flashinfer TRTLLM - delegate to its own forward + return layer.forward(x, topk_output) + if self.enable_flashinfer_cutlass_moe: assert ( not apply_router_weight_on_input ), "apply_router_weight_on_input is not supported for Flashinfer" # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision # and fp4 quantized weights loaded from the checkpoint - topk_weights, topk_ids, _ = topk_output + + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + output = flashinfer_cutlass_fused_moe( x, topk_ids.to(torch.int), @@ -1005,7 +1202,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 - topk_weights, topk_ids, _ = topk_output + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids output = cutlass_moe_fp4( a=x, a1_gscale=layer.w13_input_scale_quant, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 03faea684..759bb6afa 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -51,7 +51,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ScheduleBatchDisaggregationDecodeMixin, ) from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank -from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend from sglang.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, SWATokenToKVPoolAllocator, @@ -109,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "enable_triton_kernel_moe", "enable_multimodal", "enable_symm_mem", + "quantization", ] # Put some global args for easy access diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b5b13d9ac..009f926bf 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -60,12 +60,9 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import ( - DeepEPMoE, - get_moe_impl_class, - should_use_flashinfer_trtllm_moe, -) +from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import ( @@ -307,19 +304,15 @@ class DeepseekV2MoE(nn.Module): config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn ) - self.topk = ( - TopK( - top_k=config.num_experts_per_tok + self.num_fused_shared_experts, - renormalize=config.norm_topk_prob, - use_grouped_topk=True, - num_expert_group=config.n_group, - num_fused_shared_experts=self.num_fused_shared_experts, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - routed_scaling_factor=self.routed_scaling_factor, - ) - if not should_use_flashinfer_trtllm_moe() - else None + self.topk = TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + routed_scaling_factor=self.routed_scaling_factor, ) self.experts = get_moe_impl_class()( @@ -476,10 +469,14 @@ class DeepseekV2MoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) kwargs = {"hidden_states": hidden_states} - if self.topk is not None: - kwargs["topk_output"] = self.topk(hidden_states, router_logits) + + # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple + # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput + if should_use_flashinfer_trtllm_moe(): + kwargs["topk_output"] = (self.topk, router_logits) else: - kwargs["router_logits"] = router_logits + kwargs["topk_output"] = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(**kwargs) if not _is_cuda: final_hidden_states *= self.routed_scaling_factor @@ -505,10 +502,14 @@ class DeepseekV2MoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) kwargs = {"hidden_states": hidden_states} - if self.topk is not None: - kwargs["topk_output"] = self.topk(hidden_states, router_logits) + + # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple + # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput + if should_use_flashinfer_trtllm_moe(): + kwargs["topk_output"] = (self.topk, router_logits) else: - kwargs["router_logits"] = router_logits + kwargs["topk_output"] = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(**kwargs) if not _is_cuda and not _use_aiter: # fused in biased_grouped_topk so we can skip here diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 76f954578..568f632f2 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -50,11 +50,9 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import ( - get_moe_impl_class, - should_use_flashinfer_trtllm_moe, -) +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6c63de973..fb3f80f87 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -481,6 +481,13 @@ class ServerArgs: self.tp_size, ], "The expert parallel size must be 1 or the same as the tensor parallel size" + if self.enable_flashinfer_trtllm_moe: + if not self.disable_shared_experts_fusion: + self.disable_shared_experts_fusion = True + logger.warning( + "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set." + ) + # DeepEP MoE if self.moe_a2a_backend == "deepep": if self.deepep_mode == "normal":