Files
2026-03-05 18:06:10 +08:00

159 lines
5.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from copy import deepcopy
from fractions import Fraction
from types import MappingProxyType
from typing import TYPE_CHECKING
import regex as re
import torch
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
)
if TYPE_CHECKING:
from ..gptq import GPTQConfig
from ..gptq_marlin import GPTQMarlinConfig
else:
GPTQConfig = object
GPTQMarlinConfig = object
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: GPTQConfig | GPTQMarlinConfig, prefix: str):
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
if isinstance(weight_bits, int):
config.weight_bits = weight_bits
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
if isinstance(group_size, int):
config.group_size = group_size
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
if isinstance(desc_act, bool):
config.desc_act = desc_act
config.pack_factor = Fraction(32, config.weight_bits) # packed into int32
if config.get_name() == "gptq_marlin":
assert isinstance(config, GPTQMarlinConfig)
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool):
config.is_sym = is_sym
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
raise ValueError(
"Unsupported quantization config: "
f"bits={config.weight_bits}, sym={config.is_sym}"
)
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
elif config.get_name() == "gptq":
assert isinstance(config, GPTQConfig)
if config.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {config.weight_bits} bits."
)
def get_dynamic_override(
config: GPTQConfig | GPTQMarlinConfig,
layer_name: str,
key: str | None = None,
default_value: int | bool | None = None,
) -> dict | int | bool | None:
for pattern, pattern_dict in config.dynamic.items():
# Negative match: matched modules are excluded from quantized init
if pattern.startswith("-:"):
if re.match(pattern.removeprefix("-:"), layer_name):
return False
# Positive match: matched modules have quant properties overrides
# base quant config
elif re.match(pattern.removeprefix("+:"), layer_name):
if key is None:
return pattern_dict
else:
return pattern_dict.get(key, default_value)
return default_value
def is_layer_gptq_quantized(
prefix: str,
quantized_layers: list[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
# GPTQ's `modules_in_block_to_quantize`:
# Substr: ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"]
# Full prefix ["model.layers.0.self_attn.q_proj"]
proj_name = prefix.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_quantized = None
for shard_prefix in shard_prefixes:
is_shard_quantized = any(
layer in shard_prefix for layer in quantized_layers
)
if is_quantized is None:
is_quantized = is_shard_quantized
elif is_shard_quantized != is_quantized:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision."
)
else:
is_quantized = any(layer in prefix for layer in quantized_layers)
assert is_quantized is not None
return is_quantized
def get_linear_quant_method(
config: GPTQConfig | GPTQMarlinConfig,
layer: torch.nn.Module,
prefix: str,
linear_method_cls: type,
):
cloned_config = deepcopy(config)
parallel_lm_head_quantized = (
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
)
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
is_layer_quantized = is_layer_gptq_quantized(
prefix=prefix,
quantized_layers=cloned_config.modules_in_block_to_quantize,
fused_mapping=cloned_config.packed_modules_mapping,
)
# False = skip module, None = no override, else = Positive match
if get_dynamic_override( # noqa: E712
cloned_config, # noqa: E712
layer_name=prefix,
) == False or (not is_layer_quantized): # noqa: E712
if parallel_lm_head_quantized:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()
if prefix:
# Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix)
return linear_method_cls(cloned_config)
return None