# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any import torch from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.model_executor.kernels.linear import ( init_fp8_linear_kernel, ) from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.layers.quantization.fp8 import ( Fp8Config, Fp8KVCacheMethod, Fp8LinearMethod, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, kFp8DynamicTokenSym, ) from vllm.platforms import current_platform class PTPCFp8Config(Fp8Config): """Config class for Per-Token-Per-Channel Dynamic Quantization Fp8.""" def __init__( self, activation_scheme: str = "dynamic", ignored_layers: list[str] | None = None, ) -> None: if not current_platform.is_rocm(): raise ValueError("ptpc_fp8 quantization is supported only on ROCm.") if not current_platform.has_device_capability(94): raise ValueError( "ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501 ) if activation_scheme == "static": raise ValueError("ptpc_fp8 as of now only support dynamic quantization.") super().__init__( is_checkpoint_fp8_serialized=False, activation_scheme=activation_scheme, ignored_layers=ignored_layers, ) @classmethod def get_name(cls) -> QuantizationMethods: return "ptpc_fp8" @classmethod def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config": activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) return cls(activation_scheme=activation_scheme, ignored_layers=ignored_layers) def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): return UnquantizedLinearMethod() return PTPCFp8LinearMethod(self) elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) return None class PTPCFp8LinearMethod(Fp8LinearMethod): """Linear method for Per-Token and Per-Channel FP8 Quantization. Only supports loading quantized BF16 model checkpoints with dynamic activation scaling. To load FP16 model checkpoints, user must specify to convert the FP16 model weight loading into BF16. The weight scaling factor will be initialized after the model weights are loaded. Limitations: 1. Only support float8_e4m3fnuz data type due to the limitation of torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041) Args: quant_config: The quantization config. """ def __init__(self, quant_config: PTPCFp8Config): assert current_platform.is_rocm(), ( "PTPCFp8LinearMethod is only supported on ROCm." ) super().__init__(quant_config=quant_config) # Force weight quantization self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=kFp8DynamicTokenSym, weight_quant_key=kFp8DynamicTokenSym, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: assert layer.weight.data.dtype not in (torch.float16, torch.float32), ( "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support " f"output dtype of bfloat16. {layer.weight.data.dtype} is specified." ) if layer.weight.data.dtype == torch.bfloat16: # Quantize the weights. qweight, weight_scale = ops.scaled_fp8_quant( layer.weight, scale=None, use_per_token_if_dynamic=True ) # Update the layer with the new values. layer.weight = Parameter( qweight.t(), requires_grad=False ) # Pretranspose the weight layer.weight_scale = Parameter(weight_scale, requires_grad=False) else: assert layer.weight.data.dtype == current_platform.fp8_dtype() assert getattr(layer, "weight_scale", None) is not None layer.input_scale = None def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: return self.fp8_linear.apply_weights(layer, x, bias)