check marlin format before attempting conversion (#4675)
This commit is contained in:
committed by
GitHub
parent
9f3bd2ad39
commit
4c7640079c
@@ -37,6 +37,14 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
|
||||
# compat: gptqmodel and autogptq (eol) main use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
return hf_quant_cfg.get("checkpoint_format") == "marlin" or hf_quant_cfg.get(
|
||||
"is_marlin_format", False
|
||||
)
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
@@ -262,13 +270,15 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
is_marlin_format = check_marlin_format(hf_quant_cfg)
|
||||
|
||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
|
||||
)
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
if not is_marlin_format and can_convert and is_valid_user_quant:
|
||||
msg = (
|
||||
"The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
||||
@@ -276,7 +286,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "gptq":
|
||||
if not is_marlin_format and can_convert and user_quant == "gptq":
|
||||
logger.info(
|
||||
"Detected that the model can run with gptq_marlin"
|
||||
", however you specified quantization=gptq explicitly,"
|
||||
@@ -401,11 +411,7 @@ class MarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
is_marlin_format = hf_quant_cfg.get(
|
||||
"checkpoint_format"
|
||||
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
|
||||
is_marlin_format = check_marlin_format(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
||||
|
||||
Reference in New Issue
Block a user