From 168033d5fb1ea1744cd82d9f42f732d2327337fd Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Wed, 6 Aug 2025 00:05:25 -0700 Subject: [PATCH] Support mxfp4 for GPT-OSS (#8843) Co-authored-by: Co-author fzyzcjy Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Co-authored-by: zhuofan1123 Co-authored-by: liz-badada Co-authored-by: xutizhou Co-authored-by: linhu-nv --- .../srt/layers/moe/fused_moe_triton/layer.py | 64 ++- .../fused_moe_triton/triton_kernels_moe.py | 40 +- .../srt/layers/quantization/__init__.py | 14 +- python/sglang/srt/layers/quantization/fp4.py | 321 ++----------- .../sglang/srt/layers/quantization/mxfp4.py | 443 ++++++++++++++++++ .../sglang/srt/layers/quantization/unquant.py | 2 + python/sglang/srt/models/gpt_oss.py | 218 ++++++++- python/sglang/srt/server_args.py | 10 + python/sglang/srt/utils.py | 4 + 9 files changed, 791 insertions(+), 325 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/mxfp4.py diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 56ffe371b..35f06c6de 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -389,7 +389,7 @@ class FusedMoE(torch.nn.Module): # Narrow parameter and load. if is_bias: # this expert_data is a bias, not weight, - # for w2_bias in TP, it does not need to be sharded + # for w2_weight_bias in TP, it does not need to be sharded shard_size = expert_data.shape[-1] else: # this parameter is a weight matrix @@ -410,10 +410,6 @@ class FusedMoE(torch.nn.Module): if not is_bias and not self.use_presharded_weights: if self.use_triton_kernels: loaded_weight = loaded_weight.transpose(-2, -1) - if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]: - raise ValueError( - f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}" - ) loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) @@ -461,9 +457,25 @@ class FusedMoE(torch.nn.Module): loaded_weight: torch.Tensor, weight_name: str, shard_id: str, - expert_id: int, + expert_id: Optional[int], ) -> None: + # if expert_id is None, then + # all the experts are loaded at the same time + if ( + not expert_id + and self.quant_config is not None + and self.quant_config.get_name() == "mxfp4" + ): + if "bias" in weight_name: + dim1 = loaded_weight.shape[1] + param.data[:, :dim1].copy_(loaded_weight) + else: + dim1 = loaded_weight.shape[1] + dim2 = loaded_weight.shape[2] + param.data[:, :dim1, :dim2].copy_(loaded_weight) + return + global_expert_location_metadata = get_global_expert_location_metadata() if global_expert_location_metadata is None: self._weight_loader_impl( @@ -502,6 +514,7 @@ class FusedMoE(torch.nn.Module): shard_id: str, expert_id: int, ) -> None: + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: return @@ -705,6 +718,18 @@ class FusedMoE(torch.nn.Module): ) -> None: tp_rank = self.moe_tp_rank + if self.quant_config is not None and self.quant_config.get_name() == "mxfp4": + if "bias" in weight_name: + dim1 = loaded_weight.shape[1] + param.data[:, :dim1].copy_(loaded_weight) + elif "scale" in weight_name: + param.data.copy_(loaded_weight) + else: + dim1 = loaded_weight.shape[1] + dim2 = loaded_weight.shape[2] + param.data[:, :dim1, :dim2].copy_(loaded_weight) + return + # compressed-tensors checkpoints with packed weights are stored flipped # TODO: check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -854,6 +879,33 @@ class FusedMoE(torch.nn.Module): ("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"), ] + @classmethod + def make_expert_params_mapping_fused_mxfp4( + cls, + ckpt_gate_up_proj_name: str, + ckpt_down_proj_name: str, + ckpt_gate_up_proj_bias_name: str, + ckpt_down_proj_bias_name: str, + ckpt_gate_up_proj_scale_name: str, + ckpt_down_proj_scale_name: str, + ): + return [ + ("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"), + ( + "experts.w13_weight_bias", + f"experts.{ckpt_gate_up_proj_bias_name}", + "w13", + ), + ("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"), + ("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"), + ( + "experts.w13_weight_scale", + f"experts.{ckpt_gate_up_proj_scale_name}", + "w13", + ), + ("experts.w2_weight_scale", f"experts.{ckpt_down_proj_scale_name}", "w2"), + ] + @classmethod def make_expert_input_scale_params_mapping( cls, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py index 36466661d..e99dc683a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py @@ -186,8 +186,10 @@ def triton_kernel_fused_experts( def triton_kernel_moe_with_bias_forward( hidden_states: torch.Tensor, w1: torch.Tensor, + w1_pcg, b1: torch.Tensor, w2: torch.Tensor, + w2_pcg, b2: torch.Tensor, topk_output: TopKOutput, inplace: bool = False, @@ -209,13 +211,15 @@ def triton_kernel_moe_with_bias_forward( return triton_kernel_fused_experts_with_bias( hidden_states, - w1, - b1, - w2, - b2, - routing_data, - gather_idx, - scatter_idx, + w1=w1, + w1_pcg=w1_pcg, + b1=b1, + w2=w2, + w2_pcg=w2_pcg, + b2=b2, + routing_data=routing_data, + gather_indx=gather_idx, + scatter_indx=scatter_idx, inplace=inplace, activation=activation, use_fp8_w8a8=use_fp8_w8a8, @@ -235,8 +239,10 @@ def triton_kernel_moe_with_bias_forward( def triton_kernel_fused_experts_with_bias( hidden_states: torch.Tensor, w1: torch.Tensor, + w1_pcg, b1: torch.Tensor, w2: torch.Tensor, + w2_pcg, b2: torch.Tensor, routing_data: RoutingData, gather_indx: GatherIndx, @@ -267,8 +273,10 @@ def triton_kernel_fused_experts_with_bias( # type check assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16" - assert w1.dtype == torch.bfloat16, "w1 must be bfloat16" - assert w2.dtype == torch.bfloat16, "w2 must be bfloat16" + for w in (w1, w2): + # TODO assert bf16 or mxfp4 + # assert (w.dtype == torch.bfloat16) or check-is-mxfp4, f"w must be bfloat16 or mxfp4 {w1.dtype=}" + pass # Shape check assert hidden_states.ndim == 2, "hidden_states must be 2D" @@ -287,13 +295,15 @@ def triton_kernel_fused_experts_with_bias( if global_num_experts == -1: global_num_experts = E - device = "cuda" - optg = dict() - w1, w1_flex = quantize(w1, "bf16", device, **optg) - w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex)) + # TODO maybe completely remove this branch + if w1.dtype == torch.bfloat16: + device = "cuda" + optg = dict() + w1, w1_flex = quantize(w1, "bf16", device, **optg) + w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex)) - w2, w2_flex = quantize(w2, "bf16", device, **optg) - w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex)) + w2, w2_flex = quantize(w2, "bf16", device, **optg) + w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex)) act = FusedActivation( FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 455e8ac8f..19977012a 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -47,7 +47,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( CompressedTensorsConfig, ) -from sglang.srt.utils import mxfp_supported +from sglang.srt.utils import is_cuda, is_hip, mxfp_supported is_mxfp_supported = mxfp_supported() if is_mxfp_supported: @@ -66,6 +66,7 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ModelOptFp8Config, ) from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config +from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config from sglang.srt.layers.quantization.petit import PetitNvFp4Config from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.utils import get_linear_quant_method @@ -90,7 +91,16 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "w4afp8": W4AFp8Config, "petit_nvfp4": PetitNvFp4Config, } -if is_mxfp_supported: + + +if is_cuda(): + BASE_QUANTIZATION_METHODS.update( + { + "quark": Mxfp4Config, + "mxfp4": Mxfp4Config, + } + ) +elif is_mxfp_supported and is_hip(): BASE_QUANTIZATION_METHODS.update( { "quark": MxFp4Config, diff --git a/python/sglang/srt/layers/quantization/fp4.py b/python/sglang/srt/layers/quantization/fp4.py index ad40ed142..68d463cc3 100644 --- a/python/sglang/srt/layers/quantization/fp4.py +++ b/python/sglang/srt/layers/quantization/fp4.py @@ -50,315 +50,50 @@ use_dynamic_mxfp4_linear = get_bool_env_var("SGLANG_USE_DYNAMIC_MXFP4_linear") OCP_MX_BLOCK_SIZE = 32 -class MxFp4Config(QuantizationConfig): +class Mxfp4Config(QuantizationConfig): - def __init__( - self, - is_checkpoint_fp4_serialized: bool = False, - quant_config: dict[str, Any] = None, - kv_cache_group: Optional[list[str]] = None, - kv_cache_config: Optional[dict[str, Any]] = None, - pack_method: str = "reorder", - ignored_layers: Optional[List[str]] = None, - ): + def __init__(self, ignored_layers: Optional[list[str]] = None): super().__init__() - if kv_cache_group is None: - kv_cache_group = [] + self.ignored_layers = ignored_layers - self.is_checkpoint_fp4_serialized = is_checkpoint_fp4_serialized - self.quant_config = quant_config - self.kv_cache_group = kv_cache_group - self.kv_cache_config = kv_cache_config - self.pack_method = pack_method - - self.packed_modules_mapping = ( - self.quant_config["packed_modules_mapping"] - if is_checkpoint_fp4_serialized - else None - ) - - self.ignored_layers = ignored_layers or [] - - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] + @classmethod + def from_config(cls, config): + return cls() @classmethod def get_min_capability(cls) -> int: - return 70 - - def get_name(self) -> str: - return "fp4" - - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: - - from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - - # Check if the layer is skipped for quantization. - if len(self.ignored_layers) > 0 and should_ignore_layer( - prefix, - ignore=self.ignored_layers, - fused_mapping=self.packed_modules_mapping, - ): - return UnquantizedLinearMethod() - - if isinstance(layer, LinearBase): - if self.is_checkpoint_fp4_serialized: - scheme = self.get_scheme(layer=layer, layer_name=prefix) - layer.scheme = scheme - return MxFp4LinearMethod(self) - - elif use_dynamic_mxfp4_linear: - return MxFp4LinearMethod(self) - else: - return UnquantizedLinearMethod() - - if isinstance(layer, RadixAttention): - return MxFp4KVCacheMethod(self) - - if isinstance(layer, FusedMoE): - return MxFp4MoEMethod.get_moe_method(self, module=layer, layer_name=prefix) - - return None + return 80 @classmethod - def from_config(cls, config: dict[str, Any]) -> "MxFp4Config": - if not mxfp_supported(): - platform = torch.cuda.get_device_properties(0).gcnArchName - raise ValueError( - f"Current platform {platform} not support mxfp4 computation" - ) - quant_method = cls.get_from_keys(config, ["quant_method"]) - is_checkpoint_fp4_serialized = ( - True if quant_method else False - ) # "quark" in quant_method + def get_name(cls) -> QuantizationMethods: + return "mxfp4" - kv_cache_group = [] - pack_method = None - - if is_checkpoint_fp4_serialized: - export_config = config.get("export") - if export_config is None: - raise ValueError( - "The export key should be included in " - "the configurations of Quark quantized model" - ) - - kv_cache_group = cast(list[str], export_config.get("kv_cache_group")) - pack_method = cast(str, export_config.get("pack_method")) - - # In the export model of quark, the quantization configuration - # of kv_cache is stored in layer_quant_config. First, it is - # judged whether kv_cache_group exists, and then it is judged - # whether layer_quant_config has a quantization configuration - # that matches kv_cache. - if len(kv_cache_group) == 0: - kv_cache_config = None - else: - kv_cache_set = set(kv_cache_group) - layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config")) - layer_quant_names = list(layer_quant_config.keys()) - layer_quant_set = set(layer_quant_names) - - if not kv_cache_set.issubset(layer_quant_set): - raise ValueError( - "The Quark quantized model has the " - "kv_cache_group parameter setting, " - "but no kv_cache quantization settings " - "were found in the quantization " - "configuration." - ) - - q_configs = [ - cast(dict[str, Any], layer_quant_config.get(name)) - for name in kv_cache_group - ] - if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs): - raise ValueError( - "The quantization method used for kv_cache should " - "be the same, but the quantization method for the " - "kv_cache layer in the config is different." - ) - kv_cache_config = q_configs[0].get("output_tensors") - if kv_cache_config is None: - raise ValueError("The kv_cache quantization configuration is empty.") - - # Since we have already set kv_cache quantization configurations, - # we will remove the quantization configuration for the - # output_tensors corresponding to the kv_cache layer. - for q_config in q_configs: - q_config["output_tensors"] = None - - # In case q_proj output is also quantized, remove the configuration - # to keep qkv consistency. - q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj")) - if q_proj_q_config is not None: - q_proj_q_config["output_tensors"] = None - - ignored_layers = cls.get_from_keys_or(config, ["exclude"], None) - - return cls( - is_checkpoint_fp4_serialized=is_checkpoint_fp4_serialized, - quant_config=config, - kv_cache_group=kv_cache_group, - kv_cache_config=kv_cache_config, - pack_method=pack_method, - ignored_layers=ignored_layers, - ) + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] @classmethod def get_config_filenames(cls) -> list[str]: return [] - def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: - capability_tuple = get_device_capability() + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import - if capability_tuple is not None: - assert 0 <= capability_tuple[1] < 10 - capability = capability_tuple[0] * 10 + capability_tuple[1] - - supported = capability >= min_capability - if error and not supported: - raise RuntimeError( - "Quantization scheme is not supported for ", - f"the current GPU. Min capability: {min_capability}. ", - f"Current capability: {capability}.", - ) - return supported - else: - return False - - def _is_mx_fp4( - self, - weight_quant: Optional[dict[str, Any]], - input_quant: Optional[dict[str, Any]], - ) -> bool: - # Confirm weights and input quantized. - if weight_quant is None or input_quant is None: - logger.debug( - "Quark model is not in MX-FP4 format: " - "weight_quant or input_quant not set" - ) - return False - - # Input and weight dtype needs to be fp4. - if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4": - logger.debug("Quark model is not in MX-FP4 format: dtype not fp4") - return False - - # Input and weight qscheme needs to be per group. - if ( - weight_quant.get("qscheme") != "per_group" - or input_quant.get("qscheme") != "per_group" - ): - logger.debug("Quark model is not in MX-FP4 format: not per_group") - return False - - # Input and weight group size needs to be 32. - if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32: - logger.debug("Quark model is not in MX-FP4 format: not group_size=32") - return False - - # Weights need to use static quantization. - if weight_quant.get("is_dynamic") is True: - logger.debug("Quark model is not in MX-FP4 format: not weight static") - return False - - # Activations need to use dynamic quantization. - if input_quant.get("is_dynamic") is False: - logger.debug("Quark model is not in MX-FP4 format: not activation dynamic") - return False - - # Activations and weight scales need to be in e8m0 format. - if ( - weight_quant.get("scale_format") != "e8m0" - or input_quant.get("scale_format") != "e8m0" - ): - logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0") - return False - - return True - - def _find_matched_config( - self, layer_name: str, module: torch.nn.Module - ) -> dict[str, Any]: - - proj_name = layer_name.split(".")[-1] - if proj_name in self.packed_modules_mapping: - shard_proj_names = self.packed_modules_mapping[proj_name] - - # Convert fused_name --> [shard_names] - shard_names = [ - layer_name.replace(proj_name, shard_proj_name) - for shard_proj_name in shard_proj_names - ] - shard_configs = [ - self._find_matched_config(shard_name, module) - for shard_name in shard_names - ] - if not all( - deep_compare(q_config, shard_configs[0]) for q_config in shard_configs + if isinstance(layer, LinearBase): + if self.ignored_layers and is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, ): - raise ValueError( - f"Found a different quantization configuration for " - f"{shard_proj_names=} in {layer_name=}. vLLM " - "requires all to use the same scheme." - ) - return shard_configs[0] - else: - layer_quant_config = cast( - dict[str, Any], self.quant_config.get("layer_quant_config") - ) - for name_pattern in layer_quant_config: - if fnmatch.fnmatch(layer_name, name_pattern): - return layer_quant_config[name_pattern] - - layer_type = cast(str, type(module)) - layer_type_quant_config = cast( - dict[str, Any], self.quant_config.get("layer_type_quant_config") - ) - if layer_type in layer_type_quant_config: - return layer_type_quant_config[layer_type] - - global_quant_config = cast( - dict[str, Any], self.quant_config.get("global_quant_config") - ) - return global_quant_config - - def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": - if config.get("output_tensors") or config.get("bias"): - raise NotImplementedError( - "Currently, Quark models with output_tensors " - "and bias quantized are not supported" - ) - weight_config = cast(dict[str, Any], config.get("weight")) - input_config = cast(dict[str, Any], config.get("input_tensors")) - - if self._is_mx_fp4(weight_config, input_config): - return QuarkW4A4MXFP4(weight_config, input_config) - - raise NotImplementedError( - "No quark compatible scheme was found. " - f"{weight_config=}, " - f"{input_config=}" - ) - - def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme": - - layer_quant_config = self._find_matched_config(layer_name, layer) - - # Find the quant_scheme - scheme = self._get_scheme_from_config(layer_quant_config) - - # Raise error if device does not support the scheme - # (e.g. fp8 needs ada lovelace) - self._check_scheme_supported(scheme.get_min_capability()) - - return scheme - - def get_scaled_act_names(self) -> List[str]: - return [] + return UnquantizedLinearMethod() + raise NotImplementedError("Mxfp4 linear layer is not implemented") + elif isinstance(layer, FusedMoE): + return Mxfp4MoEMethod(layer.moe_config) + elif isinstance(layer, Attention): + raise NotImplementedError("Mxfp4 attention layer is not implemented") + return None class MxFp4LinearMethod(LinearMethodBase): diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py new file mode 100644 index 000000000..7103cb8be --- /dev/null +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -0,0 +1,443 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import importlib +import logging +from typing import TYPE_CHECKING, Callable, List, Optional + +import torch +from torch.nn.parameter import Parameter + +# from vllm.model_executor.layers.fused_moe import ( +# FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, +# FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) +from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.utils import is_layer_skipped +from sglang.srt.utils import ( + direct_register_custom_op, + is_cuda, + is_flashinfer_available, + is_hip, + next_power_of_2, + round_up, + set_weight_attrs, +) + +has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None + +if is_flashinfer_available(): + # from flashinfer.fused_moe import cutlass_fused_moe + from flashinfer import ( + mxfp8_quantize, + shuffle_matrix_a, + shuffle_matrix_sf_a, + trtllm_fp4_block_scale_moe, + ) + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + +OCP_MX_BLOCK_SIZE = 32 + + +def _swizzle_mxfp4(quant_tensor, scale, num_warps): + """weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel""" + import triton_kernels.matmul_ogs_details.opt_flags as opt_flags + from triton_kernels.numerics import InFlexData + from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor + from triton_kernels.tensor_details import layout + + value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=1 + ) + scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=1, num_warps=num_warps + ) + if is_cuda() and torch.cuda.get_device_capability()[0] == 10: + constraints = { + "is_persistent": True, + "epilogue_subtile": 1, + } + opt_flags.update_opt_flags_constraints(constraints) + # transpose the tensor so that the quantization axis is on dim1 + quant_tensor = quant_tensor.transpose(-2, -1) + scale = scale.transpose(-2, -1) + quant_tensor = convert_layout( + wrap_torch_tensor(quant_tensor, dtype=FP4), value_layout, **value_layout_opts + ) + scale = convert_layout(wrap_torch_tensor(scale), scale_layout, **scale_layout_opts) + return quant_tensor, InFlexData(), scale + + +def _dequant_mxfp4( + x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype +) -> torch.Tensor: + try: + from quark.torch.kernel import mx + except ImportError as err: + raise ImportError( + "The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`." + ) from err + + return mx.dq_mxfp4(x, scale, float_dtype) + + +def _dequant_mxfp4_fake( + x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype +) -> torch.Tensor: + return torch.empty( + (*x.shape[:-1], x.shape[-1] * 2), dtype=float_dtype, device=x.device + ) + + +def _quant_dequant_mxfp4( + x: torch.Tensor, scale_calculation_mode: str = "even" +) -> torch.Tensor: + try: + from quark.torch.kernel import mx + except ImportError as err: + raise ImportError( + "The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`." + ) from err + + return mx.qdq_mxfp4(x, scale_calculation_mode) + + +def _quant_dequant_mxfp4_fake( + x: torch.Tensor, scale_calculation_mode: str = "even" +) -> torch.Tensor: + return torch.empty_like(x) + + +try: + direct_register_custom_op( + op_name="dequant_mxfp4", + op_func=_dequant_mxfp4, + mutates_args=[], + fake_impl=_dequant_mxfp4_fake, + ) + dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4 +except AttributeError as error: + raise error + +try: + direct_register_custom_op( + op_name="quant_dequant_mxfp4", + op_func=_quant_dequant_mxfp4, + mutates_args=[], + fake_impl=_quant_dequant_mxfp4_fake, + ) + quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4 +except AttributeError as error: + raise error + + +class Mxfp4Config(QuantizationConfig): + + def __init__(self, ignored_layers: Optional[list[str]] = None): + super().__init__() + self.ignored_layers = ignored_layers + + @classmethod + def from_config(cls, config): + return cls() + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_name(cls) -> str: + return "mxfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod + + if isinstance(layer, LinearBase): + if self.ignored_layers and is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + elif isinstance(layer, FusedMoE): + return Mxfp4MoEMethod(use_triton_kernels=True, with_bias=True) + else: + raise NotImplementedError("Mxfp4 attention layer is not implemented") + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Mxfp4MoEMethod(FusedMoEMethodBase): + + def __init__(self, use_triton_kernels: bool = True, with_bias: bool = True): + super().__init__() + self.topk_indices_dtype = None + self.use_triton_kernels = use_triton_kernels + self.with_bias = with_bias + self.triton_kernel_moe_forward = None + self.triton_kernel_moe_with_bias_forward = None + if torch.cuda.is_available() and has_triton_kernels: + from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( + triton_kernel_moe_forward as _tk_forward, + ) + from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( + triton_kernel_moe_with_bias_forward as _tk_with_bias_forward, + ) + + self.triton_kernel_moe_forward = _tk_forward + self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # print(f"hi {self=} create_weights {layer=}") + self.num_experts = num_experts + weight_dtype = torch.uint8 + scale_dtype = torch.uint8 + + intermediate_size *= 2 + mxfp4_block = 32 + + self.intermediate_size = intermediate_size + self.hidden_size = hidden_size + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.zeros( + num_experts, 2 * intermediate_size, hidden_size // 2, dtype=weight_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size, + hidden_size // mxfp4_block, + dtype=scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w13_weight_bias = torch.nn.Parameter( + torch.zeros(num_experts, 2 * intermediate_size, dtype=torch.bfloat16), + requires_grad=False, + ) + layer.register_parameter("w13_weight_bias", w13_weight_bias) + set_weight_attrs(w13_weight_bias, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.zeros( + num_experts, hidden_size, intermediate_size // 2, dtype=weight_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size // mxfp4_block, + dtype=scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w2_weight_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=torch.bfloat16), + requires_grad=False, + ) + layer.register_parameter("w2_weight_bias", w2_weight_bias) + set_weight_attrs(w2_weight_bias, extra_weight_attrs) + + def process_weights_after_loading(self, layer): + + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + w13_weight_bias = layer.w13_weight_bias.to(torch.float32) + w2_weight_bias = layer.w2_weight_bias.to(torch.float32) + + layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False) + layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False) + + num_warps = 8 + + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + layer.w13_weight, layer.w13_weight_scale, num_warps + ) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( + layer.w2_weight, layer.w2_weight_scale, num_warps + ) + + self.w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex) + ) + self.w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) + ) + + self.w13_weight_triton_tensor = w13_weight + self.w2_weight_triton_tensor = w2_weight + + # need to delete the original weights to save memory on single GPU + del layer.w13_weight + del layer.w2_weight + layer.w13_weight = None + layer.w2_weight = None + torch.cuda.empty_cache() + + def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // self.num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + topk_output: TopKOutput, + *, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + activation_alpha: Optional[float] = None, + swiglu_limit: Optional[float] = None, + ) -> torch.Tensor: + # avoid import error when triton_kernel is not installed + # from vllm.model_executor.layers.fused_moe.triton_kernels_moe import ( + # triton_kernel_moe_forward) + + """ + if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE + or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE): + assert not self.moe.use_ep, ( + "EP is not supported for flashinfer mxfp4 moe backend yet.") + if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE: + assert x.dtype == torch.bfloat16 + x_quant = x + x_scale = None + else: + x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + trtllm_gen_output = trtllm_fp4_block_scale_moe( + router_logits.to(torch.bfloat16), + None, # routing_bias + x_quant, + x_scale, + layer.w13_weight, # uint8 (e2m1 x 2) + layer.w13_weight_scale, # uint8 (e4m3 x 2) + layer.w13_weight_bias, # fp32 per expert per channel + layer.gemm1_alpha, # fp32 per expert + layer.gemm1_beta, # fp32 per expert + layer.gemm1_clamp_limit, # fp32 per expert + layer.w2_weight, # uint8 (e2m1 x 2) + layer.w2_weight_scale, # ue8m0 + layer.w2_weight_bias, # fp32 per expert per channel + None, # output1_scale_scalar + None, # output1_scale_gate_scalar + None, # output2_scale_scalar + self.num_experts, + top_k, + None, # n_group + None, # topk_group + self.intermediate_size, # padded to multiple of 256 + 0, # local_expert_offset + self.num_experts, # local num experts + None, + self._get_tile_tokens_dim(x, top_k), + 1, # routing_method_type, renormalize + True, # do finalize + )[0] + return trtllm_gen_output + """ + + if self.use_triton_kernels: + if self.with_bias: + # TODO why we do not put weights on layer? + assert layer.w13_weight is None + assert layer.w2_weight is None + return self.triton_kernel_moe_with_bias_forward( + hidden_states=x, + w1=self.w13_weight_triton_tensor, + w1_pcg=self.w13_precision_config, + w2=self.w2_weight_triton_tensor, + w2_pcg=self.w2_precision_config, + b1=layer.w13_weight_bias, + b2=layer.w2_weight_bias, + topk_output=topk_output, + activation=activation, + activation_alpha=activation_alpha, + swiglu_limit=swiglu_limit, + ) + else: + return self.triton_kernel_moe_forward( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_output=topk_output, + ) + else: + raise NotImplementedError() diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 8fc4a5be1..c5558e3c1 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -272,6 +272,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): activation=activation, activation_alpha=activation_alpha, swiglu_limit=swiglu_limit, + w1_pcg=None, + w2_pcg=None, ) else: return self.triton_kernel_moe_forward( diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index cf40c652b..4ca9c40c5 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -25,6 +25,8 @@ from torch import nn from transformers import PretrainedConfig from sglang.srt.distributed import ( + get_moe_expert_parallel_rank, + get_moe_expert_parallel_world_size, get_moe_tensor_parallel_rank, get_pp_group, get_tensor_model_parallel_rank, @@ -108,11 +110,15 @@ class GptOssSparseMoeBlock(nn.Module): experts_type = get_moe_impl_class() extra_kwargs = {} if experts_type.__name__ == "FusedMoE": + quant_config_name = ( + quant_config.get_name() if quant_config is not None else None + ) extra_kwargs = { "enable_flashinfer_cutlass_moe": global_server_args_dict[ "enable_flashinfer_cutlass_moe" ], - "use_weight_loader_fused": True, # for moe gate_up_proj and down_proj and their bias loading + # for moe gate_up_proj and down_proj and their bias loading + "use_weight_loader_fused": quant_config_name != "mxfp4", } self.experts = experts_type( num_experts=config.num_local_experts @@ -350,7 +356,6 @@ class GptOssDecoderLayer(nn.Module): head_dim=head_dim, rms_norm_eps=rms_norm_eps, attention_bias=attention_bias, - quant_config=quant_config, prefix=add_prefix("self_attn", prefix), sliding_window_size=self.sliding_window_size, layer_type=config.layer_types[layer_id], @@ -538,7 +543,7 @@ class GptOssForCausalLM(nn.Module): self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, - quant_config=quant_config, + # quant_config=quant_config, prefix=add_prefix("lm_head", prefix), use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) @@ -652,11 +657,188 @@ class GptOssForCausalLM(nn.Module): return weight_mapping + # TODO beautify code def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn: bool = False, weight_name_mapping: dict = None, + ): + quant_config_name = ( + self.quant_config.get_name() if self.quant_config is not None else None + ) + if quant_config_name != "mxfp4": + self._load_normal_weights( + weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping + ) + else: + self._load_weights_mxfp4( + weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping + ) + + def _load_weights_mxfp4(self, weights, is_nextn, weight_name_mapping): + mxfp4_weights = [] + normal_weights = [] + + for name, weight in weights: + if ( + ".experts" in name + and self.quant_config is not None + and self.quant_config.get_name() == "mxfp4" + ): + mxfp4_weights.append((name, weight)) + else: + normal_weights.append((name, weight)) + + mxfp4_loaded_params = self._load_mxfp4_experts_weights(mxfp4_weights) + self._load_normal_weights( + normal_weights, + is_nextn=is_nextn, + weight_name_mapping=weight_name_mapping, + other_loaded_param_names=mxfp4_loaded_params, + ) + + def _load_mxfp4_experts_weights(self, weights): + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + mxfp4_block = 32 + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + intermediate_size = self.config.intermediate_size + intermediate_size_block = intermediate_size // mxfp4_block + per_rank_intermediate_size_block = intermediate_size_block // tp_size + per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block + + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) + + # Attention heads per rank + heads_per_rank = self.config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + num_experts = self.config.num_local_experts + + for name, weight in weights: + weight = weight.cuda() + + if "gate_up_proj_blocks" in name: + # Handle MLP gate and up projection weights + new_name = name.replace("gate_up_proj_blocks", "w13_weight") + + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view( + num_experts, 2 * intermediate_size, -1 + ).contiguous() + + narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "down_proj_blocks" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_blocks", "w2_weight") + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view( + num_experts, -1, intermediate_size // 2 + ).contiguous() + narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "gate_up_proj_scales" in name: + # Handle MLP gate and up projection weights scale + new_name = name.replace("gate_up_proj_scales", "w13_weight_scale") + narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "down_proj_scales" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_scales", "w2_weight_scale") + narrow_weight = weight[ + ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block + ] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + elif "gate_up_proj_bias" in name: + # Handle MLP gate and up projection biases + new_name = name.replace("gate_up_proj_bias", "w13_weight_bias") + + narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None, + ) + loaded_params.add(new_name) + + elif "down_proj_bias" in name: + if get_moe_tensor_parallel_rank() != 0: + weight = torch.zeros_like(weight) + + # Handle MLP down projection bias + new_name = name.replace("down_proj_bias", "w2_weight_bias") + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, weight, weight_name=new_name, shard_id=None, expert_id=None + ) + loaded_params.add(new_name) + + return loaded_params + + def _load_normal_weights( + self, + weights, + is_nextn: bool, + weight_name_mapping: dict, + other_loaded_param_names=[], ): tp_rank = get_tensor_model_parallel_rank() if is_nextn: @@ -725,15 +907,33 @@ class GptOssForCausalLM(nn.Module): ("qkv_proj", "v_proj", "v"), ] - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused( - ckpt_gate_up_proj_name="gate_up_proj", - ckpt_down_proj_name="down_proj", - ckpt_gate_up_proj_bias_name="gate_up_proj_bias", - ckpt_down_proj_bias_name="down_proj_bias", - ) + if self.quant_config is not None and (self.quant_config.get_name() == "mxfp4"): + expert_params_mapping = ( + get_moe_impl_class().make_expert_params_mapping_fused_mxfp4( + ckpt_gate_up_proj_name="gate_up_proj_blocks", + ckpt_down_proj_name="down_proj_blocks", + ckpt_gate_up_proj_bias_name="gate_up_proj_bias", + ckpt_down_proj_bias_name="down_proj_bias", + ckpt_gate_up_proj_scale_name="gate_up_proj_scales", + ckpt_down_proj_scale_name="down_proj_scales", + ) + ) + else: + expert_params_mapping = ( + get_moe_impl_class().make_expert_params_mapping_fused( + ckpt_gate_up_proj_name="gate_up_proj", + ckpt_down_proj_name="down_proj", + ckpt_gate_up_proj_bias_name="gate_up_proj_bias", + ckpt_down_proj_bias_name="down_proj_bias", + ) + ) params_dict = dict(self.named_parameters()) params_checker = {k: False for k, v in params_dict.items()} + + for other_loaded_param_name in other_loaded_param_names: + params_checker[other_loaded_param_name] = True + for name, loaded_weight in weights: loaded_weight = _WeightCreator.maybe_materialize(loaded_weight) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 225caaf60..69c840a7b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -464,6 +464,16 @@ class ServerArgs: self.enable_triton_kernel_moe = True self.disable_hybrid_swa_memory = True + quantization_config = getattr( + self.get_hf_config(), "quantization_config", None + ) + if ( + quantization_config is not None + and quantization_config.get("quant_method") == "mxfp4" + ): + # use bf16 for mxfp4 triton kernels + self.dtype = "bfloat16" + # Set page size if self.page_size is None: self.page_size = 1 diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 2772cd119..2eb0d28b2 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2124,6 +2124,10 @@ def next_power_of_2(n: int): return 1 << (n - 1).bit_length() if n > 0 else 1 +def round_up(x: int, y: int) -> int: + return ((x - 1) // y + 1) * y + + setattr(triton, "next_power_of_2", next_power_of_2)