From d918ab7985580cebea03216a5e309058df449821 Mon Sep 17 00:00:00 2001 From: Haohui Mai Date: Fri, 18 Jul 2025 19:59:39 -0700 Subject: [PATCH] Support NVFP4 quantized dense models on AMD CDNA2/CDNA3 GPUs (#7302) Co-authored-by: HAI Co-authored-by: Sai Enduri --- python/pyproject.toml | 1 + python/sglang/srt/configs/model_config.py | 3 + python/sglang/srt/layers/linear.py | 1 + .../srt/layers/quantization/__init__.py | 2 + .../sglang/srt/layers/quantization/petit.py | 249 ++++++++++++++++++ .../srt/layers/quantization/petit_utils.py | 104 ++++++++ python/sglang/srt/server_args.py | 1 + 7 files changed, 361 insertions(+) create mode 100644 python/sglang/srt/layers/quantization/petit.py create mode 100644 python/sglang/srt/layers/quantization/petit_utils.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 7afb3581a..5b6501afd 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -79,6 +79,7 @@ blackwell = [ srt_hip = [ "sglang[runtime_common]", "torch", + "petit_kernel", ] # xpu is not enabled in public vllm and torch whl, diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 1a62178b9..7d7f2eb95 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -391,6 +391,7 @@ class ModelConfig: "compressed-tensors", "fbgemm_fp8", "w8a8_fp8", + "petit_nvfp4", ] optimized_quantization_methods = [ "fp8", @@ -408,9 +409,11 @@ class ModelConfig: "moe_wna16", "qoq", "w4afp8", + "petit_nvfp4", ] compatible_quantization_methods = { "modelopt_fp4": ["modelopt"], + "petit_nvfp4": ["modelopt"], "w8a8_int8": ["compressed-tensors", "compressed_tensors"], "w8a8_fp8": ["compressed-tensors", "compressed_tensors"], } diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 1c770193f..07be9a3c6 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -53,6 +53,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "ModelOptFp8LinearMethod", "ModelOptFp4LinearMethod", "IPEXAWQLinearMethod", + "PetitNvFp4LinearMethod", ] _is_cpu = is_cpu() diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 9995b72d0..d51186465 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -58,6 +58,7 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ModelOptFp8Config, ) from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config +from sglang.srt.layers.quantization.petit import PetitNvFp4Config from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.utils import get_linear_quant_method from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config @@ -76,6 +77,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "compressed-tensors": CompressedTensorsConfig, "qoq": QoQConfig, "w4afp8": W4AFp8Config, + "petit_nvfp4": PetitNvFp4Config, } # VLLM-dependent quantization methods diff --git a/python/sglang/srt/layers/quantization/petit.py b/python/sglang/srt/layers/quantization/petit.py new file mode 100644 index 000000000..e7ee3239f --- /dev/null +++ b/python/sglang/srt/layers/quantization/petit.py @@ -0,0 +1,249 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py + + +import logging +from typing import Any, Callable, Dict, List, Optional + +import regex as re +import torch +from torch.nn.parameter import Parameter + +from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod +from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter +from sglang.srt.layers.quantization.base_config import ( + LinearMethodBase, + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.petit_utils import ( + apply_petit_nvfp4_linear, + prepare_nvfp4_layer_for_petit, + verify_petit_nvfp4_supported, +) +from sglang.srt.layers.quantization.utils import is_layer_skipped + +# Initialize logger for the module +logger = logging.getLogger(__name__) + + +# Configuration class to support the NVFP4 quantized model generated by the ModelOpt quantization tool +class PetitNvFp4Config(QuantizationConfig): + """Config class for Petit FP4.""" + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool = False, + kv_cache_quant_algo: str = None, + group_size: int = None, + exclude_modules: List[str] = None, + ) -> None: + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning( + "Detected nvfp4 checkpoint. Please note that the " + "format is experimental and subject to change." + ) + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + @classmethod + def get_name(cls) -> str: + return "petit_nvfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # Petit supports the gfx90a and gfx942 GPUs + return 90 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "PetitNvFp4Config": + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + group_size = quant_config.get("group_size", None) + verify_petit_nvfp4_supported(quant_method, group_size) + + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method + kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] + if not kv_cache_quant_algo: + kv_cache_quant_algo = "auto" + exclude_modules = quant_config.get("exclude_modules", None) + if not (group_size and kv_cache_quant_algo and (exclude_modules is not None)): + logger.warning( + f"group_size: {group_size}," + f"kv_cache_quant_algo: {kv_cache_quant_algo}," + f"exclude_modules: {exclude_modules}" + ) + raise ValueError( + "NVFP4 quantization requires group size and " + "kv_cache_quant_algo specified in " + "hf_quant_config.json" + ) + return cls( + is_checkpoint_nvfp4_serialized, + kv_cache_quant_algo, + group_size, + exclude_modules, + ) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + can_convert = cls.is_petit_nvfp4_compatible(hf_quant_cfg) + if can_convert: + return cls.get_name() + return None + + @classmethod + def is_petit_nvfp4_compatible(cls, quant_config: Dict[str, Any]) -> bool: + quant_method = quant_config.get("quant_method", "").lower() + return quant_method == "modelopt" + + def is_layer_excluded(self, prefix: str, exclude_modules: list): + for pattern in exclude_modules: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") + if re.fullmatch(regex_str, prefix): + return True + return False + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded( + prefix, self.exclude_modules + ): + return UnquantizedLinearMethod() + return PetitNvFp4LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class PetitNvFp4LinearMethod(LinearMethodBase): + """Linear method for NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + |Tensor Name | datatype | shape | + |----------------------------------------------------| + |input_scale | torch.float32 | scalar | + |weight | NVFP4(SE2M1) | [1, X, y/2] | + |weight_scale | FP8-E4M3 | [X, Y] | + |weight_scale_2 | torch.float32 | scalar | + + The weights are quantized per block of 16 elements. + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError( + "NVFP4 quantization was selected, " + " dynamic quantization is not supported." + ) + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + if input_size_per_partition % 16 != 0: + raise ValueError( + "Unsupported model when in features size is " "not multiple of 16" + ) + + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype + ) + + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 data is packed in one uint8 in the input dimension + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("input_scale", input_scale) + + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + input_scale_2 = layer.input_scale.max().to(torch.float32) + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.input_scale = Parameter(input_scale_2, requires_grad=False) + layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + layer.alpha = Parameter( + layer.input_scale * layer.weight_scale_2, requires_grad=False + ) + + prepare_nvfp4_layer_for_petit(layer) + del layer.input_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_petit_nvfp4_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/python/sglang/srt/layers/quantization/petit_utils.py b/python/sglang/srt/layers/quantization/petit_utils.py new file mode 100644 index 000000000..529869f24 --- /dev/null +++ b/python/sglang/srt/layers/quantization/petit_utils.py @@ -0,0 +1,104 @@ +from typing import Optional + +import torch + +try: + from petit_kernel import mul_nvfp4_a16, process_nvfp4_scales, repack_nvfp4 +except ImportError: + + def _check_petit_nvfp4_supported( + quant_method: str, group_size: Optional[int] + ) -> tuple[bool, Optional[str]]: + return ( + False, + "Petit is not installed. Please install it with `pip install petit-kernel`.", + ) + + def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: + raise ValueError( + "Petit is not installed. Please install it with `pip install petit-kernel`." + ) + + def apply_petit_nvfp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise ValueError( + "Petit is not installed. Please install it with `pip install petit-kernel`." + ) + + +def _check_petit_nvfp4_supported( + quant_method: str, group_size: Optional[int] +) -> tuple[bool, Optional[str]]: + if quant_method != "NVFP4": + return ( + False, + "Petit currently only supports: NVFP4" + " quantizations in sglang. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.", + ) + if group_size is not None and group_size != 16: + return ( + False, + "Petit currently only supports: group_size=16" " quantizations.", + ) + return (True, None) + + +def verify_petit_nvfp4_supported(quant_method: str, group_size: Optional[int]) -> None: + supported, error_msg = _check_petit_nvfp4_supported(quant_method, group_size) + if not supported: + raise ValueError(error_msg) + + +def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: + # Repack weights to petit format + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + qweight = layer.weight.view(torch.int32).contiguous() + petit_qweight = repack_nvfp4(qweight, size_n=part_size_n, size_k=part_size_k) + layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False) + + # Permute scales + weight_scale = process_nvfp4_scales( + scales=layer.weight_scale, size_k=part_size_k, size_n=part_size_n + ) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + return + + +def apply_petit_nvfp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + # TODO: Use auto-tuning to find the performant solution_id + output = mul_nvfp4_a16( + a=reshaped_x, + b=weight, + s=weight_scale, + global_scale=weight_scale_2, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + solution_id=-1, + ) + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 20db0b4b9..4f9e17e05 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -766,6 +766,7 @@ class ServerArgs: "gguf", "modelopt", "modelopt_fp4", + "petit_nvfp4", "w8a8_int8", "w8a8_fp8", "moe_wna16",