Add support for new autogptq quant_config.checkpoint_format (#332)
This commit is contained in:
@@ -19,7 +19,11 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||
|
||||
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
|
||||
QUANTIZATION_CONFIG_MAPPING = {
|
||||
"awq": AWQConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"marlin": MarlinConfig,
|
||||
}
|
||||
|
||||
logger = logging.getLogger("model_runner")
|
||||
|
||||
@@ -300,30 +304,31 @@ class ModelRunner:
|
||||
|
||||
# Load weights
|
||||
linear_method = None
|
||||
|
||||
quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
|
||||
if quant_cfg is not None:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
is_format_marlin = quant_cfg.get(
|
||||
"checkpoint_format"
|
||||
) == "marlin" or quant_cfg.get("is_marlin_format", False)
|
||||
|
||||
# Use marlin if the GPTQ model is serialized in marlin format.
|
||||
if quant_method == "gptq" and is_format_marlin:
|
||||
quant_method = "marlin"
|
||||
|
||||
quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
|
||||
|
||||
if quant_config_class is None:
|
||||
raise ValueError(f"Unsupported quantization method: {quant_method}")
|
||||
|
||||
quant_config = quant_config_class.from_config(quant_cfg)
|
||||
logger.info(f"quant_config: {quant_config}")
|
||||
linear_method = quant_config.get_linear_method()
|
||||
|
||||
with _set_default_torch_dtype(torch.float16):
|
||||
with torch.device("cuda"):
|
||||
hf_quant_config = getattr(
|
||||
self.model_config.hf_config, "quantization_config", None
|
||||
)
|
||||
if hf_quant_config is not None:
|
||||
hf_quant_method = hf_quant_config["quant_method"]
|
||||
|
||||
# compat: autogptq uses is_marlin_format within quant config
|
||||
if (
|
||||
hf_quant_method == "gptq"
|
||||
and "is_marlin_format" in hf_quant_config
|
||||
and hf_quant_config["is_marlin_format"]
|
||||
):
|
||||
hf_quant_method = "marlin"
|
||||
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
|
||||
|
||||
if quant_config_class is None:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
|
||||
)
|
||||
quant_config = quant_config_class.from_config(hf_quant_config)
|
||||
logger.info(f"quant_config: {quant_config}")
|
||||
linear_method = quant_config.get_linear_method()
|
||||
model = model_class(
|
||||
config=self.model_config.hf_config, linear_method=linear_method
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user