Qwen FP8/NVFP4 ModelOPT Quantization support (#7912)

Co-authored-by: Jingyu Xin <jingyux@nvidia.com>
This commit is contained in:
jingyu-ml
2025-09-02 22:56:03 -05:00
committed by GitHub
parent cc9a31c662
commit bcbeed714f
2 changed files with 43 additions and 4 deletions

View File

@@ -517,6 +517,39 @@ class ModelOptFp4Config(QuantizationConfig):
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@staticmethod
def common_group_size(cfg: dict) -> int:
"""Return the unique group_size across the config; raise if missing/mismatched."""
sizes = set()
# Top-level and 'quantization' block
v = cfg.get("group_size")
if isinstance(v, int):
sizes.add(v)
q = cfg.get("quantization")
if isinstance(q, dict):
v = q.get("group_size")
if isinstance(v, int):
sizes.add(v)
# config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
for g in (cfg.get("config_groups") or {}).values():
if isinstance(g, dict):
v = g.get("group_size")
if isinstance(v, int):
sizes.add(v)
for sub in g.values():
if isinstance(sub, dict):
v = sub.get("group_size")
if isinstance(v, int):
sizes.add(v)
if not sizes:
raise ValueError("No group_size found in config.")
if len(sizes) > 1:
raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
return next(iter(sizes))
@classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
# Handle two different config formats:
@@ -549,7 +582,7 @@ class ModelOptFp4Config(QuantizationConfig):
else:
kv_cache_quant_algo = "auto"
group_size = config.get("group_size")
group_size = ModelOptFp4Config.common_group_size(config)
exclude_modules = config.get("ignore", [])
else:
# Fall back to nested format (hf_quant_config.json - legacy format)
@@ -559,7 +592,7 @@ class ModelOptFp4Config(QuantizationConfig):
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
if not kv_cache_quant_algo:
kv_cache_quant_algo = "auto"
group_size = quant_config.get("group_size")
group_size = ModelOptFp4Config.common_group_size(config)
exclude_modules = quant_config.get("exclude_modules", [])
except (ValueError, KeyError):
raise ValueError(