Modelopt quant config adaptation (#8829)
This commit is contained in:
@@ -111,18 +111,52 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
||||
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
||||
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
||||
"kv_cache_quant_algo"
|
||||
)
|
||||
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
|
||||
"exclude_modules"
|
||||
)
|
||||
# Handle two different config formats:
|
||||
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "FP8", ...}}
|
||||
# 2. config.json quantization_config format: {"quant_algo": "FP8", ...}
|
||||
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
|
||||
# For legacy reasons, we keep hf_quant_config.json for now.
|
||||
|
||||
# Initialize variables
|
||||
kv_cache_quant_method = None
|
||||
exclude_modules = None
|
||||
|
||||
# Try flat format first (config.json quantization_config - preferred format)
|
||||
quant_method = config.get("quant_algo")
|
||||
if quant_method is not None:
|
||||
# Flat format (config.json quantization_config)
|
||||
# For kv_cache, check if kv_cache_scheme exists and extract algo
|
||||
kv_cache_scheme = config.get("kv_cache_scheme")
|
||||
if (
|
||||
kv_cache_scheme
|
||||
and kv_cache_scheme.get("type") == "float"
|
||||
and kv_cache_scheme.get("num_bits") == 8
|
||||
):
|
||||
kv_cache_quant_method = "FP8"
|
||||
|
||||
# Map 'ignore' field to 'exclude_modules'
|
||||
exclude_modules = config.get("ignore")
|
||||
else:
|
||||
# Fall back to nested format (hf_quant_config.json - legacy format)
|
||||
try:
|
||||
quantization_section = cls.get_from_keys(config, ["quantization"])
|
||||
quant_method = quantization_section.get("quant_algo")
|
||||
kv_cache_quant_method = quantization_section.get("kv_cache_quant_algo")
|
||||
exclude_modules = quantization_section.get("exclude_modules")
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Cannot find 'quant_algo' in the model's quantization config. "
|
||||
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
|
||||
)
|
||||
if quant_method is None:
|
||||
raise ValueError(
|
||||
"Cannot find 'quant_algo' in the model's quantization config. "
|
||||
)
|
||||
if "FP8" not in quant_method:
|
||||
raise ValueError(
|
||||
"ModelOpt only supports static FP8 quantization in SGLang. "
|
||||
"Check the `hf_quant_config.json` file for your model's configuration."
|
||||
"ModelOptFp8Config only supports static FP8 quantization in SGLang. "
|
||||
"For FP4 quantization, use ModelOptFp4Config. "
|
||||
"Check the quantization config for your model's configuration."
|
||||
)
|
||||
|
||||
return cls(
|
||||
@@ -485,22 +519,63 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
quant_method = quant_config["quant_algo"]
|
||||
# Handle two different config formats:
|
||||
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "NVFP4", ...}}
|
||||
# 2. config.json quantization_config format: {"quant_algo": "NVFP4", ...}
|
||||
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
|
||||
# For legacy reasons, we keep hf_quant_config.json for now.
|
||||
|
||||
# Initialize variables
|
||||
kv_cache_quant_algo = None
|
||||
group_size = None
|
||||
exclude_modules = []
|
||||
|
||||
# Try flat format first (config.json quantization_config - preferred format)
|
||||
quant_method = config.get("quant_algo")
|
||||
if quant_method is not None:
|
||||
# Flat format (config.json quantization_config)
|
||||
# Note: FP4 models in config.json format may not have all the detailed fields
|
||||
# that are present in hf_quant_config.json, so we need to handle defaults
|
||||
kv_cache_quant_algo = config.get("kv_cache_quant_algo")
|
||||
if not kv_cache_quant_algo:
|
||||
# For config.json format, derive from kv_cache_scheme if available
|
||||
kv_cache_scheme = config.get("kv_cache_scheme")
|
||||
if (
|
||||
kv_cache_scheme
|
||||
and kv_cache_scheme.get("type") == "float"
|
||||
and kv_cache_scheme.get("num_bits") == 8
|
||||
):
|
||||
kv_cache_quant_algo = "FP8"
|
||||
else:
|
||||
kv_cache_quant_algo = "auto"
|
||||
|
||||
group_size = config.get("group_size")
|
||||
exclude_modules = config.get("ignore", [])
|
||||
else:
|
||||
# Fall back to nested format (hf_quant_config.json - legacy format)
|
||||
try:
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
quant_method = quant_config["quant_algo"]
|
||||
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
|
||||
if not kv_cache_quant_algo:
|
||||
kv_cache_quant_algo = "auto"
|
||||
group_size = quant_config.get("group_size")
|
||||
exclude_modules = quant_config.get("exclude_modules", [])
|
||||
except (ValueError, KeyError):
|
||||
raise ValueError(
|
||||
"Cannot find 'quant_algo' in the model's quantization config. "
|
||||
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
|
||||
)
|
||||
|
||||
if not quant_method in ["FP8", "NVFP4"]:
|
||||
raise ValueError(
|
||||
f"ModelOpt currently only supports: FP8, NVFP4"
|
||||
" quantizations in sglang. Please check the "
|
||||
"`hf_quant_config.json` file for your model's "
|
||||
"quant configuration."
|
||||
"quantization config for your model's configuration."
|
||||
)
|
||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
||||
if not kv_cache_quant_algo:
|
||||
kv_cache_quant_algo = "auto"
|
||||
group_size = quant_config["group_size"]
|
||||
exclude_modules = quant_config["exclude_modules"]
|
||||
if not (group_size and kv_cache_quant_algo and exclude_modules):
|
||||
|
||||
if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
|
||||
logger.warning(
|
||||
f"group_size: {group_size},"
|
||||
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
||||
@@ -508,8 +583,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
)
|
||||
raise ValueError(
|
||||
"NVFP4 quantization requires group size and "
|
||||
"kv_cache_quant_algo specified in "
|
||||
"hf_quant_config.json"
|
||||
"kv_cache_quant_algo specified in the quantization config"
|
||||
)
|
||||
return cls(
|
||||
is_checkpoint_nvfp4_serialized,
|
||||
|
||||
Reference in New Issue
Block a user