check marlin format before attempting conversion (#4675)
This commit is contained in:
committed by
GitHub
parent
9f3bd2ad39
commit
4c7640079c
@@ -37,6 +37,14 @@ except ImportError:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
|
||||||
|
# compat: gptqmodel and autogptq (eol) main use checkpoint_format: str
|
||||||
|
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||||
|
return hf_quant_cfg.get("checkpoint_format") == "marlin" or hf_quant_cfg.get(
|
||||||
|
"is_marlin_format", False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GPTQConfig(QuantizationConfig):
|
class GPTQConfig(QuantizationConfig):
|
||||||
"""Config class for GPTQ.
|
"""Config class for GPTQ.
|
||||||
|
|
||||||
@@ -262,13 +270,15 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||||
|
is_marlin_format = check_marlin_format(hf_quant_cfg)
|
||||||
|
|
||||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||||
|
|
||||||
is_valid_user_quant = (
|
is_valid_user_quant = (
|
||||||
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
|
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_convert and is_valid_user_quant:
|
if not is_marlin_format and can_convert and is_valid_user_quant:
|
||||||
msg = (
|
msg = (
|
||||||
"The model is convertible to {} during runtime."
|
"The model is convertible to {} during runtime."
|
||||||
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
||||||
@@ -276,7 +286,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
return cls.get_name()
|
return cls.get_name()
|
||||||
|
|
||||||
if can_convert and user_quant == "gptq":
|
if not is_marlin_format and can_convert and user_quant == "gptq":
|
||||||
logger.info(
|
logger.info(
|
||||||
"Detected that the model can run with gptq_marlin"
|
"Detected that the model can run with gptq_marlin"
|
||||||
", however you specified quantization=gptq explicitly,"
|
", however you specified quantization=gptq explicitly,"
|
||||||
@@ -401,11 +411,7 @@ class MarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
is_marlin_format = check_marlin_format(hf_quant_cfg)
|
||||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
|
||||||
is_marlin_format = hf_quant_cfg.get(
|
|
||||||
"checkpoint_format"
|
|
||||||
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
|
|
||||||
|
|
||||||
is_valid_user_quant = (
|
is_valid_user_quant = (
|
||||||
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
||||||
|
|||||||
Reference in New Issue
Block a user