Support mxfp4 for GPT-OSS (#8843)

Co-authored-by: Co-author fzyzcjy <ch271828n@outlook.com>
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Co-authored-by: zhuofan1123 <zhuofanl@nvidia.com>
Co-authored-by: liz-badada <jinyanc@nvidia.com>
Co-authored-by: xutizhou <xutingz@nvidia.com>
Co-authored-by: linhu-nv <linhu@nvidia.com>
This commit is contained in:
Ying Sheng
2025-08-06 00:05:25 -07:00
committed by GitHub
parent cbbb738371
commit 168033d5fb
9 changed files with 791 additions and 325 deletions

View File

@@ -50,315 +50,50 @@ use_dynamic_mxfp4_linear = get_bool_env_var("SGLANG_USE_DYNAMIC_MXFP4_linear")
OCP_MX_BLOCK_SIZE = 32
class MxFp4Config(QuantizationConfig):
class Mxfp4Config(QuantizationConfig):
def __init__(
self,
is_checkpoint_fp4_serialized: bool = False,
quant_config: dict[str, Any] = None,
kv_cache_group: Optional[list[str]] = None,
kv_cache_config: Optional[dict[str, Any]] = None,
pack_method: str = "reorder",
ignored_layers: Optional[List[str]] = None,
):
def __init__(self, ignored_layers: Optional[list[str]] = None):
super().__init__()
if kv_cache_group is None:
kv_cache_group = []
self.ignored_layers = ignored_layers
self.is_checkpoint_fp4_serialized = is_checkpoint_fp4_serialized
self.quant_config = quant_config
self.kv_cache_group = kv_cache_group
self.kv_cache_config = kv_cache_config
self.pack_method = pack_method
self.packed_modules_mapping = (
self.quant_config["packed_modules_mapping"]
if is_checkpoint_fp4_serialized
else None
)
self.ignored_layers = ignored_layers or []
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def from_config(cls, config):
return cls()
@classmethod
def get_min_capability(cls) -> int:
return 70
def get_name(self) -> str:
return "fp4"
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
# Check if the layer is skipped for quantization.
if len(self.ignored_layers) > 0 and should_ignore_layer(
prefix,
ignore=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
if self.is_checkpoint_fp4_serialized:
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
return MxFp4LinearMethod(self)
elif use_dynamic_mxfp4_linear:
return MxFp4LinearMethod(self)
else:
return UnquantizedLinearMethod()
if isinstance(layer, RadixAttention):
return MxFp4KVCacheMethod(self)
if isinstance(layer, FusedMoE):
return MxFp4MoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
return None
return 80
@classmethod
def from_config(cls, config: dict[str, Any]) -> "MxFp4Config":
if not mxfp_supported():
platform = torch.cuda.get_device_properties(0).gcnArchName
raise ValueError(
f"Current platform {platform} not support mxfp4 computation"
)
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp4_serialized = (
True if quant_method else False
) # "quark" in quant_method
def get_name(cls) -> QuantizationMethods:
return "mxfp4"
kv_cache_group = []
pack_method = None
if is_checkpoint_fp4_serialized:
export_config = config.get("export")
if export_config is None:
raise ValueError(
"The export key should be included in "
"the configurations of Quark quantized model"
)
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
pack_method = cast(str, export_config.get("pack_method"))
# In the export model of quark, the quantization configuration
# of kv_cache is stored in layer_quant_config. First, it is
# judged whether kv_cache_group exists, and then it is judged
# whether layer_quant_config has a quantization configuration
# that matches kv_cache.
if len(kv_cache_group) == 0:
kv_cache_config = None
else:
kv_cache_set = set(kv_cache_group)
layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config"))
layer_quant_names = list(layer_quant_config.keys())
layer_quant_set = set(layer_quant_names)
if not kv_cache_set.issubset(layer_quant_set):
raise ValueError(
"The Quark quantized model has the "
"kv_cache_group parameter setting, "
"but no kv_cache quantization settings "
"were found in the quantization "
"configuration."
)
q_configs = [
cast(dict[str, Any], layer_quant_config.get(name))
for name in kv_cache_group
]
if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):
raise ValueError(
"The quantization method used for kv_cache should "
"be the same, but the quantization method for the "
"kv_cache layer in the config is different."
)
kv_cache_config = q_configs[0].get("output_tensors")
if kv_cache_config is None:
raise ValueError("The kv_cache quantization configuration is empty.")
# Since we have already set kv_cache quantization configurations,
# we will remove the quantization configuration for the
# output_tensors corresponding to the kv_cache layer.
for q_config in q_configs:
q_config["output_tensors"] = None
# In case q_proj output is also quantized, remove the configuration
# to keep qkv consistency.
q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj"))
if q_proj_q_config is not None:
q_proj_q_config["output_tensors"] = None
ignored_layers = cls.get_from_keys_or(config, ["exclude"], None)
return cls(
is_checkpoint_fp4_serialized=is_checkpoint_fp4_serialized,
quant_config=config,
kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config,
pack_method=pack_method,
ignored_layers=ignored_layers,
)
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool:
capability_tuple = get_device_capability()
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if capability_tuple is not None:
assert 0 <= capability_tuple[1] < 10
capability = capability_tuple[0] * 10 + capability_tuple[1]
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.",
)
return supported
else:
return False
def _is_mx_fp4(
self,
weight_quant: Optional[dict[str, Any]],
input_quant: Optional[dict[str, Any]],
) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
logger.debug(
"Quark model is not in MX-FP4 format: "
"weight_quant or input_quant not set"
)
return False
# Input and weight dtype needs to be fp4.
if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4":
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
return False
# Input and weight qscheme needs to be per group.
if (
weight_quant.get("qscheme") != "per_group"
or input_quant.get("qscheme") != "per_group"
):
logger.debug("Quark model is not in MX-FP4 format: not per_group")
return False
# Input and weight group size needs to be 32.
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
logger.debug("Quark model is not in MX-FP4 format: not group_size=32")
return False
# Weights need to use static quantization.
if weight_quant.get("is_dynamic") is True:
logger.debug("Quark model is not in MX-FP4 format: not weight static")
return False
# Activations need to use dynamic quantization.
if input_quant.get("is_dynamic") is False:
logger.debug("Quark model is not in MX-FP4 format: not activation dynamic")
return False
# Activations and weight scales need to be in e8m0 format.
if (
weight_quant.get("scale_format") != "e8m0"
or input_quant.get("scale_format") != "e8m0"
):
logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0")
return False
return True
def _find_matched_config(
self, layer_name: str, module: torch.nn.Module
) -> dict[str, Any]:
proj_name = layer_name.split(".")[-1]
if proj_name in self.packed_modules_mapping:
shard_proj_names = self.packed_modules_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
shard_configs = [
self._find_matched_config(shard_name, module)
for shard_name in shard_names
]
if not all(
deep_compare(q_config, shard_configs[0]) for q_config in shard_configs
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
raise ValueError(
f"Found a different quantization configuration for "
f"{shard_proj_names=} in {layer_name=}. vLLM "
"requires all to use the same scheme."
)
return shard_configs[0]
else:
layer_quant_config = cast(
dict[str, Any], self.quant_config.get("layer_quant_config")
)
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]
layer_type = cast(str, type(module))
layer_type_quant_config = cast(
dict[str, Any], self.quant_config.get("layer_type_quant_config")
)
if layer_type in layer_type_quant_config:
return layer_type_quant_config[layer_type]
global_quant_config = cast(
dict[str, Any], self.quant_config.get("global_quant_config")
)
return global_quant_config
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with output_tensors "
"and bias quantized are not supported"
)
weight_config = cast(dict[str, Any], config.get("weight"))
input_config = cast(dict[str, Any], config.get("input_tensors"))
if self._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFP4(weight_config, input_config)
raise NotImplementedError(
"No quark compatible scheme was found. "
f"{weight_config=}, "
f"{input_config=}"
)
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
layer_quant_config = self._find_matched_config(layer_name, layer)
# Find the quant_scheme
scheme = self._get_scheme_from_config(layer_quant_config)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return scheme
def get_scaled_act_names(self) -> List[str]:
return []
return UnquantizedLinearMethod()
raise NotImplementedError("Mxfp4 linear layer is not implemented")
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None
class MxFp4LinearMethod(LinearMethodBase):