Add support for new autogptq quant_config.checkpoint_format (#332)
This commit is contained in:
@@ -19,7 +19,11 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
|||||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||||
|
|
||||||
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
|
QUANTIZATION_CONFIG_MAPPING = {
|
||||||
|
"awq": AWQConfig,
|
||||||
|
"gptq": GPTQConfig,
|
||||||
|
"marlin": MarlinConfig,
|
||||||
|
}
|
||||||
|
|
||||||
logger = logging.getLogger("model_runner")
|
logger = logging.getLogger("model_runner")
|
||||||
|
|
||||||
@@ -300,30 +304,31 @@ class ModelRunner:
|
|||||||
|
|
||||||
# Load weights
|
# Load weights
|
||||||
linear_method = None
|
linear_method = None
|
||||||
|
|
||||||
|
quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
|
||||||
|
if quant_cfg is not None:
|
||||||
|
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||||
|
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||||
|
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||||
|
is_format_marlin = quant_cfg.get(
|
||||||
|
"checkpoint_format"
|
||||||
|
) == "marlin" or quant_cfg.get("is_marlin_format", False)
|
||||||
|
|
||||||
|
# Use marlin if the GPTQ model is serialized in marlin format.
|
||||||
|
if quant_method == "gptq" and is_format_marlin:
|
||||||
|
quant_method = "marlin"
|
||||||
|
|
||||||
|
quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
|
||||||
|
|
||||||
|
if quant_config_class is None:
|
||||||
|
raise ValueError(f"Unsupported quantization method: {quant_method}")
|
||||||
|
|
||||||
|
quant_config = quant_config_class.from_config(quant_cfg)
|
||||||
|
logger.info(f"quant_config: {quant_config}")
|
||||||
|
linear_method = quant_config.get_linear_method()
|
||||||
|
|
||||||
with _set_default_torch_dtype(torch.float16):
|
with _set_default_torch_dtype(torch.float16):
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
hf_quant_config = getattr(
|
|
||||||
self.model_config.hf_config, "quantization_config", None
|
|
||||||
)
|
|
||||||
if hf_quant_config is not None:
|
|
||||||
hf_quant_method = hf_quant_config["quant_method"]
|
|
||||||
|
|
||||||
# compat: autogptq uses is_marlin_format within quant config
|
|
||||||
if (
|
|
||||||
hf_quant_method == "gptq"
|
|
||||||
and "is_marlin_format" in hf_quant_config
|
|
||||||
and hf_quant_config["is_marlin_format"]
|
|
||||||
):
|
|
||||||
hf_quant_method = "marlin"
|
|
||||||
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
|
|
||||||
|
|
||||||
if quant_config_class is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
|
|
||||||
)
|
|
||||||
quant_config = quant_config_class.from_config(hf_quant_config)
|
|
||||||
logger.info(f"quant_config: {quant_config}")
|
|
||||||
linear_method = quant_config.get_linear_method()
|
|
||||||
model = model_class(
|
model = model_class(
|
||||||
config=self.model_config.hf_config, linear_method=linear_method
|
config=self.model_config.hf_config, linear_method=linear_method
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user