From 1c1f8a118ef869327aba9b46100fe27fe36de6d5 Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Sun, 17 Aug 2025 10:01:54 +0800 Subject: [PATCH] Combine fp4.py and mxfp4.py into one file and support dynamic mxfp4 quantization in mxfp4.py (#9049) Co-authored-by: wunhuang --- .../srt/layers/moe/fused_moe_triton/layer.py | 7 +- .../srt/layers/quantization/__init__.py | 17 +- python/sglang/srt/layers/quantization/fp4.py | 540 ------------------ .../sglang/srt/layers/quantization/mxfp4.py | 161 +++++- .../srt/layers/quantization/quark/quark.py | 390 +++++++++++++ .../layers/quantization/quark/quark_moe.py | 197 +++++++ python/sglang/srt/server_args.py | 5 +- 7 files changed, 760 insertions(+), 557 deletions(-) delete mode 100644 python/sglang/srt/layers/quantization/fp4.py create mode 100644 python/sglang/srt/layers/quantization/quark/quark.py create mode 100644 python/sglang/srt/layers/quantization/quark/quark_moe.py diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index c5b314988..766f79404 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -474,6 +474,7 @@ class FusedMoE(torch.nn.Module): not expert_id and self.quant_config is not None and self.quant_config.get_name() == "mxfp4" + and self.quant_config.is_static_cfg() ): if "bias" in weight_name: dim1 = loaded_weight.shape[1] @@ -724,7 +725,11 @@ class FusedMoE(torch.nn.Module): ) -> None: tp_rank = self.moe_tp_rank - if self.quant_config is not None and self.quant_config.get_name() == "mxfp4": + if ( + self.quant_config is not None + and self.quant_config.get_name() == "mxfp4" + and self.quant_config.is_static_cfg() + ): if "bias" in weight_name: dim1 = loaded_weight.shape[1] param.data[:, :dim1].copy_(loaded_weight) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 971a894cf..d001bb646 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -48,12 +48,6 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( CompressedTensorsConfig, ) -from sglang.srt.utils import is_cuda, is_hip, mxfp_supported - -is_mxfp_supported = mxfp_supported() -if is_mxfp_supported: - from sglang.srt.layers.quantization.fp4 import MxFp4Config - from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.modelopt_quant import ( @@ -67,6 +61,9 @@ from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config +from sglang.srt.utils import is_cuda, is_hip, mxfp_supported + +_is_mxfp_supported = mxfp_supported() if TYPE_CHECKING: from sglang.srt.layers.moe.topk import TopKOutput @@ -98,11 +95,13 @@ if is_cuda(): "mxfp4": Mxfp4Config, } ) -elif is_mxfp_supported and is_hip(): +elif _is_mxfp_supported and is_hip(): + from sglang.srt.layers.quantization.quark.quark import QuarkConfig + BASE_QUANTIZATION_METHODS.update( { - "quark": MxFp4Config, - "mxfp4": MxFp4Config, + "quark": QuarkConfig, + "mxfp4": Mxfp4Config, } ) # VLLM-dependent quantization methods diff --git a/python/sglang/srt/layers/quantization/fp4.py b/python/sglang/srt/layers/quantization/fp4.py deleted file mode 100644 index a03f200c8..000000000 --- a/python/sglang/srt/layers/quantization/fp4.py +++ /dev/null @@ -1,540 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import fnmatch -import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast - -import aiter -import torch -import torch.nn.functional as F -from aiter import ActivationType, QuantType, dtypes -from aiter.fused_moe import fused_moe -from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages -from aiter.ops.gemm_op_a4w4 import gemm_a4w4 -from aiter.ops.quant import get_torch_quant -from aiter.ops.shuffle import shuffle_weight -from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 -from aiter.ops.triton.quant import dynamic_mxfp4_quant -from aiter.utility.fp4_utils import e8m0_shuffle -from torch.nn import Module - -from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod -from sglang.srt.layers.parameter import ModelWeightParameter -from sglang.srt.layers.quantization.base_config import ( - FusedMoEMethodBase, - LinearMethodBase, - QuantizationConfig, - QuantizeMethodBase, -) -from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod -from sglang.srt.layers.quantization.quark.schemes import QuarkScheme, QuarkW4A4MXFP4 -from sglang.srt.layers.quantization.quark.utils import deep_compare, should_ignore_layer -from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import ( - get_bool_env_var, - get_device_capability, - log_info_on_rank0, - mxfp_supported, - set_weight_attrs, -) - -if TYPE_CHECKING: - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig - from sglang.srt.layers.moe.topk import TopKOutput - -logger = logging.getLogger(__name__) - -use_dynamic_mxfp4_linear = get_bool_env_var("SGLANG_USE_DYNAMIC_MXFP4_linear") - -OCP_MX_BLOCK_SIZE = 32 - - -class Mxfp4Config(QuantizationConfig): - - def __init__(self, ignored_layers: Optional[list[str]] = None): - super().__init__() - self.ignored_layers = ignored_layers - - @classmethod - def from_config(cls, config): - return cls() - - @classmethod - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "mxfp4" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16] - - @classmethod - def get_config_filenames(cls) -> list[str]: - return [] - - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import - - if isinstance(layer, LinearBase): - if self.ignored_layers and is_layer_skipped( - prefix=prefix, - ignored_layers=self.ignored_layers, - fused_mapping=self.packed_modules_mapping, - ): - return UnquantizedLinearMethod() - raise NotImplementedError("Mxfp4 linear layer is not implemented") - elif isinstance(layer, FusedMoE): - return Mxfp4MoEMethod(layer.moe_config) - elif isinstance(layer, Attention): - raise NotImplementedError("Mxfp4 attention layer is not implemented") - return None - - -class MxFp4LinearMethod(LinearMethodBase): - - def __init__(self, quantization_config: MxFp4Config): - self.quantization_config = quantization_config - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - return - # if self.quantization_config.is_checkpoint_fp4_serialized: - # layer.scheme.process_weights_after_loading(layer) - # else: - # #w, w_scales = dynamic_mxfp4_quant(layer.weight.data) - # ##log_info_on_rank0(logger, f"w.shape: {w.shape}") - - # #wshuffle = w#shuffle_weight(w, layout=(16, 16)) - # #w_scales_shuffle = w_scales#e8m0_shuffle(w_scales).view(dtypes.fp8_e8m0) - - # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32) - - # w, w_scales_shuffle = quant_func(layer.weight.data, shuffle=True) - - # wshuffle = shuffle_weight(w, layout=(16, 16)) - - # layer.weight = torch.nn.Parameter(wshuffle, - # requires_grad=False) - # layer.weight_scale = torch.nn.Parameter(w_scales_shuffle, - # requires_grad=False) - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - """ - Use the CompressedTensorsScheme associated with each layer to create - the necessary parameters for the layer. See LinearMethodBase for param - details - """ - weight_loader = extra_weight_attrs.get("weight_loader") - - if self.quantization_config.is_checkpoint_fp4_serialized: - layer.scheme.create_weights( - layer=layer, - input_size=input_size, - input_size_per_partition=input_size_per_partition, - output_partition_sizes=output_partition_sizes, - output_size=output_size, - params_dtype=params_dtype, - weight_loader=weight_loader, - ) - else: - output_size_per_partition = sum(output_partition_sizes) - layer.logical_widths = output_partition_sizes - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.orig_dtype = params_dtype - - weight_dtype = params_dtype - - weight = ModelWeightParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - - layer.register_parameter("weight", weight) - layer.register_parameter("weight_scale", None) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ): - """ - Use the output of create_weights and the CompressedTensorsScheme - associated with the layer to apply the forward pass with the - layer input. See LinearMethodBase for param details - - """ - if self.quantization_config.is_checkpoint_fp4_serialized: - scheme = layer.scheme - if scheme is None: - raise ValueError("A scheme must be defined for each layer") - return scheme.apply_weights(layer, x, bias=bias) - else: - out_dtype = x.dtype - - # ck or asm implement - # M = x.shape[0] - # N = layer.weight.shape[0] - - # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32) - - # x, x_scales_shuffle = quant_func(x, shuffle=True) - - # y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=out_dtype) - - # out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias) - - # return out[:M] - - # triton implement - x_q, x_s = dynamic_mxfp4_quant(x) - y = torch.empty( - x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype - ) - - out = gemm_afp4wfp4( - x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y - ) - - return out - - -class MxFp4MoEMethod(FusedMoEMethodBase): - - def __init__(self, quant_config: Mxfp4Config): - self.quant_config = quant_config - - @staticmethod - def get_moe_method( - quant_config: "MxFp4Config", # type: ignore # noqa E501 # noqa F821 - module: torch.nn.Module, - layer_name: str, - ) -> "MxFp4MoEMethod": - - if quant_config.is_checkpoint_fp4_serialized: - layer_quant_config = quant_config._find_matched_config(layer_name, module) - - if layer_quant_config.get("output_tensors") or layer_quant_config.get( - "bias" - ): - raise NotImplementedError( - "Currently, Quark models with " - "output_tensors and bias " - "quantized are not supported" - ) - weight_config = layer_quant_config.get("weight") - input_config = layer_quant_config.get("input_tensors") - - if quant_config._is_mx_fp4(weight_config, input_config): - return W4A4MXFp4MoEStaticMethod(weight_config, input_config) - else: - raise RuntimeError("Unsupported FusedMoe scheme") - else: - return W4A4MXFp4MoEDynamicMethod(quant_config) - - -class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod): - def __init__(self, quant_config): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - - from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - - layer.w13_input_scale = None - layer.w2_input_scale = None - - def mxfp4_quantize(self, w): - w_shape = w.shape - w_need_reshape = True if w.dim() != 2 else False - - if w_need_reshape: - w_last_dim_size = w_shape[-1] - w = w.view(-1, w_last_dim_size) - - # log_info_on_rank0(logger, f"[Pre-quant] w.shape: {w.shape}") - w, mx_scales = dynamic_mxfp4_quant(w) - # log_info_on_rank0(logger, f"[Post-quant] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}") - - if w_need_reshape: - w_new_shape = w_shape[:-1] + (w.shape[-1],) - w = w.view(w_new_shape) - - # log_info_on_rank0(logger, f"[re-shape] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}") - - mx_scales = e8m0_shuffle(mx_scales) - - return w, mx_scales - - def process_weights_after_loading(self, layer: Module) -> None: - w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data) - w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data) - - layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False) - - layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: - topk_weights, topk_ids, _ = topk_output - - return fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - quant_type=QuantType.per_1x32, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - activation=( - ActivationType.Silu - if moe_runner_config.activation == "silu" - else ActivationType.Gelu - ), - doweight_stage1=False, - ) - - -class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod): - - def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]): - self.weight_quant = weight_config - self.input_quant = input_config - - weight_qscheme = self.weight_quant.get("qscheme") - input_qscheme = self.input_quant.get("qscheme") - if not (weight_qscheme == "per_group" and input_qscheme == "per_group"): - raise ValueError( - "For MX(FP4) Fused MoE layers, only per-group scales " - "for weights and activations are supported. Found " - f"{weight_qscheme=}, {input_qscheme=}" - ) # noqa E501 - - self.static_input_scales = not self.input_quant.get("is_dynamic") - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - - from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} - ) - - params_dtype = torch.uint8 - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // 2, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // 2, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // OCP_MX_BLOCK_SIZE, - dtype=params_dtype, - ), - requires_grad=False, - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - hidden_size, - intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - float_dtype = torch.get_default_dtype() - - # Pre-shuffle weight scales - s0, s1, _ = layer.w13_weight_scale.shape - w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1) - w13_weight_scale = e8m0_shuffle(w13_weight_scale) - layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1) - - s0, s1, _ = layer.w2_weight_scale.shape - w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1) - w2_weight_scale = e8m0_shuffle(w2_weight_scale) - layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - topk_output: TopKOutput, - moe_runner_config: MoeRunnerConfig, - ) -> torch.Tensor: - topk_weights, topk_ids, _ = topk_output - - return fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - quant_type=QuantType.per_1x32, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - activation=( - ActivationType.Silu - if moe_runner_config.activation == "silu" - else ActivationType.Gelu - ), - doweight_stage1=False, - ) - - -class MxFp4KVCacheMethod(BaseKVCacheMethod): - """ - Supports loading kv-cache scaling factors from quark checkpoints. - """ - - def __init__(self, quant_config: MxFp4Config): - self.validate_kv_cache_config(quant_config.kv_cache_config) - super().__init__(quant_config) - - @staticmethod - def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]): - """ - Validator for the kv cache configuration. Useful for controlling the - kv cache quantization schemes, that are being supported in vLLM - :param kv_cache_config: the quark kv cache scheme - """ - if kv_cache_config is None: - return - - dtype = kv_cache_config.get("dtype") - if dtype != "fp8_e4m3": - raise NotImplementedError( - "Currently supported kv cache quantization is " - f"dtype=fp8_e4m3, however received {dtype}" - ) - - qscheme = kv_cache_config.get("qscheme") - if qscheme != "per_tensor": - raise NotImplementedError( - "Only support per-tensor scaling factor " - "for quark KV cache. " - f"Expected qscheme: per_tensor, found qscheme: {qscheme}" - ) diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index fedf4c0b0..4cb28d421 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -38,6 +38,7 @@ from sglang.srt.utils import ( is_hip, is_triton_kernels_available, log_info_on_rank0, + mxfp_supported, next_power_of_2, round_up, set_weight_attrs, @@ -61,7 +62,14 @@ if TYPE_CHECKING: from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput -OCP_MX_BLOCK_SIZE = 32 +_is_hip = is_hip() + +if _is_hip: + # import aiter + from aiter import ActivationType, QuantType, dtypes + from aiter.fused_moe import fused_moe + from aiter.ops.triton.quant import dynamic_mxfp4_quant + from aiter.utility.fp4_utils import e8m0_shuffle def _swizzle_mxfp4(quant_tensor, scale, num_warps): @@ -162,13 +170,34 @@ except AttributeError as error: class Mxfp4Config(QuantizationConfig): - def __init__(self, ignored_layers: Optional[list[str]] = None): + def __init__( + self, + ignored_layers: Optional[list[str]] = None, + is_checkpoint_mxfp4_serialized: bool = False, + ): super().__init__() + self.is_checkpoint_mxfp4_serialized = is_checkpoint_mxfp4_serialized self.ignored_layers = ignored_layers @classmethod def from_config(cls, config): - return cls() + + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_mxfp4_serialized = "mxfp4" in quant_method + + if _is_hip: + if mxfp_supported(): + return cls( + is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized + ) + else: + + platform = torch.cuda.get_device_properties(0).gcnArchName + raise ValueError( + f"Current platform {platform} not support mxfp4 computation" + ) + + return cls(is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized) @classmethod def get_min_capability(cls) -> int: @@ -186,6 +215,9 @@ class Mxfp4Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return [] + def is_static_cfg(self): + return self.is_checkpoint_mxfp4_serialized + def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: @@ -201,10 +233,16 @@ class Mxfp4Config(QuantizationConfig): fused_mapping=self.packed_modules_mapping, ): return UnquantizedLinearMethod() + elif _is_hip: + return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): - return Mxfp4MoEMethod(prefix) + if self.is_checkpoint_mxfp4_serialized: + return Mxfp4MoEMethod(prefix=prefix) + else: + return Mxfp4DynamicQuantMoEMethod() else: - raise NotImplementedError("Mxfp4 attention layer is not implemented") + if self.is_checkpoint_mxfp4_serialized: + raise NotImplementedError("Mxfp4 attention layer is not implemented") return None def get_scaled_act_names(self) -> List[str]: @@ -655,3 +693,116 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): b1=layer.w13_weight_bias, b2=layer.w2_weight_bias, ) + + +class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + + layer.w13_input_scale = None + layer.w2_input_scale = None + + def mxfp4_quantize(self, w): + w_shape = w.shape + w_need_reshape = True if w.dim() != 2 else False + + if w_need_reshape: + w_last_dim_size = w_shape[-1] + w = w.view(-1, w_last_dim_size) + + w, mx_scales = dynamic_mxfp4_quant(w) + + if w_need_reshape: + w_new_shape = w_shape[:-1] + (w.shape[-1],) + w = w.view(w_new_shape) + + mx_scales = e8m0_shuffle(mx_scales) + + return w, mx_scales + + def process_weights_after_loading(self, layer: Module) -> None: + w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data) + w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False) + + layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + topk_output: TopKOutput, + moe_runner_config: MoeRunnerConfig, + ) -> torch.Tensor: + topk_weights, topk_ids, _ = topk_output + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + quant_type=QuantType.per_1x32, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + activation=( + ActivationType.Silu + if moe_runner_config.activation == "silu" + else ActivationType.Gelu + ), + doweight_stage1=False, + ) diff --git a/python/sglang/srt/layers/quantization/quark/quark.py b/python/sglang/srt/layers/quantization/quark/quark.py new file mode 100644 index 000000000..6d5a66544 --- /dev/null +++ b/python/sglang/srt/layers/quantization/quark/quark.py @@ -0,0 +1,390 @@ +# SPDX-License-Identifier: Apache-2.0 + +import fnmatch +import logging +from typing import Any, List, Optional, cast + +import torch + +from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod +from sglang.srt.layers.quantization.base_config import ( # noqa: E501 + LinearMethodBase, + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.quark.quark_moe import QuarkMoEMethod +from sglang.srt.layers.quantization.quark.schemes import QuarkScheme, QuarkW4A4MXFP4 +from sglang.srt.layers.quantization.quark.utils import deep_compare, should_ignore_layer +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.utils import get_device_capability + +__all__ = ["QuarkLinearMethod"] + +logger = logging.getLogger(__name__) + + +class QuarkConfig(QuantizationConfig): + + def __init__( + self, + quant_config: dict[str, Any], + kv_cache_group: Optional[list[str]] = None, + kv_cache_config: Optional[dict[str, Any]] = None, + pack_method: str = "reorder", + ): + super().__init__() + if kv_cache_group is None: + kv_cache_group = [] + self.quant_config = quant_config + self.kv_cache_group = kv_cache_group + self.kv_cache_config = kv_cache_config + self.pack_method = pack_method + + self.packed_modules_mapping = self.quant_config["packed_modules_mapping"] + + def get_linear_method(self) -> "QuarkLinearMethod": + return QuarkLinearMethod(self) + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def get_name(self) -> str: + return "quark" + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + # Check if the layer is skipped for quantization. + exclude_layers = cast(list[str], self.quant_config.get("exclude")) + if should_ignore_layer( + prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping + ): + return UnquantizedLinearMethod() + + if isinstance(layer, LinearBase): + scheme = self.get_scheme(layer=layer, layer_name=prefix) + layer.scheme = scheme + return QuarkLinearMethod(self) + + if isinstance(layer, RadixAttention): + return QuarkKVCacheMethod(self) + + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + + if isinstance(layer, FusedMoE): + return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix) + + return None + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": + export_config = config.get("export") + if export_config is None: + raise ValueError( + "The export key should be included in " + "the configurations of Quark quantized model" + ) + + kv_cache_group = cast(list[str], export_config.get("kv_cache_group")) + pack_method = cast(str, export_config.get("pack_method")) + + # In the export model of quark, the quantization configuration + # of kv_cache is stored in layer_quant_config. First, it is + # judged whether kv_cache_group exists, and then it is judged + # whether layer_quant_config has a quantization configuration + # that matches kv_cache. + if len(kv_cache_group) == 0: + kv_cache_config = None + else: + kv_cache_set = set(kv_cache_group) + layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config")) + layer_quant_names = list(layer_quant_config.keys()) + layer_quant_set = set(layer_quant_names) + + if not kv_cache_set.issubset(layer_quant_set): + raise ValueError( + "The Quark quantized model has the " + "kv_cache_group parameter setting, " + "but no kv_cache quantization settings " + "were found in the quantization " + "configuration." + ) + + q_configs = [ + cast(dict[str, Any], layer_quant_config.get(name)) + for name in kv_cache_group + ] + if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs): + raise ValueError( + "The quantization method used for kv_cache should " + "be the same, but the quantization method for the " + "kv_cache layer in the config is different." + ) + kv_cache_config = q_configs[0].get("output_tensors") + if kv_cache_config is None: + raise ValueError("The kv_cache quantization configuration is empty.") + + # Since we have already set kv_cache quantization configurations, + # we will remove the quantization configuration for the + # output_tensors corresponding to the kv_cache layer. + for q_config in q_configs: + q_config["output_tensors"] = None + + # In case q_proj output is also quantized, remove the configuration + # to keep qkv consistency. + q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj")) + if q_proj_q_config is not None: + q_proj_q_config["output_tensors"] = None + + return cls( + quant_config=config, + kv_cache_group=kv_cache_group, + kv_cache_config=kv_cache_config, + pack_method=pack_method, + ) + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: + capability_tuple = get_device_capability() + + if capability_tuple is not None: + assert 0 <= capability_tuple[1] < 10 + capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.", + ) + return supported + else: + return False + + def _is_mx_fp4( + self, + weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]], + ) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + logger.debug( + "Quark model is not in MX-FP4 format: " + "weight_quant or input_quant not set" + ) + return False + + # Input and weight dtype needs to be fp4. + if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4": + logger.debug("Quark model is not in MX-FP4 format: dtype not fp4") + return False + + # Input and weight qscheme needs to be per group. + if ( + weight_quant.get("qscheme") != "per_group" + or input_quant.get("qscheme") != "per_group" + ): + logger.debug("Quark model is not in MX-FP4 format: not per_group") + return False + + # Input and weight group size needs to be 32. + if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32: + logger.debug("Quark model is not in MX-FP4 format: not group_size=32") + return False + + # Weights need to use static quantization. + if weight_quant.get("is_dynamic") is True: + logger.debug("Quark model is not in MX-FP4 format: not weight static") + return False + + # Activations need to use dynamic quantization. + if input_quant.get("is_dynamic") is False: + logger.debug("Quark model is not in MX-FP4 format: not activation dynamic") + return False + + # Activations and weight scales need to be in e8m0 format. + if ( + weight_quant.get("scale_format") != "e8m0" + or input_quant.get("scale_format") != "e8m0" + ): + logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0") + return False + + return True + + def _find_matched_config( + self, layer_name: str, module: torch.nn.Module + ) -> dict[str, Any]: + + proj_name = layer_name.split(".")[-1] + if proj_name in self.packed_modules_mapping: + shard_proj_names = self.packed_modules_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + shard_configs = [ + self._find_matched_config(shard_name, module) + for shard_name in shard_names + ] + if not all( + deep_compare(q_config, shard_configs[0]) for q_config in shard_configs + ): + raise ValueError( + f"Found a different quantization configuration for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme." + ) + return shard_configs[0] + else: + layer_quant_config = cast( + dict[str, Any], self.quant_config.get("layer_quant_config") + ) + for name_pattern in layer_quant_config: + if fnmatch.fnmatch(layer_name, name_pattern): + return layer_quant_config[name_pattern] + + layer_type = type(module).__name__ + layer_type_quant_config = cast( + dict[str, Any], self.quant_config.get("layer_type_quant_config") + ) + if layer_type in layer_type_quant_config: + return layer_type_quant_config[layer_type] + + global_quant_config = cast( + dict[str, Any], self.quant_config.get("global_quant_config") + ) + return global_quant_config + + def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": + if config.get("output_tensors") or config.get("bias"): + raise NotImplementedError( + "Currently, Quark models with output_tensors " + "and bias quantized are not supported" + ) + weight_config = cast(dict[str, Any], config.get("weight")) + input_config = cast(dict[str, Any], config.get("input_tensors")) + + if self._is_mx_fp4(weight_config, input_config): + return QuarkW4A4MXFP4(weight_config, input_config) + + raise NotImplementedError( + "No quark compatible scheme was found. " + f"Weight config: {weight_config}, " + f"Input config: {input_config}" + ) + + def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme": + + layer_quant_config = self._find_matched_config(layer_name, layer) + + # Find the quant_scheme + scheme = self._get_scheme_from_config(layer_quant_config) + + # Raise error if device does not support the scheme + # (e.g. fp8 needs ada lovelace) + self._check_scheme_supported(scheme.get_min_capability()) + + return scheme + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class QuarkLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: QuarkConfig): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + weight_loader = extra_weight_attrs.get("weight_loader") + layer.scheme.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) + + +class QuarkKVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from quark checkpoints. + """ + + def __init__(self, quant_config: QuarkConfig): + self.validate_kv_cache_config(quant_config.kv_cache_config) + super().__init__(quant_config) + + @staticmethod + def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]): + """ + Validator for the kv cache configuration. Useful for controlling the + kv cache quantization schemes, that are being supported in vLLM + :param kv_cache_config: the quark kv cache scheme + """ + if kv_cache_config is None: + return + + dtype = kv_cache_config.get("dtype") + if dtype != "fp8_e4m3": + raise NotImplementedError( + "Currently supported kv cache quantization is " + f"dtype=fp8_e4m3, however received {dtype}" + ) + + qscheme = kv_cache_config.get("qscheme") + if qscheme != "per_tensor": + raise NotImplementedError( + "Only support per-tensor scaling factor " + "for quark KV cache. " + f"Expected qscheme: per_tensor, found qscheme: {qscheme}" + ) diff --git a/python/sglang/srt/layers/quantization/quark/quark_moe.py b/python/sglang/srt/layers/quantization/quark/quark_moe.py new file mode 100644 index 000000000..194fa414d --- /dev/null +++ b/python/sglang/srt/layers/quantization/quark/quark_moe.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch +from aiter import ActivationType, QuantType, biased_grouped_topk +from aiter.fused_moe import fused_moe +from aiter.utility.fp4_utils import e8m0_shuffle + +from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs + +logger = logging.getLogger(__name__) + +__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"] + +OCP_MX_BLOCK_SIZE = 32 + +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + + +class QuarkMoEMethod: + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + @staticmethod + def get_moe_method( + quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 + module: torch.nn.Module, + layer_name: str, + ) -> "QuarkMoEMethod": + layer_quant_config = quant_config._find_matched_config(layer_name, module) + + if layer_quant_config.get("output_tensors") or layer_quant_config.get("bias"): + raise NotImplementedError( + "Currently, Quark models with " + "output_tensors and bias " + "quantized are not supported" + ) + weight_config = layer_quant_config.get("weight") + input_config = layer_quant_config.get("input_tensors") + + if quant_config._is_mx_fp4(weight_config, input_config): + return QuarkW4A4MXFp4MoEMethod(weight_config, input_config) + else: + raise RuntimeError("Unsupported FusedMoe scheme") + + +class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): + + def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]): + self.weight_quant = weight_config + self.input_quant = input_config + + weight_qscheme = self.weight_quant.get("qscheme") + input_qscheme = self.input_quant.get("qscheme") + if not (weight_qscheme == "per_group" and input_qscheme == "per_group"): + raise ValueError( + "For MX(FP4) Fused MoE layers, only per-group scales " + "for weights and activations are supported. Found " + f"{weight_qscheme}, {input_qscheme}" + ) # noqa E501 + + self.static_input_scales = not self.input_quant.get("is_dynamic") + self.with_bias = False + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) + + params_dtype = torch.uint8 + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // OCP_MX_BLOCK_SIZE, + dtype=params_dtype, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + float_dtype = torch.get_default_dtype() + + # Pre-shuffle weight scales + s0, s1, _ = layer.w13_weight_scale.shape + w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1) + w13_weight_scale = e8m0_shuffle(w13_weight_scale) + # layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, requires_grad=False) + layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1) + + s0, s1, _ = layer.w2_weight_scale.shape + w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1) + w2_weight_scale = e8m0_shuffle(w2_weight_scale) + # layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False) + layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + topk_output: TopKOutput, + moe_runner_config: MoeRunnerConfig, + ) -> torch.Tensor: + topk_weights, topk_ids, _ = topk_output + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + quant_type=QuantType.per_1x32, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + activation=( + ActivationType.Silu + if moe_runner_config.activation == "silu" + else ActivationType.Gelu + ), + doweight_stage1=False, + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index fd2bd1580..fdd1f80dd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -33,6 +33,7 @@ from sglang.srt.utils import ( configure_ipv6, get_device, get_device_memory_capacity, + is_cuda, is_flashinfer_available, is_hip, is_port_available, @@ -2165,9 +2166,9 @@ class ServerArgs: model_arch = hf_config.architectures[0] if model_arch in ["GptOssForCausalLM"]: if self.attention_backend is None: - if is_sm100_supported(): + if is_cuda() and is_sm100_supported(): self.attention_backend = "trtllm_mha" - elif is_sm90_supported(): + elif is_cuda() and is_sm90_supported(): self.attention_backend = "fa3" else: self.attention_backend = "triton"