Combine fp4.py and mxfp4.py into one file and support dynamic mxfp4 quantization in mxfp4.py (#9049)

Co-authored-by: wunhuang <wunhuang@amd.com>
This commit is contained in:
kk
2025-08-17 10:01:54 +08:00
committed by GitHub
parent 384f8ab5ce
commit 1c1f8a118e
7 changed files with 760 additions and 557 deletions

View File

@@ -474,6 +474,7 @@ class FusedMoE(torch.nn.Module):
not expert_id
and self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and self.quant_config.is_static_cfg()
):
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
@@ -724,7 +725,11 @@ class FusedMoE(torch.nn.Module):
) -> None:
tp_rank = self.moe_tp_rank
if self.quant_config is not None and self.quant_config.get_name() == "mxfp4":
if (
self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and self.quant_config.is_static_cfg()
):
if "bias" in weight_name:
dim1 = loaded_weight.shape[1]
param.data[:, :dim1].copy_(loaded_weight)

View File

@@ -48,12 +48,6 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
is_mxfp_supported = mxfp_supported()
if is_mxfp_supported:
from sglang.srt.layers.quantization.fp4 import MxFp4Config
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import (
@@ -67,6 +61,9 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
_is_mxfp_supported = mxfp_supported()
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
@@ -98,11 +95,13 @@ if is_cuda():
"mxfp4": Mxfp4Config,
}
)
elif is_mxfp_supported and is_hip():
elif _is_mxfp_supported and is_hip():
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
BASE_QUANTIZATION_METHODS.update(
{
"quark": MxFp4Config,
"mxfp4": MxFp4Config,
"quark": QuarkConfig,
"mxfp4": Mxfp4Config,
}
)
# VLLM-dependent quantization methods

View File

