Files
2025-09-15 14:58:11 +08:00

411 lines
17 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Union
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
logger = init_logger(__name__)
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
class ModelOptFp8Config(QuantizationConfig):
"""Config class for ModelOpt FP8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change.")
@classmethod
def get_name(cls) -> str:
return "modelopt"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 89
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
if quant_method not in QUANT_ALGOS:
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
" quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
return cls(is_checkpoint_fp8_serialized)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
return ModelOptFp8LinearMethod(self)
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
return None
class ModelOptFp8LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer static quantization.
Supports loading FP8 checkpoints with static weight scale and
activation scale. Future support might be added for dynamic
scales.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn datatype
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp()
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
max_w_scale = layer.weight_scale.max()
if not (layer.weight_scale == layer.weight_scale[0]).all():
max_w_scale, weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias)
class ModelOptNvFp4Config(QuantizationConfig):
"""Config class for ModelOpt FP4."""
def __init__(
self,
is_checkpoint_nvfp4_serialized: bool,
kv_cache_quant_algo: str,
exclude_modules: List[str],
group_size: int = 16,
) -> None:
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
"Detected ModelOpt NVFP4 checkpoint. Please note that"
" the format is experimental and could change in future.")
self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules
@classmethod
def get_name(cls) -> str:
return "modelopt_nvfp4"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
@classmethod
def get_min_capability(cls) -> int:
return 100
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
if quant_method not in QUANT_ALGOS:
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
" quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"]
if not (group_size and kv_cache_quant_algo and exclude_modules):
raise ValueError("NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json")
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules):
return UnquantizedLinearMethod()
return ModelOptNvFp4LinearMethod(self)
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
return None
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: Union[ModelOptFp8Config,
ModelOptNvFp4Config]):
super().__init__(quant_config)
class ModelOptNvFp4LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
input_scale: torch.float32, scalar ,
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
weight_scale_2: torch.float32, scalar,
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptNvFp4Config):
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
if not self.cutlass_nvfp4_supported:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and above.")
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError("NVFP4 quantization was selected, "
" dynamic quantization is not supported.")
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
if (input_size_per_partition % 16 != 0):
raise ValueError("Unsupported model when in features size is "
"not multiple of 16")
# The nvfp4 weight is still represented as
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_nvfp4_serialized
else params_dtype)
# Weight
weight = ModelWeightParameter(
data=torch.empty(
# 2 fp4 items are packed in the input dimension
layer.output_size_per_partition,
layer.input_size_per_partition // 2,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# Input Weight Scale
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
# Global Weight Scale
weight_scale_2 = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_scale_2", weight_scale_2)
# Per Block Weight Scale
weight_scale = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.group_size,
dtype=weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: Module) -> None:
# global scales:
input_scale_2 = layer.input_scale.max().to(torch.float32)
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
requires_grad=False)
# Swizzle the weight blockscale.
# contracting dimension is input dimension
# block_size = 16;
assert (layer.weight_scale.shape[1] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Block scale must be represented as FP8-E4M3")
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
output_dtype = x.dtype
# for input only the contracting dimension has a constraint.
x_m, _ = x.shape
w_n, _ = layer.weight.shape
output_shape = [x_m, w_n]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
s_quant = 1 / layer.input_scale
x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)
# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale
assert (x_fp4.dtype == torch.uint8)
assert (layer.weight.dtype == torch.uint8)
assert (x_blockscale.dtype == torch.float8_e4m3fn)
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
assert (layer.alpha.dtype == torch.float32)
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
layer.weight_scale_swizzled, layer.alpha,
output_dtype)
if bias is not None:
out = out + bias
return out.view(*output_shape)