Refactor kv_cache_scheme handling for quantization (#10132)
This commit is contained in:
committed by
GitHub
parent
916784746b
commit
cd4da1f19b
@@ -140,11 +140,21 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
# Flat format (config.json quantization_config)
|
# Flat format (config.json quantization_config)
|
||||||
# For kv_cache, check if kv_cache_scheme exists and extract algo
|
# For kv_cache, check if kv_cache_scheme exists and extract algo
|
||||||
kv_cache_scheme = config.get("kv_cache_scheme")
|
kv_cache_scheme = config.get("kv_cache_scheme")
|
||||||
if (
|
|
||||||
kv_cache_scheme
|
kv_cache_type = None
|
||||||
and kv_cache_scheme.get("type") == "float"
|
kv_cache_bits = None
|
||||||
and kv_cache_scheme.get("num_bits") == 8
|
if isinstance(kv_cache_scheme, dict):
|
||||||
):
|
# Handles the expected format: {"type": "float", "num_bits": 8}
|
||||||
|
kv_cache_type = kv_cache_scheme.get("type")
|
||||||
|
kv_cache_bits = kv_cache_scheme.get("num_bits")
|
||||||
|
elif isinstance(kv_cache_scheme, str):
|
||||||
|
# Handles the shorthand format: "FP8"
|
||||||
|
if kv_cache_scheme.upper() == "FP8":
|
||||||
|
kv_cache_type = "float"
|
||||||
|
kv_cache_bits = 8
|
||||||
|
|
||||||
|
# Now, safely use the extracted values
|
||||||
|
if kv_cache_type == "float" and kv_cache_bits == 8:
|
||||||
kv_cache_quant_method = "FP8"
|
kv_cache_quant_method = "FP8"
|
||||||
|
|
||||||
# Map 'ignore' field to 'exclude_modules'
|
# Map 'ignore' field to 'exclude_modules'
|
||||||
@@ -594,11 +604,22 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
if not kv_cache_quant_algo:
|
if not kv_cache_quant_algo:
|
||||||
# For config.json format, derive from kv_cache_scheme if available
|
# For config.json format, derive from kv_cache_scheme if available
|
||||||
kv_cache_scheme = config.get("kv_cache_scheme")
|
kv_cache_scheme = config.get("kv_cache_scheme")
|
||||||
if (
|
|
||||||
kv_cache_scheme
|
kv_cache_type = None
|
||||||
and kv_cache_scheme.get("type") == "float"
|
kv_cache_bits = None
|
||||||
and kv_cache_scheme.get("num_bits") == 8
|
if isinstance(kv_cache_scheme, dict):
|
||||||
):
|
# Handles the expected format: {"type": "float", "num_bits": 8}
|
||||||
|
kv_cache_type = kv_cache_scheme.get("type")
|
||||||
|
kv_cache_bits = kv_cache_scheme.get("num_bits")
|
||||||
|
elif isinstance(kv_cache_scheme, str):
|
||||||
|
# Handles the shorthand format: "FP8"
|
||||||
|
# We can infer the properties from the string.
|
||||||
|
if kv_cache_scheme.upper() == "FP8":
|
||||||
|
kv_cache_type = "float"
|
||||||
|
kv_cache_bits = 8
|
||||||
|
|
||||||
|
# Now, safely use the extracted values in the original logic
|
||||||
|
if kv_cache_type == "float" and kv_cache_bits == 8:
|
||||||
kv_cache_quant_algo = "FP8"
|
kv_cache_quant_algo = "FP8"
|
||||||
else:
|
else:
|
||||||
kv_cache_quant_algo = "auto"
|
kv_cache_quant_algo = "auto"
|
||||||
|
|||||||
Reference in New Issue
Block a user