@@ -1,540 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import fnmatch
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast
import aiter
import torch
import torch.nn.functional as F
from aiter import ActivationType, QuantType, dtypes
from aiter.fused_moe import fused_moe
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
from aiter.ops.quant import get_torch_quant
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from aiter.utility.fp4_utils import e8m0_shuffle
from torch.nn import Module
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.parameter import ModelWeightParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme, QuarkW4A4MXFP4
from sglang.srt.layers.quantization.quark.utils import deep_compare, should_ignore_layer
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import (
get_bool_env_var,
get_device_capability,
log_info_on_rank0,
mxfp_supported,
set_weight_attrs,
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
logger = logging.getLogger(__name__)
use_dynamic_mxfp4_linear = get_bool_env_var("SGLANG_USE_DYNAMIC_MXFP4_linear")
OCP_MX_BLOCK_SIZE = 32
class Mxfp4Config(QuantizationConfig):
def __init__(self, ignored_layers: Optional[list[str]] = None):
super().__init__()
self.ignored_layers = ignored_layers
@classmethod
def from_config(cls, config):
return cls()
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_name(cls) -> QuantizationMethods:
return "mxfp4"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
raise NotImplementedError("Mxfp4 linear layer is not implemented")
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None
class MxFp4LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: MxFp4Config):
self.quantization_config = quantization_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return
# if self.quantization_config.is_checkpoint_fp4_serialized:
# layer.scheme.process_weights_after_loading(layer)
# else:
# #w, w_scales = dynamic_mxfp4_quant(layer.weight.data)
# ##log_info_on_rank0(logger, f"w.shape: {w.shape}")
# #wshuffle = w#shuffle_weight(w, layout=(16, 16))
# #w_scales_shuffle = w_scales#e8m0_shuffle(w_scales).view(dtypes.fp8_e8m0)
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
# w, w_scales_shuffle = quant_func(layer.weight.data, shuffle=True)
# wshuffle = shuffle_weight(w, layout=(16, 16))
# layer.weight = torch.nn.Parameter(wshuffle,
# requires_grad=False)
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
# requires_grad=False)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
if self.quantization_config.is_checkpoint_fp4_serialized:
layer.scheme.create_weights(
layer=layer,
input_size=input_size,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader,
)
else:
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
weight_dtype = params_dtype
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
layer.register_parameter("weight_scale", None)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input. See LinearMethodBase for param details
"""
if self.quantization_config.is_checkpoint_fp4_serialized:
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
else:
out_dtype = x.dtype
# ck or asm implement
# M = x.shape[0]
# N = layer.weight.shape[0]
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
# x, x_scales_shuffle = quant_func(x, shuffle=True)
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=out_dtype)
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
# return out[:M]
# triton implement
x_q, x_s = dynamic_mxfp4_quant(x)
y = torch.empty(
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
)
out = gemm_afp4wfp4(
x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y
)
return out
class MxFp4MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Mxfp4Config):
self.quant_config = quant_config
@staticmethod
def get_moe_method(
quant_config: "MxFp4Config", # type: ignore # noqa E501 # noqa F821
module: torch.nn.Module,
layer_name: str,
) -> "MxFp4MoEMethod":
if quant_config.is_checkpoint_fp4_serialized:
layer_quant_config = quant_config._find_matched_config(layer_name, module)
if layer_quant_config.get("output_tensors") or layer_quant_config.get(
"bias"
):
raise NotImplementedError(
"Currently, Quark models with "
"output_tensors and bias "
"quantized are not supported"
)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_mx_fp4(weight_config, input_config):
return W4A4MXFp4MoEStaticMethod(weight_config, input_config)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
else:
return W4A4MXFp4MoEDynamicMethod(quant_config)
class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
def __init__(self, quant_config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
layer.w13_input_scale = None
layer.w2_input_scale = None
def mxfp4_quantize(self, w):
w_shape = w.shape
w_need_reshape = True if w.dim() != 2 else False
if w_need_reshape:
w_last_dim_size = w_shape[-1]
w = w.view(-1, w_last_dim_size)
# log_info_on_rank0(logger, f"[Pre-quant] w.shape: {w.shape}")
w, mx_scales = dynamic_mxfp4_quant(w)
# log_info_on_rank0(logger, f"[Post-quant] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}")
if w_need_reshape:
w_new_shape = w_shape[:-1] + (w.shape[-1],)
w = w.view(w_new_shape)
# log_info_on_rank0(logger, f"[re-shape] w.shape: {w.shape} mx_scales.shape: {mx_scales.shape}")
mx_scales = e8m0_shuffle(mx_scales)
return w, mx_scales
def process_weights_after_loading(self, layer: Module) -> None:
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)
class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]):
self.weight_quant = weight_config
self.input_quant = input_config
weight_qscheme = self.weight_quant.get("qscheme")
input_qscheme = self.input_quant.get("qscheme")
if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
raise ValueError(
"For MX(FP4) Fused MoE layers, only per-group scales "
"for weights and activations are supported. Found "
f"{weight_qscheme=}, {input_qscheme=}"
) # noqa E501
self.static_input_scales = not self.input_quant.get("is_dynamic")
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
)
params_dtype = torch.uint8
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 2,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // 2,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
hidden_size,
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
float_dtype = torch.get_default_dtype()
# Pre-shuffle weight scales
s0, s1, _ = layer.w13_weight_scale.shape
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
s0, s1, _ = layer.w2_weight_scale.shape
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)
class MxFp4KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from quark checkpoints.
"""
def __init__(self, quant_config: MxFp4Config):
self.validate_kv_cache_config(quant_config.kv_cache_config)
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
"""
Validator for the kv cache configuration. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_config: the quark kv cache scheme
"""
if kv_cache_config is None:
return
dtype = kv_cache_config.get("dtype")
if dtype != "fp8_e4m3":
raise NotImplementedError(
"Currently supported kv cache quantization is "
f"dtype=fp8_e4m3, however received {dtype}"
)
qscheme = kv_cache_config.get("qscheme")
if qscheme != "per_tensor":
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for quark KV cache. "
f"Expected qscheme: per_tensor, found qscheme: {qscheme}"
)

View File

