From d4bf5a8524820d3a232f7fc8e349d5e7d0d2880d Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Tue, 5 Aug 2025 09:14:52 +0800 Subject: [PATCH] Support OCP MXFP4 quantization on AMD GPUs (#8255) Co-authored-by: wunhuang Co-authored-by: Hubert Lu --- python/sglang/srt/configs/model_config.py | 2 + .../srt/layers/quantization/__init__.py | 14 +- python/sglang/srt/layers/quantization/fp4.py | 822 ++++++++++++++++++ .../srt/layers/quantization/quark/__init__.py | 0 .../quantization/quark/schemes/__init__.py | 6 + .../quark/schemes/quark_scheme.py | 55 ++ .../quark/schemes/quark_w4a4_mxfp4.py | 118 +++ .../srt/layers/quantization/quark/utils.py | 107 +++ .../sglang/srt/model_loader/weight_utils.py | 10 + python/sglang/srt/models/deepseek_v2.py | 14 + python/sglang/srt/server_args.py | 1 + python/sglang/srt/utils.py | 11 + 12 files changed, 1159 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/layers/quantization/fp4.py create mode 100644 python/sglang/srt/layers/quantization/quark/__init__.py create mode 100644 python/sglang/srt/layers/quantization/quark/schemes/__init__.py create mode 100644 python/sglang/srt/layers/quantization/quark/schemes/quark_scheme.py create mode 100644 python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py create mode 100644 python/sglang/srt/layers/quantization/quark/utils.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index f3643d154..3091ed4fe 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -401,6 +401,8 @@ class ModelConfig: "fbgemm_fp8", "w8a8_fp8", "petit_nvfp4", + "quark", + "mxfp4", ] optimized_quantization_methods = [ "fp8", diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 496cbc8f5..455e8ac8f 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -47,6 +47,12 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( CompressedTensorsConfig, ) +from sglang.srt.utils import mxfp_supported + +is_mxfp_supported = mxfp_supported() +if is_mxfp_supported: + from sglang.srt.layers.quantization.fp4 import MxFp4Config + from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.gptq import ( GPTQConfig, @@ -84,7 +90,13 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "w4afp8": W4AFp8Config, "petit_nvfp4": PetitNvFp4Config, } - +if is_mxfp_supported: + BASE_QUANTIZATION_METHODS.update( + { + "quark": MxFp4Config, + "mxfp4": MxFp4Config, + } + ) # VLLM-dependent quantization methods VLLM_QUANTIZATION_METHODS = { "aqlm": AQLMConfig, diff --git a/python/sglang/srt/layers/quantization/fp4.py b/python/sglang/srt/layers/quantization/fp4.py new file mode 100644 index 000000000..ad40ed142 --- /dev/null +++ b/python/sglang/srt/layers/quantization/fp4.py @@ -0,0 +1,822 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import fnmatch +import logging +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast + +import aiter +import torch +import torch.nn.functional as F +from aiter import ActivationType, QuantType, dtypes +from aiter.fused_moe import fused_moe +from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages +from aiter.ops.gemm_op_a4w4 import gemm_a4w4 +from aiter.ops.quant import get_torch_quant +from aiter.ops.shuffle import shuffle_weight +from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 +from aiter.ops.triton.quant import dynamic_mxfp4_quant +from aiter.utility.fp4_utils import e8m0_shuffle +from torch.nn import Module + +from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod +from sglang.srt.layers.parameter import ModelWeightParameter +from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.quark.schemes import QuarkScheme, QuarkW4A4MXFP4 +from sglang.srt.layers.quantization.quark.utils import deep_compare, should_ignore_layer +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.utils import ( + get_bool_env_var, + get_device_capability, + log_info_on_rank0, + mxfp_supported, + set_weight_attrs, +) + +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + +logger = logging.getLogger(__name__) + +use_dynamic_mxfp4_linear = get_bool_env_var("SGLANG_USE_DYNAMIC_MXFP4_linear") + +OCP_MX_BLOCK_SIZE = 32 + + +class MxFp4Config(QuantizationConfig): + + def __init__( + self, + is_checkpoint_fp4_serialized: bool = False, + quant_config: dict[str, Any] = None, + kv_cache_group: Optional[list[str]] = None, + kv_cache_config: Optional[dict[str, Any]] = None, + pack_method: str = "reorder", + ignored_layers: Optional[List[str]] = None, + ): + super().__init__() + if kv_cache_group is None: + kv_cache_group = [] + + self.is_checkpoint_fp4_serialized = is_checkpoint_fp4_serialized + self.quant_config = quant_config + self.kv_cache_group = kv_cache_group + self.kv_cache_config = kv_cache_config + self.pack_method = pack_method + + self.packed_modules_mapping = ( + self.quant_config["packed_modules_mapping"] + if is_checkpoint_fp4_serialized + else None + ) + + self.ignored_layers = ignored_layers or [] + + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def get_name(self) -> str: + return "fp4" + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + # Check if the layer is skipped for quantization. + if len(self.ignored_layers) > 0 and should_ignore_layer( + prefix, + ignore=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + + if isinstance(layer, LinearBase): + if self.is_checkpoint_fp4_serialized: + scheme = self.get_scheme(layer=layer, layer_name=prefix) + layer.scheme = scheme + return MxFp4LinearMethod(self) + + elif use_dynamic_mxfp4_linear: + return MxFp4LinearMethod(self) + else: + return UnquantizedLinearMethod() + + if isinstance(layer, RadixAttention): + return MxFp4KVCacheMethod(self) + + if isinstance(layer, FusedMoE): + return MxFp4MoEMethod.get_moe_method(self, module=layer, layer_name=prefix) + + return None + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "MxFp4Config": + if not mxfp_supported(): + platform = torch.cuda.get_device_properties(0).gcnArchName + raise ValueError( + f"Current platform {platform} not support mxfp4 computation" + ) + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp4_serialized = ( + True if quant_method else False + ) # "quark" in quant_method + + kv_cache_group = [] + pack_method = None + + if is_checkpoint_fp4_serialized: + export_config = config.get("export") + if export_config is None: + raise ValueError( + "The export key should be included in " + "the configurations of Quark quantized model" + ) + + kv_cache_group = cast(list[str], export_config.get("kv_cache_group")) + pack_method = cast(str, export_config.get("pack_method")) + + # In the export model of quark, the quantization configuration + # of kv_cache is stored in layer_quant_config. First, it is + # judged whether kv_cache_group exists, and then it is judged + # whether layer_quant_config has a quantization configuration + # that matches kv_cache. + if len(kv_cache_group) == 0: + kv_cache_config = None + else: + kv_cache_set = set(kv_cache_group) + layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config")) + layer_quant_names = list(layer_quant_config.keys()) + layer_quant_set = set(layer_quant_names) + + if not kv_cache_set.issubset(layer_quant_set): + raise ValueError( + "The Quark quantized model has the " + "kv_cache_group parameter setting, " + "but no kv_cache quantization settings " + "were found in the quantization " + "configuration." + ) + + q_configs = [ + cast(dict[str, Any], layer_quant_config.get(name)) + for name in kv_cache_group + ] + if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs): + raise ValueError( + "The quantization method used for kv_cache should " + "be the same, but the quantization method for the " + "kv_cache layer in the config is different." + ) + kv_cache_config = q_configs[0].get("output_tensors") + if kv_cache_config is None: + raise ValueError("The kv_cache quantization configuration is empty.") + + # Since we have already set kv_cache quantization configurations, + # we will remove the quantization configuration for the + # output_tensors corresponding to the kv_cache layer. + for q_config in q_configs: + q_config["output_tensors"] = None + + # In case q_proj output is also quantized, remove the configuration + # to keep qkv consistency. + q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj")) + if q_proj_q_config is not None: + q_proj_q_config["output_tensors"] = None + + ignored_layers = cls.get_from_keys_or(config, ["exclude"], None) + + return cls( + is_checkpoint_fp4_serialized=is_checkpoint_fp4_serialized, + quant_config=config, + kv_cache_group=kv_cache_group, + kv_cache_config=kv_cache_config, + pack_method=pack_method, + ignored_layers=ignored_layers, + ) + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: + capability_tuple = get_device_capability() + + if capability_tuple is not None: + assert 0 <= capability_tuple[1] < 10 + capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.", + ) + return supported + else: + return False + + def _is_mx_fp4( + self, + weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]], + ) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + logger.debug( + "Quark model is not in MX-FP4 format: " + "weight_quant or input_quant not set" + ) + return False + + # Input and weight dtype needs to be fp4. + if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4": + logger.debug("Quark model is not in MX-FP4 format: dtype not fp4") + return False + + # Input and weight qscheme needs to be per group. + if ( + weight_quant.get("qscheme") != "per_group" + or input_quant.get("qscheme") != "per_group" + ): + logger.debug("Quark model is not in MX-FP4 format: not per_group") + return False + + # Input and weight group size needs to be 32. + if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32: + logger.debug("Quark model is not in MX-FP4 format: not group_size=32") + return False + + # Weights need to use static quantization. + if weight_quant.get("is_dynamic") is True: + logger.debug("Quark model is not in MX-FP4 format: not weight static") + return False + + # Activations need to use dynamic quantization. + if input_quant.get("is_dynamic") is False: + logger.debug("Quark model is not in MX-FP4 format: not activation dynamic") + return False + + # Activations and weight scales need to be in e8m0 format. + if ( + weight_quant.get("scale_format") != "e8m0" + or input_quant.get("scale_format") != "e8m0" + ): + logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0") + return False + + return True + + def _find_matched_config( + self, layer_name: str, module: torch.nn.Module + ) -> dict[str, Any]: + + proj_name = layer_name.split(".")[-1] + if proj_name in self.packed_modules_mapping: + shard_proj_names = self.packed_modules_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + shard_configs = [ + self._find_matched_config(shard_name, module) + for shard_name in shard_names + ] + if not all( + deep_compare(q_config, shard_configs[0]) for q_config in shard_configs + ): + raise ValueError( + f"Found a different quantization configuration for " + f"{shard_proj_names=} in {layer_name=}. vLLM " + "requires all to use the same scheme." + ) + return shard_configs[0] + else: + layer_quant_config = cast( + dict[str, Any], self.quant_config.get("layer_quant_config") + ) + for name_pattern in layer_quant_config: + if fnmatch.fnmatch(layer_name, name_pattern): + return layer_quant_config[name_pattern] + + layer_type = cast(str, type(module)) + layer_type_quant_config = cast( + dict[str, Any], self.quant_config.get("layer_type_quant_config") + ) + if layer_type in layer_type_quant_config: + return layer_type_quant_config[layer_type] + + global_quant_config = cast( + dict[str, Any], self.quant_config.get("global_quant_config") + ) + return global_quant_config + + def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": + if config.get("output_tensors") or config.get("bias"): + raise NotImplementedError( + "Currently, Quark models with output_tensors " + "and bias quantized are not supported" + ) + weight_config = cast(dict[str, Any], config.get("weight")) + input_config = cast(dict[str, Any], config.get("input_tensors")) + + if self._is_mx_fp4(weight_config, input_config): + return QuarkW4A4MXFP4(weight_config, input_config) + + raise NotImplementedError( + "No quark compatible scheme was found. " + f"{weight_config=}, " + f"{input_config=}" + ) + + def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme": + + layer_quant_config = self._find_matched_config(layer_name, layer) + + # Find the quant_scheme + scheme = self._get_scheme_from_config(layer_quant_config) + + # Raise error if device does not support the scheme + # (e.g. fp8 needs ada lovelace) + self._check_scheme_supported(scheme.get_min_capability()) + + return scheme + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class MxFp4LinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: MxFp4Config): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + return + # if self.quantization_config.is_checkpoint_fp4_serialized: + # layer.scheme.process_weights_after_loading(layer) + # else: + # #w, w_scales = dynamic_mxfp4_quant(layer.weight.data) + # ##log_info_on_rank0(logger, f"w.shape: {w.shape}") + + # #wshuffle = w#shuffle_weight(w, layout=(16, 16)) + # #w_scales_shuffle = w_scales#e8m0_shuffle(w_scales).view(dtypes.fp8_e8m0) + + # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32) + + # w, w_scales_shuffle = quant_func(layer.weight.data, shuffle=True) + + # wshuffle = shuffle_weight(w, layout=(16, 16)) + + # layer.weight = torch.nn.Parameter(wshuffle, + # requires_grad=False) + # layer.weight_scale = torch.nn.Parameter(w_scales_shuffle, + # requires_grad=False) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + weight_loader = extra_weight_attrs.get("weight_loader") + + if self.quantization_config.is_checkpoint_fp4_serialized: + layer.scheme.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader, + ) + else: + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + weight_dtype = params_dtype + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight", weight) + layer.register_parameter("weight_scale", None) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + if self.quantization_config.is_checkpoint_fp4_serialized: + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) + else: + out_dtype = x.dtype + + # ck or asm implement + # M = x.shape[0] + # N = layer.weight.shape[0] + + # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32) + + # x, x_scales_shuffle = quant_func(x, shuffle=True) + + # y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=out_dtype) + + # out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias) + + # return out[:M] + + # triton implement + x_q, x_s = dynamic_mxfp4_quant(x) + y = torch.empty( + x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype + ) + + out = gemm_afp4wfp4( + x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y + ) + + return out + + +class MxFp4MoEMethod: + def __new__(cls, *args, **kwargs): + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + @staticmethod + def get_moe_method( + quant_config: "MxFp4Config", # type: ignore # noqa E501 # noqa F821 + module: torch.nn.Module, + layer_name: str, + ) -> "MxFp4MoEMethod": + + if quant_config.is_checkpoint_fp4_serialized: + layer_quant_config = quant_config._find_matched_config(layer_name, module) + + if layer_quant_config.get("output_tensors") or layer_quant_config.get( + "bias" + ): + raise NotImplementedError( + "Currently, Quark models with " + "output_tensors and bias " + "quantized are not supported" + ) + weight_config = layer_quant_config.get("weight") + input_config = layer_quant_config.get("input_tensors") + + if quant_config._is_mx_fp4(weight_config, input_config): + return W4A4MXFp4MoEStaticMethod(weight_config, input_config) + else: + raise RuntimeError("Unsupported FusedMoe scheme") + else: + return W4A4MXFp4MoEDynamicMethod(quant_config) + + +class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod): + def __init__(self, quant_config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + + layer.w13_input_scale = None + layer.w2_input_scale = None + + def mxfp4_quantize(self, w): + w_shape = w.shape + w_need_reshape = True if w.dim() != 2 else False + + if w_need_reshape: + w_last_dim_size = w_shape[-1] + w = w.view(-1, w_last_dim_size) + + # log_info_on_rank0(logger, f"[Pre-quant] w.shape: {w.shape}") + w, mx_scales = dynamic_mxfp4_quant(w) + # log_info_on_rank0(logger, f"[Post-quant] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}") + + if w_need_reshape: + w_new_shape = w_shape[:-1] + (w.shape[-1],) + w = w.view(w_new_shape) + + # log_info_on_rank0(logger, f"[re-shape] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}") + + mx_scales = e8m0_shuffle(mx_scales) + + return w, mx_scales + + def process_weights_after_loading(self, layer: Module) -> None: + w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data) + w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False) + + layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + topk_output: TopKOutput, + *, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + topk_weights, topk_ids, _ = topk_output + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + quant_type=QuantType.per_1x32, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + activation=( + ActivationType.Silu if activation == "silu" else ActivationType.Gelu + ), + doweight_stage1=False, + ) + + +class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod): + + def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]): + self.weight_quant = weight_config + self.input_quant = input_config + + weight_qscheme = self.weight_quant.get("qscheme") + input_qscheme = self.input_quant.get("qscheme") + if not (weight_qscheme == "per_group" and input_qscheme == "per_group"): + raise ValueError( + "For MX(FP4) Fused MoE layers, only per-group scales " + "for weights and activations are supported. Found " + f"{weight_qscheme=}, {input_qscheme=}" + ) # noqa E501 + + self.static_input_scales = not self.input_quant.get("is_dynamic") + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) + + params_dtype = torch.uint8 + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // OCP_MX_BLOCK_SIZE, + dtype=params_dtype, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + float_dtype = torch.get_default_dtype() + + # Pre-shuffle weight scales + s0, s1, _ = layer.w13_weight_scale.shape + w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1) + w13_weight_scale = e8m0_shuffle(w13_weight_scale) + layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1) + + s0, s1, _ = layer.w2_weight_scale.shape + w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1) + w2_weight_scale = e8m0_shuffle(w2_weight_scale) + layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + topk_output: TopKOutput, + *, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + topk_weights, topk_ids, _ = topk_output + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + quant_type=QuantType.per_1x32, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + activation=( + ActivationType.Silu if activation == "silu" else ActivationType.Gelu + ), + doweight_stage1=False, + ) + + +class MxFp4KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from quark checkpoints. + """ + + def __init__(self, quant_config: MxFp4Config): + self.validate_kv_cache_config(quant_config.kv_cache_config) + super().__init__(quant_config) + + @staticmethod + def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]): + """ + Validator for the kv cache configuration. Useful for controlling the + kv cache quantization schemes, that are being supported in vLLM + :param kv_cache_config: the quark kv cache scheme + """ + if kv_cache_config is None: + return + + dtype = kv_cache_config.get("dtype") + if dtype != "fp8_e4m3": + raise NotImplementedError( + "Currently supported kv cache quantization is " + f"dtype=fp8_e4m3, however received {dtype}" + ) + + qscheme = kv_cache_config.get("qscheme") + if qscheme != "per_tensor": + raise NotImplementedError( + "Only support per-tensor scaling factor " + "for quark KV cache. " + f"Expected qscheme: per_tensor, found qscheme: {qscheme}" + ) diff --git a/python/sglang/srt/layers/quantization/quark/__init__.py b/python/sglang/srt/layers/quantization/quark/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sglang/srt/layers/quantization/quark/schemes/__init__.py b/python/sglang/srt/layers/quantization/quark/schemes/__init__.py new file mode 100644 index 000000000..91b08c512 --- /dev/null +++ b/python/sglang/srt/layers/quantization/quark/schemes/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .quark_scheme import QuarkScheme +from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4 + +__all__ = ["QuarkScheme", "QuarkW4A4MXFP4"] diff --git a/python/sglang/srt/layers/quantization/quark/schemes/quark_scheme.py b/python/sglang/srt/layers/quantization/quark/schemes/quark_scheme.py new file mode 100644 index 000000000..aab5c9c1e --- /dev/null +++ b/python/sglang/srt/layers/quantization/quark/schemes/quark_scheme.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +__all__ = ["QuarkScheme"] + + +class QuarkScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by Quark. + """ + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py new file mode 100644 index 000000000..e5fc22797 --- /dev/null +++ b/python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Optional + +import aiter +import torch +import torch.nn.functional as F +from aiter.ops.gemm_op_a4w4 import gemm_a4w4 +from aiter.ops.shuffle import shuffle_weight +from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 +from aiter.ops.triton.quant import dynamic_mxfp4_quant +from aiter.utility import dtypes +from aiter.utility.fp4_utils import e8m0_shuffle + +from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter +from sglang.srt.layers.quantization.quark.schemes import QuarkScheme +from sglang.srt.utils import get_bool_env_var + +__all__ = ["QuarkW4A4MXFP4"] + +OCP_MX_BLOCK_SIZE = 32 + + +class QuarkW4A4MXFP4(QuarkScheme): + + def __init__( + self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any] + ): + self.out_dtype = torch.get_default_dtype() + self.qscheme = "per_group" + self.weight_quant_spec = weight_quant_spec + self.input_quant_spec = input_quant_spec + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + return + + # for aiter implement + # wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16)) + # w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0) + + # layer.weight = torch.nn.Parameter(wshuffle, + # requires_grad=False) + # layer.weight_scale = torch.nn.Parameter(w_scales_shuffle, + # requires_grad=False) + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = PackedvLLMParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=2, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + out_dtype = x.dtype + # M = x.shape[0] + # N = layer.weight.shape[0] + + # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32) + # x, x_scales_shuffle = quant_func(x, shuffle=True) + + # y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype) + + # out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias) + + # return out[:M] + + # triton implement + x_q, x_s = dynamic_mxfp4_quant(x) + y = torch.empty( + x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype + ) + + out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y) + + return out diff --git a/python/sglang/srt/layers/quantization/quark/utils.py b/python/sglang/srt/layers/quantization/quark/utils.py new file mode 100644 index 000000000..5ea91b5d8 --- /dev/null +++ b/python/sglang/srt/layers/quantization/quark/utils.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from collections.abc import Iterable, Mapping +from types import MappingProxyType +from typing import Any, Optional + + +def deep_compare(dict1: Any, dict2: Any) -> bool: + if type(dict1) is not type(dict2): + return False + if isinstance(dict1, dict): + if dict1.keys() != dict2.keys(): + return False + return all(deep_compare(dict1[k], dict2[k]) for k in dict1) + elif isinstance(dict1, list): + return set(dict1) == set(dict2) + else: + return dict1 == dict2 + + +def should_ignore_layer( + layer_name: Optional[str], + ignore: Iterable[str], + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), +) -> bool: + if layer_name is None: + return False + + # layer_name = model.layers.0.self_attn.qkv_proj + # proj_name = qkv_proj + proj_name = layer_name.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping: + shard_proj_names = fused_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + + # Layer should be ignored if shards are ignored. + should_ignore_layer = None + for shard_name in shard_names: + should_ignore_shard = check_equal_or_regex_match( + layer_name=shard_name, targets=ignore + ) + + # If shard_idx=0, set layer ignore to match shard. + if should_ignore_layer is None: + should_ignore_layer = should_ignore_shard + + # If shard_idx=1+ confirm scheme matches prior shards. + elif should_ignore_shard != should_ignore_layer: + raise ValueError( + f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme." + ) + + # Unfused layers like down_proj and o_proj will match + # the safetensors checkpoint already. + else: + should_ignore_layer = check_equal_or_regex_match( + layer_name=layer_name, targets=ignore + ) + + assert should_ignore_layer is not None + + return should_ignore_layer + + +def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool: + """ + Checks whether a layer_name is exactly equal or a regex match for + if target starts with 're:' to any target in list. + """ + for target in targets: + if _is_equal_or_regex_match(layer_name, target): + return True + return False + + +def _is_equal_or_regex_match( + value: str, target: str, check_contains: bool = False +) -> bool: + """ + Checks whether a value is exactly equal or a regex match for target + if target starts with 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + """ + + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + return True + elif check_contains: + if target.lower() in value.lower(): + return True + elif target == value: + return True + return False diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 33f11b8af..a326e3f10 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -843,6 +843,16 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: return None return remapped_name + quark_scale_names = { + ".q_proj.output_scale": ".attn.q_scale", + ".k_proj.output_scale": ".attn.k_scale", + ".v_proj.output_scale": ".attn.v_scale", + "self_attn.prob_output_scale": ".attn.prob_scale", + } + for quark_scale_name, sglang_scale_name in quark_scale_names.items(): + if name.endswith(quark_scale_name): + return name.replace(quark_scale_name, sglang_scale_name) + # If there were no matches, return the untouched param name return name diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 009f926bf..913764b45 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2061,6 +2061,8 @@ class DeepseekV2Model(nn.Module): class DeepseekV2ForCausalLM(nn.Module): + # for quark model load + packed_modules_mapping = {} def __init__( self, @@ -2069,6 +2071,18 @@ class DeepseekV2ForCausalLM(nn.Module): prefix: str = "", ) -> None: super().__init__() + + # for quark model load + # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None + self.fuse_qkv_a_proj = ( + hasattr(config, "q_lora_rank") and config.q_lora_rank is not None + ) + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + self.config = config self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index aacaaf1cd..60d8efb9e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -813,6 +813,7 @@ class ServerArgs: "moe_wna16", "qoq", "w4afp8", + "mxfp4", ], help="The quantization method.", ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index db841b3fd..055c0b115 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2832,6 +2832,17 @@ def parse_module_path(module_path, function_name, create_dummy): return final_module, None +def mxfp_supported(): + """ + Returns whether the current platform supports MX types. + """ + if torch.version.hip: + gcn_arch = torch.cuda.get_device_properties(0).gcnArchName + return any(gfx in gcn_arch for gfx in ["gfx95"]) + else: + return False + + # LoRA-related constants and utilities SUPPORTED_LORA_TARGET_MODULES = [ "q_proj",