164 lines
6.3 KiB
Python
164 lines
6.3 KiB
Python
from typing import Any, Dict, List, Optional
|
|
|
|
import torch
|
|
from torch.nn import Module
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig, QuantizeMethodBase)
|
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
|
|
from vllm.model_executor.parameter import (ModelWeightParameter,
|
|
PerTensorScaleParameter)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
ACTIVATION_SCHEMES = ["static"]
|
|
|
|
|
|
class ModelOptFp8Config(QuantizationConfig):
|
|
"""Config class for ModelOpt FP8."""
|
|
|
|
def __init__(
|
|
self,
|
|
is_checkpoint_fp8_serialized: bool = False,
|
|
) -> None:
|
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
|
if is_checkpoint_fp8_serialized:
|
|
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
|
|
" the format is experimental and could change.")
|
|
|
|
@classmethod
|
|
def get_name(cls) -> str:
|
|
return "modelopt"
|
|
|
|
@classmethod
|
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
|
return [torch.bfloat16, torch.half]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
return 89
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> List[str]:
|
|
return ["hf_quant_config.json"]
|
|
|
|
@classmethod
|
|
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
|
quant_config = cls.get_from_keys(config, ["quantization"])
|
|
quant_method = quant_config["quant_algo"]
|
|
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
|
|
if not is_checkpoint_fp8_serialized:
|
|
raise ValueError("ModelOpt currently only supports static FP8"
|
|
"quantization in vLLM. Please check the "
|
|
"`hf_quant_config.json` file for your model's "
|
|
"quant configuration.")
|
|
return cls(is_checkpoint_fp8_serialized)
|
|
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
from vllm.attention.layer import Attention # Avoid circular import
|
|
if isinstance(layer, LinearBase):
|
|
return ModelOptFp8LinearMethod(self)
|
|
elif isinstance(layer, Attention):
|
|
return ModelOptFp8KVCacheMethod(self)
|
|
return None
|
|
|
|
def get_scaled_act_names(self) -> List[str]:
|
|
return []
|
|
|
|
|
|
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
|
"""
|
|
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
|
"""
|
|
|
|
def __init__(self, quant_config: ModelOptFp8Config):
|
|
super().__init__(quant_config)
|
|
|
|
|
|
class ModelOptFp8LinearMethod(LinearMethodBase):
|
|
"""Linear method for Model Optimizer static quantization.
|
|
Supports loading FP8 checkpoints with static weight scale and
|
|
activation scale. Future support might be added for dynamic
|
|
scales.
|
|
|
|
Limitations:
|
|
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
|
2. Only support float8_e4m3fn datatype
|
|
Args: quant_config: The ModelOpt quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: ModelOptFp8Config):
|
|
self.quant_config = quant_config
|
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: List[int],
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
del input_size, output_size
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
|
layer.logical_widths = output_partition_sizes
|
|
layer.input_size_per_partition = input_size_per_partition
|
|
layer.output_size_per_partition = output_size_per_partition
|
|
weight_dtype = (torch.float8_e4m3fn
|
|
if self.quant_config.is_checkpoint_fp8_serialized else
|
|
params_dtype)
|
|
weight = ModelWeightParameter(data=torch.empty(
|
|
output_size_per_partition,
|
|
input_size_per_partition,
|
|
dtype=weight_dtype),
|
|
input_dim=1,
|
|
output_dim=0,
|
|
weight_loader=weight_loader)
|
|
layer.register_parameter("weight", weight)
|
|
|
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
# WEIGHT SCALE
|
|
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
|
len(output_partition_sizes), dtype=torch.float32),
|
|
weight_loader=weight_loader)
|
|
weight_scale[:] = torch.finfo(torch.float32).min
|
|
layer.register_parameter("weight_scale", weight_scale)
|
|
# INPUT SCALE
|
|
scale = PerTensorScaleParameter(data=torch.empty(
|
|
len(output_partition_sizes), dtype=torch.float32),
|
|
weight_loader=weight_loader)
|
|
|
|
scale[:] = torch.finfo(torch.float32).min
|
|
layer.register_parameter("input_scale", scale)
|
|
|
|
def process_weights_after_loading(self, layer: Module) -> None:
|
|
max_w_scale, weight = requantize_with_max_scale(
|
|
layer.weight, layer.weight_scale, layer.logical_widths)
|
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
|
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
|
layer.input_scale = Parameter(layer.input_scale.max(),
|
|
requires_grad=False)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
return apply_fp8_linear(
|
|
input=x,
|
|
weight=layer.weight,
|
|
weight_scale=layer.weight_scale,
|
|
input_scale=layer.input_scale,
|
|
bias=bias,
|
|
cutlass_fp8_supported=self.cutlass_fp8_supported)
|