Support NVFP4 quantized dense models on AMD CDNA2/CDNA3 GPUs (#7302)

Co-authored-by: HAI <hixiao@gmail.com>
Co-authored-by: Sai Enduri <saimanas.enduri@amd.com>
This commit is contained in:
Haohui Mai
2025-07-18 19:59:39 -07:00
committed by GitHub
parent 3964b352c3
commit d918ab7985
7 changed files with 361 additions and 0 deletions

View File

@@ -79,6 +79,7 @@ blackwell = [
srt_hip = [
"sglang[runtime_common]",
"torch",
"petit_kernel",
]
# xpu is not enabled in public vllm and torch whl,

View File

@@ -391,6 +391,7 @@ class ModelConfig:
"compressed-tensors",
"fbgemm_fp8",
"w8a8_fp8",
"petit_nvfp4",
]
optimized_quantization_methods = [
"fp8",
@@ -408,9 +409,11 @@ class ModelConfig:
"moe_wna16",
"qoq",
"w4afp8",
"petit_nvfp4",
]
compatible_quantization_methods = {
"modelopt_fp4": ["modelopt"],
"petit_nvfp4": ["modelopt"],
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
}

View File

@@ -53,6 +53,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"ModelOptFp8LinearMethod",
"ModelOptFp4LinearMethod",
"IPEXAWQLinearMethod",
"PetitNvFp4LinearMethod",
]
_is_cpu = is_cpu()

View File

@@ -58,6 +58,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp8Config,
)
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.utils import get_linear_quant_method
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
@@ -76,6 +77,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"compressed-tensors": CompressedTensorsConfig,
"qoq": QoQConfig,
"w4afp8": W4AFp8Config,
"petit_nvfp4": PetitNvFp4Config,
}
# VLLM-dependent quantization methods

View File

@@ -0,0 +1,249 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
import logging
from typing import Any, Callable, Dict, List, Optional
import regex as re
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.petit_utils import (
apply_petit_nvfp4_linear,
prepare_nvfp4_layer_for_petit,
verify_petit_nvfp4_supported,
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
# Initialize logger for the module
logger = logging.getLogger(__name__)
# Configuration class to support the NVFP4 quantized model generated by the ModelOpt quantization tool
class PetitNvFp4Config(QuantizationConfig):
"""Config class for Petit FP4."""
def __init__(
self,
is_checkpoint_nvfp4_serialized: bool = False,
kv_cache_quant_algo: str = None,
group_size: int = None,
exclude_modules: List[str] = None,
) -> None:
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
"Detected nvfp4 checkpoint. Please note that the "
"format is experimental and subject to change."
)
self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules
@classmethod
def get_name(cls) -> str:
return "petit_nvfp4"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
# Petit supports the gfx90a and gfx942 GPUs
return 90
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "PetitNvFp4Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
group_size = quant_config.get("group_size", None)
verify_petit_nvfp4_supported(quant_method, group_size)
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
if not kv_cache_quant_algo:
kv_cache_quant_algo = "auto"
exclude_modules = quant_config.get("exclude_modules", None)
if not (group_size and kv_cache_quant_algo and (exclude_modules is not None)):
logger.warning(
f"group_size: {group_size},"
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
f"exclude_modules: {exclude_modules}"
)
raise ValueError(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json"
)
return cls(
is_checkpoint_nvfp4_serialized,
kv_cache_quant_algo,
group_size,
exclude_modules,
)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
can_convert = cls.is_petit_nvfp4_compatible(hf_quant_cfg)
if can_convert:
return cls.get_name()
return None
@classmethod
def is_petit_nvfp4_compatible(cls, quant_config: Dict[str, Any]) -> bool:
quant_method = quant_config.get("quant_method", "").lower()
return quant_method == "modelopt"
def is_layer_excluded(self, prefix: str, exclude_modules: list):
for pattern in exclude_modules:
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
if re.fullmatch(regex_str, prefix):
return True
return False
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
prefix, self.exclude_modules
):
return UnquantizedLinearMethod()
return PetitNvFp4LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class PetitNvFp4LinearMethod(LinearMethodBase):
"""Linear method for NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
|Tensor Name | datatype | shape |
|----------------------------------------------------|
|input_scale | torch.float32 | scalar |
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|weight_scale | FP8-E4M3 | [X, Y] |
|weight_scale_2 | torch.float32 | scalar |
The weights are quantized per block of 16 elements.
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: PetitNvFp4Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
if input_size_per_partition % 16 != 0:
raise ValueError(
"Unsupported model when in features size is " "not multiple of 16"
)
weight_dtype = (
torch.float8_e4m3fn
if self.quant_config.is_checkpoint_nvfp4_serialized
else params_dtype
)
weight = ModelWeightParameter(
data=torch.empty(
# 2 fp4 data is packed in one uint8 in the input dimension
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
input_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale)
weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale_2", weight_scale_2)
weight_scale = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.group_size,
dtype=weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
input_scale_2 = layer.input_scale.max().to(torch.float32)
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
layer.alpha = Parameter(
layer.input_scale * layer.weight_scale_2, requires_grad=False
)
prepare_nvfp4_layer_for_petit(layer)
del layer.input_scale
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_petit_nvfp4_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)

View File

@@ -0,0 +1,104 @@
from typing import Optional
import torch
try:
from petit_kernel import mul_nvfp4_a16, process_nvfp4_scales, repack_nvfp4
except ImportError:
def _check_petit_nvfp4_supported(
quant_method: str, group_size: Optional[int]
) -> tuple[bool, Optional[str]]:
return (
False,
"Petit is not installed. Please install it with `pip install petit-kernel`.",
)
def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
raise ValueError(
"Petit is not installed. Please install it with `pip install petit-kernel`."
)
def apply_petit_nvfp4_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise ValueError(
"Petit is not installed. Please install it with `pip install petit-kernel`."
)
def _check_petit_nvfp4_supported(
quant_method: str, group_size: Optional[int]
) -> tuple[bool, Optional[str]]:
if quant_method != "NVFP4":
return (
False,
"Petit currently only supports: NVFP4"
" quantizations in sglang. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.",
)
if group_size is not None and group_size != 16:
return (
False,
"Petit currently only supports: group_size=16" " quantizations.",
)
return (True, None)
def verify_petit_nvfp4_supported(quant_method: str, group_size: Optional[int]) -> None:
supported, error_msg = _check_petit_nvfp4_supported(quant_method, group_size)
if not supported:
raise ValueError(error_msg)
def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
# Repack weights to petit format
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
qweight = layer.weight.view(torch.int32).contiguous()
petit_qweight = repack_nvfp4(qweight, size_n=part_size_n, size_k=part_size_k)
layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False)
# Permute scales
weight_scale = process_nvfp4_scales(
scales=layer.weight_scale, size_k=part_size_k, size_n=part_size_n
)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
return
def apply_petit_nvfp4_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n,)
# TODO: Use auto-tuning to find the performant solution_id
output = mul_nvfp4_a16(
a=reshaped_x,
b=weight,
s=weight_scale,
global_scale=weight_scale_2,
size_m=reshaped_x.size(0),
size_n=size_n,
size_k=size_k,
solution_id=-1,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -766,6 +766,7 @@ class ServerArgs:
"gguf",
"modelopt",
"modelopt_fp4",
"petit_nvfp4",
"w8a8_int8",
"w8a8_fp8",
"moe_wna16",