Sync from v0.13
This commit is contained in:
158
vllm/model_executor/layers/quantization/utils/gptq_utils.py
Normal file
158
vllm/model_executor/layers/quantization/utils/gptq_utils.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from fractions import Fraction
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
UnquantizedEmbeddingMethod,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..gptq import GPTQConfig
|
||||
from ..gptq_marlin import GPTQMarlinConfig
|
||||
else:
|
||||
GPTQConfig = object
|
||||
GPTQMarlinConfig = object
|
||||
|
||||
|
||||
# Match dynamic rules with module name (prefix) and override quantize
|
||||
# config if module (prefix) matches a rule
|
||||
def override_config(config: GPTQConfig | GPTQMarlinConfig, prefix: str):
|
||||
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
||||
if isinstance(weight_bits, int):
|
||||
config.weight_bits = weight_bits
|
||||
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
||||
if isinstance(group_size, int):
|
||||
config.group_size = group_size
|
||||
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
||||
if isinstance(desc_act, bool):
|
||||
config.desc_act = desc_act
|
||||
|
||||
config.pack_factor = Fraction(32, config.weight_bits) # packed into int32
|
||||
if config.get_name() == "gptq_marlin":
|
||||
assert isinstance(config, GPTQMarlinConfig)
|
||||
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
||||
if isinstance(is_sym, bool):
|
||||
config.is_sym = is_sym
|
||||
|
||||
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
||||
raise ValueError(
|
||||
"Unsupported quantization config: "
|
||||
f"bits={config.weight_bits}, sym={config.is_sym}"
|
||||
)
|
||||
|
||||
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
||||
elif config.get_name() == "gptq":
|
||||
assert isinstance(config, GPTQConfig)
|
||||
if config.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {config.weight_bits} bits."
|
||||
)
|
||||
|
||||
|
||||
def get_dynamic_override(
|
||||
config: GPTQConfig | GPTQMarlinConfig,
|
||||
layer_name: str,
|
||||
key: str | None = None,
|
||||
default_value: int | bool | None = None,
|
||||
) -> dict | int | bool | None:
|
||||
for pattern, pattern_dict in config.dynamic.items():
|
||||
# Negative match: matched modules are excluded from quantized init
|
||||
if pattern.startswith("-:"):
|
||||
if re.match(pattern.removeprefix("-:"), layer_name):
|
||||
return False
|
||||
# Positive match: matched modules have quant properties overrides
|
||||
# base quant config
|
||||
elif re.match(pattern.removeprefix("+:"), layer_name):
|
||||
if key is None:
|
||||
return pattern_dict
|
||||
else:
|
||||
return pattern_dict.get(key, default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def is_layer_gptq_quantized(
|
||||
prefix: str,
|
||||
quantized_layers: list[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
|
||||
# GPTQ's `modules_in_block_to_quantize`:
|
||||
# Substr: ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"]
|
||||
# Full prefix ["model.layers.0.self_attn.q_proj"]
|
||||
|
||||
proj_name = prefix.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
|
||||
is_quantized = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_quantized = any(
|
||||
layer in shard_prefix for layer in quantized_layers
|
||||
)
|
||||
|
||||
if is_quantized is None:
|
||||
is_quantized = is_shard_quantized
|
||||
elif is_shard_quantized != is_quantized:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision."
|
||||
)
|
||||
else:
|
||||
is_quantized = any(layer in prefix for layer in quantized_layers)
|
||||
|
||||
assert is_quantized is not None
|
||||
return is_quantized
|
||||
|
||||
|
||||
def get_linear_quant_method(
|
||||
config: GPTQConfig | GPTQMarlinConfig,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
linear_method_cls: type,
|
||||
):
|
||||
cloned_config = deepcopy(config)
|
||||
parallel_lm_head_quantized = (
|
||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
)
|
||||
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
||||
is_layer_quantized = is_layer_gptq_quantized(
|
||||
prefix=prefix,
|
||||
quantized_layers=cloned_config.modules_in_block_to_quantize,
|
||||
fused_mapping=cloned_config.packed_modules_mapping,
|
||||
)
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if get_dynamic_override( # noqa: E712
|
||||
cloned_config, # noqa: E712
|
||||
layer_name=prefix,
|
||||
) == False or (not is_layer_quantized): # noqa: E712
|
||||
if parallel_lm_head_quantized:
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if prefix:
|
||||
# Dynamic per module/layer rules may override base config
|
||||
override_config(cloned_config, prefix=prefix)
|
||||
|
||||
return linear_method_cls(cloned_config)
|
||||
return None
|
||||
Reference in New Issue
Block a user