Modelopt quant config adaptation (#8829)
This commit is contained in:
@@ -111,18 +111,52 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
||||||
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
# Handle two different config formats:
|
||||||
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "FP8", ...}}
|
||||||
"kv_cache_quant_algo"
|
# 2. config.json quantization_config format: {"quant_algo": "FP8", ...}
|
||||||
)
|
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
|
||||||
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
|
# For legacy reasons, we keep hf_quant_config.json for now.
|
||||||
"exclude_modules"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Initialize variables
|
||||||
|
kv_cache_quant_method = None
|
||||||
|
exclude_modules = None
|
||||||
|
|
||||||
|
# Try flat format first (config.json quantization_config - preferred format)
|
||||||
|
quant_method = config.get("quant_algo")
|
||||||
|
if quant_method is not None:
|
||||||
|
# Flat format (config.json quantization_config)
|
||||||
|
# For kv_cache, check if kv_cache_scheme exists and extract algo
|
||||||
|
kv_cache_scheme = config.get("kv_cache_scheme")
|
||||||
|
if (
|
||||||
|
kv_cache_scheme
|
||||||
|
and kv_cache_scheme.get("type") == "float"
|
||||||
|
and kv_cache_scheme.get("num_bits") == 8
|
||||||
|
):
|
||||||
|
kv_cache_quant_method = "FP8"
|
||||||
|
|
||||||
|
# Map 'ignore' field to 'exclude_modules'
|
||||||
|
exclude_modules = config.get("ignore")
|
||||||
|
else:
|
||||||
|
# Fall back to nested format (hf_quant_config.json - legacy format)
|
||||||
|
try:
|
||||||
|
quantization_section = cls.get_from_keys(config, ["quantization"])
|
||||||
|
quant_method = quantization_section.get("quant_algo")
|
||||||
|
kv_cache_quant_method = quantization_section.get("kv_cache_quant_algo")
|
||||||
|
exclude_modules = quantization_section.get("exclude_modules")
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot find 'quant_algo' in the model's quantization config. "
|
||||||
|
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
|
||||||
|
)
|
||||||
|
if quant_method is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot find 'quant_algo' in the model's quantization config. "
|
||||||
|
)
|
||||||
if "FP8" not in quant_method:
|
if "FP8" not in quant_method:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"ModelOpt only supports static FP8 quantization in SGLang. "
|
"ModelOptFp8Config only supports static FP8 quantization in SGLang. "
|
||||||
"Check the `hf_quant_config.json` file for your model's configuration."
|
"For FP4 quantization, use ModelOptFp4Config. "
|
||||||
|
"Check the quantization config for your model's configuration."
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
@@ -485,22 +519,63 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
||||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
# Handle two different config formats:
|
||||||
quant_method = quant_config["quant_algo"]
|
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "NVFP4", ...}}
|
||||||
|
# 2. config.json quantization_config format: {"quant_algo": "NVFP4", ...}
|
||||||
|
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
|
||||||
|
# For legacy reasons, we keep hf_quant_config.json for now.
|
||||||
|
|
||||||
|
# Initialize variables
|
||||||
|
kv_cache_quant_algo = None
|
||||||
|
group_size = None
|
||||||
|
exclude_modules = []
|
||||||
|
|
||||||
|
# Try flat format first (config.json quantization_config - preferred format)
|
||||||
|
quant_method = config.get("quant_algo")
|
||||||
|
if quant_method is not None:
|
||||||
|
# Flat format (config.json quantization_config)
|
||||||
|
# Note: FP4 models in config.json format may not have all the detailed fields
|
||||||
|
# that are present in hf_quant_config.json, so we need to handle defaults
|
||||||
|
kv_cache_quant_algo = config.get("kv_cache_quant_algo")
|
||||||
|
if not kv_cache_quant_algo:
|
||||||
|
# For config.json format, derive from kv_cache_scheme if available
|
||||||
|
kv_cache_scheme = config.get("kv_cache_scheme")
|
||||||
|
if (
|
||||||
|
kv_cache_scheme
|
||||||
|
and kv_cache_scheme.get("type") == "float"
|
||||||
|
and kv_cache_scheme.get("num_bits") == 8
|
||||||
|
):
|
||||||
|
kv_cache_quant_algo = "FP8"
|
||||||
|
else:
|
||||||
|
kv_cache_quant_algo = "auto"
|
||||||
|
|
||||||
|
group_size = config.get("group_size")
|
||||||
|
exclude_modules = config.get("ignore", [])
|
||||||
|
else:
|
||||||
|
# Fall back to nested format (hf_quant_config.json - legacy format)
|
||||||
|
try:
|
||||||
|
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||||
|
quant_method = quant_config["quant_algo"]
|
||||||
|
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
|
||||||
|
if not kv_cache_quant_algo:
|
||||||
|
kv_cache_quant_algo = "auto"
|
||||||
|
group_size = quant_config.get("group_size")
|
||||||
|
exclude_modules = quant_config.get("exclude_modules", [])
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot find 'quant_algo' in the model's quantization config. "
|
||||||
|
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
|
||||||
|
)
|
||||||
|
|
||||||
if not quant_method in ["FP8", "NVFP4"]:
|
if not quant_method in ["FP8", "NVFP4"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"ModelOpt currently only supports: FP8, NVFP4"
|
f"ModelOpt currently only supports: FP8, NVFP4"
|
||||||
" quantizations in sglang. Please check the "
|
" quantizations in sglang. Please check the "
|
||||||
"`hf_quant_config.json` file for your model's "
|
"quantization config for your model's configuration."
|
||||||
"quant configuration."
|
|
||||||
)
|
)
|
||||||
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
|
||||||
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
|
||||||
if not kv_cache_quant_algo:
|
if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
|
||||||
kv_cache_quant_algo = "auto"
|
|
||||||
group_size = quant_config["group_size"]
|
|
||||||
exclude_modules = quant_config["exclude_modules"]
|
|
||||||
if not (group_size and kv_cache_quant_algo and exclude_modules):
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"group_size: {group_size},"
|
f"group_size: {group_size},"
|
||||||
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
|
||||||
@@ -508,8 +583,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"NVFP4 quantization requires group size and "
|
"NVFP4 quantization requires group size and "
|
||||||
"kv_cache_quant_algo specified in "
|
"kv_cache_quant_algo specified in the quantization config"
|
||||||
"hf_quant_config.json"
|
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
is_checkpoint_nvfp4_serialized,
|
is_checkpoint_nvfp4_serialized,
|
||||||
|
|||||||
Reference in New Issue
Block a user