@@ -38,6 +38,7 @@ from sglang.srt.utils import (
is_hip,
is_triton_kernels_available,
log_info_on_rank0,
mxfp_supported,
next_power_of_2,
round_up,
set_weight_attrs,
@@ -61,7 +62,14 @@ if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
OCP_MX_BLOCK_SIZE = 32
_is_hip = is_hip()
if _is_hip:
# import aiter
from aiter import ActivationType, QuantType, dtypes
from aiter.fused_moe import fused_moe
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from aiter.utility.fp4_utils import e8m0_shuffle
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
@@ -162,13 +170,34 @@ except AttributeError as error:
class Mxfp4Config(QuantizationConfig):
def __init__(self, ignored_layers: Optional[list[str]] = None):
def __init__(
self,
ignored_layers: Optional[list[str]] = None,
is_checkpoint_mxfp4_serialized: bool = False,
):
super().__init__()
self.is_checkpoint_mxfp4_serialized = is_checkpoint_mxfp4_serialized
self.ignored_layers = ignored_layers
@classmethod
def from_config(cls, config):
return cls()
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_mxfp4_serialized = "mxfp4" in quant_method
if _is_hip:
if mxfp_supported():
return cls(
is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized
)
else:
platform = torch.cuda.get_device_properties(0).gcnArchName
raise ValueError(
f"Current platform {platform} not support mxfp4 computation"
)
return cls(is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized)
@classmethod
def get_min_capability(cls) -> int:
@@ -186,6 +215,9 @@ class Mxfp4Config(QuantizationConfig):
def get_config_filenames(cls) -> list[str]:
return []
def is_static_cfg(self):
return self.is_checkpoint_mxfp4_serialized
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
@@ -201,10 +233,16 @@ class Mxfp4Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
elif _is_hip:
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(prefix)
if self.is_checkpoint_mxfp4_serialized:
return Mxfp4MoEMethod(prefix=prefix)
else:
return Mxfp4DynamicQuantMoEMethod()
else:
raise NotImplementedError("Mxfp4 attention layer is not implemented")
if self.is_checkpoint_mxfp4_serialized:
raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None
def get_scaled_act_names(self) -> List[str]:
@@ -655,3 +693,116 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
)
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
layer.w13_input_scale = None
layer.w2_input_scale = None
def mxfp4_quantize(self, w):
w_shape = w.shape
w_need_reshape = True if w.dim() != 2 else False
if w_need_reshape:
w_last_dim_size = w_shape[-1]
w = w.view(-1, w_last_dim_size)
w, mx_scales = dynamic_mxfp4_quant(w)
if w_need_reshape:
w_new_shape = w_shape[:-1] + (w.shape[-1],)
w = w.view(w_new_shape)
mx_scales = e8m0_shuffle(mx_scales)
return w, mx_scales
def process_weights_after_loading(self, layer: Module) -> None:
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)

View File

