[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
committed by
GitHub
parent
583d6af71b
commit
56a724eba3
@@ -1,5 +1,7 @@
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
||||
from typing import Callable, Dict, Optional, Type
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Dict, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||
@@ -16,8 +18,6 @@ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfi
|
||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
|
||||
@@ -61,19 +62,119 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
return QUANTIZATION_METHODS[quantization]
|
||||
|
||||
|
||||
# Match dynamic rules with module name (prefix) and override quantize
|
||||
# config if module (prefix) matches a rule
|
||||
def override_config(config: QuantizationConfig, prefix: str):
|
||||
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
||||
if isinstance(weight_bits, int):
|
||||
config.weight_bits = weight_bits
|
||||
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
||||
if isinstance(group_size, int):
|
||||
config.group_size = group_size
|
||||
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
||||
if isinstance(desc_act, bool):
|
||||
config.desc_act = desc_act
|
||||
|
||||
config.pack_factor = 32 // config.weight_bits # packed into int32
|
||||
if config.get_name() == "gptq_marlin":
|
||||
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
||||
if isinstance(is_sym, bool):
|
||||
config.is_sym = is_sym
|
||||
|
||||
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
||||
raise ValueError(
|
||||
"Unsupported quantization config: "
|
||||
f"bits={config.weight_bits}, sym={config.is_sym}"
|
||||
)
|
||||
|
||||
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
||||
elif config.get_name() == "gptq":
|
||||
if config.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {config.weight_bits} bits."
|
||||
)
|
||||
|
||||
|
||||
def get_dynamic_override(
|
||||
config: QuantizationConfig,
|
||||
layer_name: str,
|
||||
key: Optional[str] = None,
|
||||
default_value: Union[int, bool, None] = None,
|
||||
) -> Union[Dict, int, bool, None]:
|
||||
for pattern, pattern_dict in config.dynamic.items():
|
||||
# Negative match: matched modules are excluded from quantized init
|
||||
if pattern.startswith("-:"):
|
||||
if re.match(pattern.removeprefix("-:"), layer_name):
|
||||
return False
|
||||
# Positive match: matched modules have quant properties overrides
|
||||
# base quant config
|
||||
elif re.match(pattern.removeprefix("+:"), layer_name):
|
||||
if key is None:
|
||||
return pattern_dict
|
||||
else:
|
||||
return pattern_dict.get(key, default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def get_linear_quant_method(
|
||||
config: QuantizationConfig,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
linear_method_cls: type,
|
||||
):
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
UnquantizedEmbeddingMethod,
|
||||
)
|
||||
|
||||
cloned_config = deepcopy(config)
|
||||
parallel_lm_head_quantized = (
|
||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
)
|
||||
|
||||
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if (
|
||||
get_dynamic_override( # noqa: E712
|
||||
cloned_config, layer_name=prefix # noqa: E712
|
||||
)
|
||||
== False
|
||||
): # noqa: E712
|
||||
if parallel_lm_head_quantized:
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if prefix:
|
||||
# Dynamic per module/layer rules may override base config
|
||||
override_config(cloned_config, prefix=prefix)
|
||||
|
||||
return linear_method_cls(cloned_config)
|
||||
return None
|
||||
|
||||
|
||||
def gptq_get_quant_method(self, layer, prefix):
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return GPTQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
|
||||
if isinstance(self, GPTQConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
||||
)
|
||||
elif isinstance(self, GPTQMarlinConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -155,6 +256,7 @@ def apply_monkey_patches():
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
||||
|
||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
||||
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
||||
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
|
||||
|
||||
|
||||
416
python/sglang/srt/layers/quantization/gptq.py
Normal file
416
python/sglang/srt/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,416 @@
|
||||
import logging
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
super().__init__()
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
if self.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {self.weight_bits} bits."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}),"
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}"
|
||||
)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
"""Returns the activation function names that should be post-scaled.
|
||||
|
||||
For now, this is only used by AWQ.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized, dynamic)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["GPTQLinearMethod"]:
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
|
||||
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
full_config: Dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.is_sym = is_sym
|
||||
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.full_config = full_config
|
||||
|
||||
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
||||
raise ValueError(
|
||||
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
|
||||
)
|
||||
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}"
|
||||
)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
"""Returns the activation function names that should be post-scaled.
|
||||
|
||||
For now, this is only used by AWQ.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
desc_act,
|
||||
is_sym,
|
||||
lm_head_quantized,
|
||||
dynamic,
|
||||
config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
|
||||
)
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = (
|
||||
"The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
||||
)
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "gptq":
|
||||
logger.info(
|
||||
"Detected that the model can run with gptq_marlin"
|
||||
", however you specified quantization=gptq explicitly,"
|
||||
" so forcing gptq. Use quantization=gptq_marlin for"
|
||||
" faster inference"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
# TODO: re-enable after SGLang syncs with vllm >= 0.7.3
|
||||
# if layer.num_experts > 32:
|
||||
# # For MoEs with many experts the moe_wna16 kernel is faster
|
||||
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
||||
# layer, prefix
|
||||
# )
|
||||
# else:
|
||||
# return GPTQMarlinMoEMethod(self)
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
sym = quant_config.get("sym")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# Marlin conversion is only valid if required properties are found
|
||||
if num_bits is None or group_size is None or sym is None or desc_act is None:
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_marlin_supported(
|
||||
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
|
||||
)
|
||||
|
||||
|
||||
class MarlinConfig(QuantizationConfig):
|
||||
"""Config class for Marlin.
|
||||
|
||||
Reference: https://github.com/IST-DASLab/marlin/tree/master
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_size: int,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
# Group size for the quantization.
|
||||
self.group_size = group_size
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
if self.group_size != 128 and self.group_size != -1:
|
||||
raise ValueError(
|
||||
"Currently, only group size 128 and -1 (channelwise) "
|
||||
"is supported for Marlin, but got group_size of "
|
||||
f"{self.group_size}"
|
||||
)
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // 4
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = 64
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = 128
|
||||
|
||||
# Max parallel problems to solve at once (improves large
|
||||
# batch performance)
|
||||
self.max_parallel = 16
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"MarlinConfig(group_size={self.group_size}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
return cls(group_size, lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
is_marlin_format = hf_quant_cfg.get(
|
||||
"checkpoint_format"
|
||||
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
||||
)
|
||||
|
||||
if is_marlin_format and is_valid_user_quant:
|
||||
msg = "The model is serialized in {} format. Using {} kernel.".format(
|
||||
cls.get_name(), cls.get_name()
|
||||
)
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["MarlinLinearMethod"]:
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
return MarlinLinearMethod(self)
|
||||
return None
|
||||
Reference in New Issue
Block a user