# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from copy import deepcopy from fractions import Fraction from types import MappingProxyType from typing import Optional, Union import regex as re import torch from vllm.config import QuantizationConfig from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, UnquantizedEmbeddingMethod) # Match dynamic rules with module name (prefix) and override quantize # config if module (prefix) matches a rule def override_config(config: QuantizationConfig, prefix: str): weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits) if isinstance(weight_bits, int): config.weight_bits = weight_bits group_size = get_dynamic_override(config, prefix, "group_size", config.group_size) if isinstance(group_size, int): config.group_size = group_size desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act) if isinstance(desc_act, bool): config.desc_act = desc_act config.pack_factor = Fraction(32, config.weight_bits) # packed into int32 if config.get_name() == "gptq_marlin": is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) if isinstance(is_sym, bool): config.is_sym = is_sym if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: raise ValueError("Unsupported quantization config: " f"bits={config.weight_bits}, sym={config.is_sym}") config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)] elif config.get_name() == "gptq": if config.weight_bits not in [2, 3, 4, 8]: raise ValueError( "Currently, only 2/3/4/8-bit weight quantization is " f"supported for GPTQ, but got {config.weight_bits} bits.") def get_dynamic_override( config: QuantizationConfig, layer_name: str, key: Optional[str] = None, default_value: Union[int, bool, None] = None) -> Union[dict, int, bool, None]: for pattern, pattern_dict in config.dynamic.items(): # Negative match: matched modules are excluded from quantized init if pattern.startswith("-:"): if re.match(pattern.removeprefix("-:"), layer_name): return False # Positive match: matched modules have quant properties overrides # base quant config elif re.match(pattern.removeprefix("+:"), layer_name): if key is None: return pattern_dict else: return pattern_dict.get(key, default_value) return default_value def is_layer_gptq_quantized( prefix: str, quantized_layers: list[str], fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) ) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj # GPTQ's `modules_in_block_to_quantize`: # Substr: ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"] # Full prefix ["model.layers.0.self_attn.q_proj"] proj_name = prefix.split(".")[-1] # Fused layers like gate_up_proj or qkv_proj will not be fused # in the safetensors checkpoint. So, we convert the name # from the fused version to unfused + check to make sure that # each shard of the fused layer has the same scheme. if proj_name in fused_mapping: shard_prefixes = [ prefix.replace(proj_name, shard_proj_name) for shard_proj_name in fused_mapping[proj_name] ] is_quantized = None for shard_prefix in shard_prefixes: is_shard_quantized = any(layer in shard_prefix for layer in quantized_layers) if is_quantized is None: is_quantized = is_shard_quantized elif is_shard_quantized != is_quantized: raise ValueError( f"Detected some but not all shards of {prefix} " "are quantized. All shards of fused layers " "to have the same precision.") else: is_quantized = any(layer in prefix for layer in quantized_layers) assert is_quantized is not None return is_quantized def get_linear_quant_method( config: QuantizationConfig, layer: torch.nn.Module, prefix: str, linear_method_cls: type, ): cloned_config = deepcopy(config) parallel_lm_head_quantized = isinstance( layer, ParallelLMHead) and cloned_config.lm_head_quantized if isinstance(layer, LinearBase) or parallel_lm_head_quantized: is_layer_quantized = is_layer_gptq_quantized( prefix=prefix, quantized_layers=cloned_config.modules_in_block_to_quantize, fused_mapping=cloned_config.packed_modules_mapping) # False = skip module, None = no override, else = Positive match if get_dynamic_override( # noqa: E712 cloned_config, # noqa: E712 layer_name=prefix) == False or ( not is_layer_quantized): # noqa: E712 if parallel_lm_head_quantized: return UnquantizedEmbeddingMethod() return UnquantizedLinearMethod() if prefix: # Dynamic per module/layer rules may override base config override_config(cloned_config, prefix=prefix) return linear_method_cls(cloned_config) return None