@@ -0,0 +1,390 @@
# SPDX-License-Identifier: Apache-2.0
import fnmatch
import logging
from typing import Any, List, Optional, cast
import torch
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import ( # noqa: E501
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.quark.quark_moe import QuarkMoEMethod
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme, QuarkW4A4MXFP4
from sglang.srt.layers.quantization.quark.utils import deep_compare, should_ignore_layer
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_device_capability
__all__ = ["QuarkLinearMethod"]
logger = logging.getLogger(__name__)
class QuarkConfig(QuantizationConfig):
def __init__(
self,
quant_config: dict[str, Any],
kv_cache_group: Optional[list[str]] = None,
kv_cache_config: Optional[dict[str, Any]] = None,
pack_method: str = "reorder",
):
super().__init__()
if kv_cache_group is None:
kv_cache_group = []
self.quant_config = quant_config
self.kv_cache_group = kv_cache_group
self.kv_cache_config = kv_cache_config
self.pack_method = pack_method
self.packed_modules_mapping = self.quant_config["packed_modules_mapping"]
def get_linear_method(self) -> "QuarkLinearMethod":
return QuarkLinearMethod(self)
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 70
def get_name(self) -> str:
return "quark"
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
# Check if the layer is skipped for quantization.
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer(
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
return QuarkLinearMethod(self)
if isinstance(layer, RadixAttention):
return QuarkKVCacheMethod(self)
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, FusedMoE):
return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
return None
@classmethod
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
export_config = config.get("export")
if export_config is None:
raise ValueError(
"The export key should be included in "
"the configurations of Quark quantized model"
)
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
pack_method = cast(str, export_config.get("pack_method"))
# In the export model of quark, the quantization configuration
# of kv_cache is stored in layer_quant_config. First, it is
# judged whether kv_cache_group exists, and then it is judged
# whether layer_quant_config has a quantization configuration
# that matches kv_cache.
if len(kv_cache_group) == 0:
kv_cache_config = None
else:
kv_cache_set = set(kv_cache_group)
layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config"))
layer_quant_names = list(layer_quant_config.keys())
layer_quant_set = set(layer_quant_names)
if not kv_cache_set.issubset(layer_quant_set):
raise ValueError(
"The Quark quantized model has the "
"kv_cache_group parameter setting, "
"but no kv_cache quantization settings "
"were found in the quantization "
"configuration."
)
q_configs = [
cast(dict[str, Any], layer_quant_config.get(name))
for name in kv_cache_group
]
if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):
raise ValueError(
"The quantization method used for kv_cache should "
"be the same, but the quantization method for the "
"kv_cache layer in the config is different."
)
kv_cache_config = q_configs[0].get("output_tensors")
if kv_cache_config is None:
raise ValueError("The kv_cache quantization configuration is empty.")
# Since we have already set kv_cache quantization configurations,
# we will remove the quantization configuration for the
# output_tensors corresponding to the kv_cache layer.
for q_config in q_configs:
q_config["output_tensors"] = None
# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj"))
if q_proj_q_config is not None:
q_proj_q_config["output_tensors"] = None
return cls(
quant_config=config,
kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config,
pack_method=pack_method,
)
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool:
capability_tuple = get_device_capability()
if capability_tuple is not None:
assert 0 <= capability_tuple[1] < 10
capability = capability_tuple[0] * 10 + capability_tuple[1]
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.",
)
return supported
else:
return False
def _is_mx_fp4(
self,
weight_quant: Optional[dict[str, Any]],
input_quant: Optional[dict[str, Any]],
) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
logger.debug(
"Quark model is not in MX-FP4 format: "
"weight_quant or input_quant not set"
)
return False
# Input and weight dtype needs to be fp4.
if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4":
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
return False
# Input and weight qscheme needs to be per group.
if (
weight_quant.get("qscheme") != "per_group"
or input_quant.get("qscheme") != "per_group"
):
logger.debug("Quark model is not in MX-FP4 format: not per_group")
return False
# Input and weight group size needs to be 32.
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
logger.debug("Quark model is not in MX-FP4 format: not group_size=32")
return False
# Weights need to use static quantization.
if weight_quant.get("is_dynamic") is True:
logger.debug("Quark model is not in MX-FP4 format: not weight static")
return False
# Activations need to use dynamic quantization.
if input_quant.get("is_dynamic") is False:
logger.debug("Quark model is not in MX-FP4 format: not activation dynamic")
return False
# Activations and weight scales need to be in e8m0 format.
if (
weight_quant.get("scale_format") != "e8m0"
or input_quant.get("scale_format") != "e8m0"
):
logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0")
return False
return True
def _find_matched_config(
self, layer_name: str, module: torch.nn.Module
) -> dict[str, Any]:
proj_name = layer_name.split(".")[-1]
if proj_name in self.packed_modules_mapping:
shard_proj_names = self.packed_modules_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
shard_configs = [
self._find_matched_config(shard_name, module)
for shard_name in shard_names
]
if not all(
deep_compare(q_config, shard_configs[0]) for q_config in shard_configs
):
raise ValueError(
f"Found a different quantization configuration for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme."
)
return shard_configs[0]
else:
layer_quant_config = cast(
dict[str, Any], self.quant_config.get("layer_quant_config")
)
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]
layer_type = type(module).__name__
layer_type_quant_config = cast(
dict[str, Any], self.quant_config.get("layer_type_quant_config")
)
if layer_type in layer_type_quant_config:
return layer_type_quant_config[layer_type]
global_quant_config = cast(
dict[str, Any], self.quant_config.get("global_quant_config")
)
return global_quant_config
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with output_tensors "
"and bias quantized are not supported"
)
weight_config = cast(dict[str, Any], config.get("weight"))
input_config = cast(dict[str, Any], config.get("input_tensors"))
if self._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFP4(weight_config, input_config)
raise NotImplementedError(
"No quark compatible scheme was found. "
f"Weight config: {weight_config}, "
f"Input config: {input_config}"
)
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
layer_quant_config = self._find_matched_config(layer_name, layer)
# Find the quant_scheme
scheme = self._get_scheme_from_config(layer_quant_config)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return scheme
def get_scaled_act_names(self) -> List[str]:
return []
class QuarkLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: QuarkConfig):
self.quantization_config = quantization_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scheme.process_weights_after_loading(layer)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
layer.scheme.create_weights(
layer=layer,
input_size=input_size,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader,
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input. See LinearMethodBase for param details
"""
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
class QuarkKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from quark checkpoints.
"""
def __init__(self, quant_config: QuarkConfig):
self.validate_kv_cache_config(quant_config.kv_cache_config)
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
"""
Validator for the kv cache configuration. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_config: the quark kv cache scheme
"""
if kv_cache_config is None:
return
dtype = kv_cache_config.get("dtype")
if dtype != "fp8_e4m3":
raise NotImplementedError(
"Currently supported kv cache quantization is "
f"dtype=fp8_e4m3, however received {dtype}"
)
qscheme = kv_cache_config.get("qscheme")
if qscheme != "per_tensor":
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for quark KV cache. "
f"Expected qscheme: per_tensor, found qscheme: {qscheme}"
)

