# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from fractions import Fraction from typing import TYPE_CHECKING, Any import regex as re import torch from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import ( QuantizationConfig, QuantizationMethods, ) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.platforms import current_platform from vllm.scalar_type import scalar_types if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) class AutoRoundConfig(QuantizationConfig): """Config class for AutoRound. Reference: https://arxiv.org/pdf/2309.05516 """ SUPPORTED_BITS = {2, 3, 4, 8} SUPPORTED_DTYPES = {"int"} SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"} SUPPORTED_BACKENDS = { "auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin", "ipex", } def __init__( self, weight_bits: int, group_size: int, sym: bool = True, packing_format: str = "auto_round:auto_gptq", block_name_to_quantize: str | list[str] | None = None, extra_config: dict[str, Any] | None = None, data_type: str = "int", backend: str = "auto", ) -> None: super().__init__() if weight_bits not in self.SUPPORTED_BITS: raise ValueError( f"Unsupported weight_bits: {weight_bits}, " f"currently only support {self.SUPPORTED_BITS}" ) if data_type not in self.SUPPORTED_DTYPES: raise ValueError( f"Unsupported data_type: {data_type}," f" currently only support {self.SUPPORTED_DTYPES}" ) if packing_format not in self.SUPPORTED_FORMATS: raise ValueError( f"Unsupported packing_format: {packing_format}, " f"currently only support {self.SUPPORTED_FORMATS}" ) if backend not in self.SUPPORTED_BACKENDS: raise ValueError( f"Unsupported backend: {backend}, " f"currently only support {self.SUPPORTED_BACKENDS}" ) self.weight_bits = weight_bits self.group_size = group_size self.sym = sym self.packing_format = packing_format self.block_name_to_quantize = ( block_name_to_quantize.split(",") if isinstance(block_name_to_quantize, str) else block_name_to_quantize ) self.extra_config = extra_config self.data_type = data_type self.backend = backend self.pack_factor = Fraction(32, weight_bits) def __repr__(self) -> str: return ( f"AutoRoundConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, sym={self.sym})" ) @classmethod def get_name(cls) -> QuantizationMethods: return "auto-round" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: return 60 @classmethod def get_config_filenames(cls) -> list[str]: return ["quantization_config.json"] @classmethod def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig": return cls( weight_bits=cls.get_from_keys(config, ["bits"]), group_size=cls.get_from_keys(config, ["group_size"]), sym=cls.get_from_keys(config, ["sym"]), packing_format=cls.get_from_keys_or( config, ["packing_format"], "auto_round:auto_gptq" ), block_name_to_quantize=cls.get_from_keys_or( config, ["block_name_to_quantize", "to_quant_block_names"], None ), extra_config=cls.get_from_keys_or(config, ["extra_config"], None), data_type=cls.get_from_keys_or(config, ["data_type"], "int"), backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], "auto"), ) def get_layer_config(self, layer, layer_name: str): def get_config(name: str, quantized: bool = True): if not self.extra_config: return ( self.weight_bits if quantized else 16, self.group_size if quantized else -1, self.sym if quantized else True, ) # exact match first if name in self.extra_config: cfg = self.extra_config[name] return ( cfg.get("bits", self.weight_bits if quantized else 16), cfg.get("group_size", self.group_size if quantized else -1), cfg.get("sym", self.sym if quantized else True), ) REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\") for pattern, cfg in self.extra_config.items(): if not isinstance(pattern, str) or not any( c in REGEX_SPECIAL_CHARS for c in pattern ): continue try: if re.search(re.compile(pattern), name) is not None: return ( cfg.get("bits", self.weight_bits if quantized else 16), cfg.get("group_size", self.group_size if quantized else -1), cfg.get("sym", self.sym if quantized else True), ) except re.error: # Invalid regex, ignore. continue return ( self.weight_bits if quantized else 16, self.group_size if quantized else -1, self.sym if quantized else True, ) # 1. Exact match from config if self.extra_config and layer_name in self.extra_config: return get_config(layer_name) # 2. Determine whether layer should be quantized quantized = not isinstance(layer, ParallelLMHead) if self.block_name_to_quantize: quantized = any( layer_name.startswith(name) for name in self.block_name_to_quantize ) # 3. Handle fused MoE if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower(): moe_configs = [ get_config(name, quantized) for name in self.extra_config if name.startswith(layer_name) ] if moe_configs: if len(set(moe_configs)) == 1: return moe_configs[0] raise ValueError( f"Fused MoE layer '{layer_name}' requires " f"consistent quant config for all sub-layers" ) # 4. Handle fused QKV or other patterns if self.extra_config: for fusion_key, sub_keys in self.packed_modules_mapping.items(): if fusion_key in layer_name and layer_name.count(fusion_key) == 1: sub_names = [ layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys ] sub_configs = [get_config(name, quantized) for name in sub_names] if len(set(sub_configs)) == 1: return sub_configs[0] raise ValueError( f"Fused module '{layer_name}' requires " f"consistent quant config for {sub_names}" ) # 5. Fallback or try a regular expression match return get_config(layer_name, quantized) def check_quantized(self, weight_bits: int) -> bool: return weight_bits < 16 def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.block_name_to_quantize is not None: self.block_name_to_quantize = hf_to_vllm_mapper.apply_list( self.block_name_to_quantize ) if self.extra_config is not None: self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config) def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, check_moe_marlin_supports_layer, ) weight_bits, group_size, sym = self.get_layer_config(layer, prefix) if not self.check_quantized(weight_bits): if isinstance(layer, (LinearBase, ParallelLMHead)): return UnquantizedLinearMethod() else: return None logger.debug( "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", prefix, layer.__class__.__name__, weight_bits, group_size, sym, ) if backend == "auto" or "marlin" in backend: AWQ_TYPE_MAP = { 4: scalar_types.uint4, 8: scalar_types.uint8, } use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported( AWQ_TYPE_MAP[weight_bits], group_size, not sym ) if isinstance(layer, FusedMoE): use_marlin = use_marlin and check_moe_marlin_supports_layer( layer, group_size ) else: use_marlin = False if use_marlin: from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinConfig, AWQMarlinLinearMethod, AWQMarlinMoEMethod, ) quant_args_marlin = AWQMarlinConfig( weight_bits=weight_bits, group_size=group_size, zero_point=not sym, lm_head_quantized=False, full_config={}, modules_to_not_convert=[], ) else: from vllm.model_executor.layers.quantization.awq import ( AWQConfig, AWQLinearMethod, ) quant_args = AWQConfig( weight_bits=weight_bits, group_size=group_size, zero_point=not sym, ) if isinstance(layer, FusedMoE): if use_marlin: return AWQMarlinMoEMethod(quant_args_marlin, layer.moe) from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config config = { "quant_method": "awq", "bits": weight_bits, "group_size": group_size, "zero_point": not sym, "lm_head": False, } return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix) if isinstance(layer, (LinearBase, ParallelLMHead)): if use_marlin: return AWQMarlinLinearMethod(quant_args_marlin) else: return AWQLinearMethod(quant_args) return None def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"): from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, check_moe_marlin_supports_layer, ) weight_bits, group_size, sym = self.get_layer_config(layer, prefix) if not self.check_quantized(weight_bits): if isinstance(layer, (LinearBase, ParallelLMHead)): return UnquantizedLinearMethod() else: return None logger.debug( "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", prefix, layer.__class__.__name__, weight_bits, group_size, sym, ) if backend == "auto" or "marlin" in backend: GPTQ_TYPE_MAP = { (4, True): scalar_types.uint4b8, (8, True): scalar_types.uint8b128, } use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported( GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym ) if isinstance(layer, FusedMoE): use_marlin = use_marlin and check_moe_marlin_supports_layer( layer, group_size ) else: use_marlin = False if use_marlin: from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod, ) quant_args_marlin = GPTQMarlinConfig( weight_bits=weight_bits, group_size=group_size, is_sym=sym, lm_head_quantized=False, desc_act=False, dynamic={}, full_config={}, ) else: from vllm.model_executor.layers.quantization.gptq import ( GPTQConfig, GPTQLinearMethod, ) quant_args = GPTQConfig( weight_bits=weight_bits, group_size=group_size, lm_head_quantized=False, desc_act=False, dynamic={}, ) if isinstance(layer, FusedMoE): if use_marlin: return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe_config) else: from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config, ) config = { "quant_method": "gptq", "bits": weight_bits, "group_size": group_size, "sym": sym, "lm_head": False, } return MoeWNA16Config.from_config(config).get_quant_method( layer, prefix ) if isinstance(layer, (LinearBase, ParallelLMHead)): if use_marlin: return GPTQMarlinLinearMethod(quant_args_marlin) else: return GPTQLinearMethod(quant_args) return None def apply_ipex_quant_layer(self, layer, prefix: str): weight_bits, group_size, sym = self.get_layer_config(layer, prefix) if not self.check_quantized(weight_bits): if isinstance(layer, (LinearBase, ParallelLMHead)): return UnquantizedLinearMethod() else: return None from vllm.model_executor.layers.quantization.ipex_quant import ( IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod, ) if isinstance(layer, (LinearBase, ParallelLMHead)): if "awq" in self.packing_format: config = IPEXConfig( method="awq", weight_bits=weight_bits, group_size=group_size ) return IPEXAWQLinearMethod(config) elif "gptq" in self.packing_format: config = IPEXConfig( method="gptq", weight_bits=weight_bits, group_size=group_size ) return IPEXGPTQLinearMethod(config) else: raise ValueError( f"ipex backend only supports awq " f"and gtpq format,but got {self.packing_format}" ) else: return None def get_quant_method(self, layer: torch.nn.Module, prefix: str): if prefix and self.extra_config: for layer_name in self.extra_config: if ( layer_name == prefix or layer_name == f"model.{prefix}" ) and self.extra_config[layer_name].get("bits", 16) >= 16: return UnquantizedLinearMethod() if ( current_platform.is_cpu() or current_platform.is_xpu() or self.backend == "ipex" ): return self.apply_ipex_quant_layer(layer, prefix) if "gptq" in self.packing_format or "gptq" in self.backend: return self.apply_gptq_quant_layer(layer, prefix) if "awq" in self.packing_format or "awq" in self.backend: return self.apply_awq_quant_layer(layer, prefix)