diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index ccc6ebffb..db0bf3ab7 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -111,18 +111,52 @@ class ModelOptFp8Config(QuantizationConfig): @classmethod def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config: - quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo") - kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get( - "kv_cache_quant_algo" - ) - exclude_modules = cls.get_from_keys(config, ["quantization"]).get( - "exclude_modules" - ) + # Handle two different config formats: + # 1. hf_quant_config.json format: {"quantization": {"quant_algo": "FP8", ...}} + # 2. config.json quantization_config format: {"quant_algo": "FP8", ...} + # In future modelopt will deprecate hf_quant_config.json, and only keep config.json. + # For legacy reasons, we keep hf_quant_config.json for now. + # Initialize variables + kv_cache_quant_method = None + exclude_modules = None + + # Try flat format first (config.json quantization_config - preferred format) + quant_method = config.get("quant_algo") + if quant_method is not None: + # Flat format (config.json quantization_config) + # For kv_cache, check if kv_cache_scheme exists and extract algo + kv_cache_scheme = config.get("kv_cache_scheme") + if ( + kv_cache_scheme + and kv_cache_scheme.get("type") == "float" + and kv_cache_scheme.get("num_bits") == 8 + ): + kv_cache_quant_method = "FP8" + + # Map 'ignore' field to 'exclude_modules' + exclude_modules = config.get("ignore") + else: + # Fall back to nested format (hf_quant_config.json - legacy format) + try: + quantization_section = cls.get_from_keys(config, ["quantization"]) + quant_method = quantization_section.get("quant_algo") + kv_cache_quant_method = quantization_section.get("kv_cache_quant_algo") + exclude_modules = quantization_section.get("exclude_modules") + except ValueError: + raise ValueError( + "Cannot find 'quant_algo' in the model's quantization config. " + "Expected either flat format (config.json) or nested format (hf_quant_config.json)." + ) + if quant_method is None: + raise ValueError( + "Cannot find 'quant_algo' in the model's quantization config. " + ) if "FP8" not in quant_method: raise ValueError( - "ModelOpt only supports static FP8 quantization in SGLang. " - "Check the `hf_quant_config.json` file for your model's configuration." + "ModelOptFp8Config only supports static FP8 quantization in SGLang. " + "For FP4 quantization, use ModelOptFp4Config. " + "Check the quantization config for your model's configuration." ) return cls( @@ -485,22 +519,63 @@ class ModelOptFp4Config(QuantizationConfig): @classmethod def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: - quant_config = cls.get_from_keys(config, ["quantization"]) - quant_method = quant_config["quant_algo"] + # Handle two different config formats: + # 1. hf_quant_config.json format: {"quantization": {"quant_algo": "NVFP4", ...}} + # 2. config.json quantization_config format: {"quant_algo": "NVFP4", ...} + # In future modelopt will deprecate hf_quant_config.json, and only keep config.json. + # For legacy reasons, we keep hf_quant_config.json for now. + + # Initialize variables + kv_cache_quant_algo = None + group_size = None + exclude_modules = [] + + # Try flat format first (config.json quantization_config - preferred format) + quant_method = config.get("quant_algo") + if quant_method is not None: + # Flat format (config.json quantization_config) + # Note: FP4 models in config.json format may not have all the detailed fields + # that are present in hf_quant_config.json, so we need to handle defaults + kv_cache_quant_algo = config.get("kv_cache_quant_algo") + if not kv_cache_quant_algo: + # For config.json format, derive from kv_cache_scheme if available + kv_cache_scheme = config.get("kv_cache_scheme") + if ( + kv_cache_scheme + and kv_cache_scheme.get("type") == "float" + and kv_cache_scheme.get("num_bits") == 8 + ): + kv_cache_quant_algo = "FP8" + else: + kv_cache_quant_algo = "auto" + + group_size = config.get("group_size") + exclude_modules = config.get("ignore", []) + else: + # Fall back to nested format (hf_quant_config.json - legacy format) + try: + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo") + if not kv_cache_quant_algo: + kv_cache_quant_algo = "auto" + group_size = quant_config.get("group_size") + exclude_modules = quant_config.get("exclude_modules", []) + except (ValueError, KeyError): + raise ValueError( + "Cannot find 'quant_algo' in the model's quantization config. " + "Expected either flat format (config.json) or nested format (hf_quant_config.json)." + ) + if not quant_method in ["FP8", "NVFP4"]: raise ValueError( f"ModelOpt currently only supports: FP8, NVFP4" " quantizations in sglang. Please check the " - "`hf_quant_config.json` file for your model's " - "quant configuration." + "quantization config for your model's configuration." ) is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method - kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] - if not kv_cache_quant_algo: - kv_cache_quant_algo = "auto" - group_size = quant_config["group_size"] - exclude_modules = quant_config["exclude_modules"] - if not (group_size and kv_cache_quant_algo and exclude_modules): + + if not (group_size and kv_cache_quant_algo) or exclude_modules is None: logger.warning( f"group_size: {group_size}," f"kv_cache_quant_algo: {kv_cache_quant_algo}," @@ -508,8 +583,7 @@ class ModelOptFp4Config(QuantizationConfig): ) raise ValueError( "NVFP4 quantization requires group size and " - "kv_cache_quant_algo specified in " - "hf_quant_config.json" + "kv_cache_quant_algo specified in the quantization config" ) return cls( is_checkpoint_nvfp4_serialized,