View File

@@ -0,0 +1,197 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
from aiter import ActivationType, QuantType, biased_grouped_topk
from aiter.fused_moe import fused_moe
from aiter.utility.fp4_utils import e8m0_shuffle
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
logger = logging.getLogger(__name__)
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
OCP_MX_BLOCK_SIZE = 32
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
class QuarkMoEMethod:
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
@staticmethod
def get_moe_method(
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
module: torch.nn.Module,
layer_name: str,
) -> "QuarkMoEMethod":
layer_quant_config = quant_config._find_matched_config(layer_name, module)
if layer_quant_config.get("output_tensors") or layer_quant_config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with "
"output_tensors and bias "
"quantized are not supported"
)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]):
self.weight_quant = weight_config
self.input_quant = input_config
weight_qscheme = self.weight_quant.get("qscheme")
input_qscheme = self.input_quant.get("qscheme")
if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
raise ValueError(
"For MX(FP4) Fused MoE layers, only per-group scales "
"for weights and activations are supported. Found "
f"{weight_qscheme}, {input_qscheme}"
) # noqa E501
self.static_input_scales = not self.input_quant.get("is_dynamic")
self.with_bias = False
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
)
params_dtype = torch.uint8
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 2,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // 2,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
hidden_size,
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
float_dtype = torch.get_default_dtype()
# Pre-shuffle weight scales
s0, s1, _ = layer.w13_weight_scale.shape
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
# layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, requires_grad=False)
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
s0, s1, _ = layer.w2_weight_scale.shape
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)

View File

@@ -33,6 +33,7 @@ from sglang.srt.utils import (
configure_ipv6,
get_device,
get_device_memory_capacity,
is_cuda,
is_flashinfer_available,
is_hip,
is_port_available,
@@ -2165,9 +2166,9 @@ class ServerArgs:
model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
if is_sm100_supported():
if is_cuda() and is_sm100_supported():
self.attention_backend = "trtllm_mha"
elif is_sm90_supported():
elif is_cuda() and is_sm90_supported():
self.attention_backend = "fa3"
else:
self.attention_backend = "triton"