diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index d72526a61..b5e8d5a35 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -140,11 +140,21 @@ class ModelOptFp8Config(QuantizationConfig): # Flat format (config.json quantization_config) # For kv_cache, check if kv_cache_scheme exists and extract algo kv_cache_scheme = config.get("kv_cache_scheme") - if ( - kv_cache_scheme - and kv_cache_scheme.get("type") == "float" - and kv_cache_scheme.get("num_bits") == 8 - ): + + kv_cache_type = None + kv_cache_bits = None + if isinstance(kv_cache_scheme, dict): + # Handles the expected format: {"type": "float", "num_bits": 8} + kv_cache_type = kv_cache_scheme.get("type") + kv_cache_bits = kv_cache_scheme.get("num_bits") + elif isinstance(kv_cache_scheme, str): + # Handles the shorthand format: "FP8" + if kv_cache_scheme.upper() == "FP8": + kv_cache_type = "float" + kv_cache_bits = 8 + + # Now, safely use the extracted values + if kv_cache_type == "float" and kv_cache_bits == 8: kv_cache_quant_method = "FP8" # Map 'ignore' field to 'exclude_modules' @@ -594,11 +604,22 @@ class ModelOptFp4Config(QuantizationConfig): if not kv_cache_quant_algo: # For config.json format, derive from kv_cache_scheme if available kv_cache_scheme = config.get("kv_cache_scheme") - if ( - kv_cache_scheme - and kv_cache_scheme.get("type") == "float" - and kv_cache_scheme.get("num_bits") == 8 - ): + + kv_cache_type = None + kv_cache_bits = None + if isinstance(kv_cache_scheme, dict): + # Handles the expected format: {"type": "float", "num_bits": 8} + kv_cache_type = kv_cache_scheme.get("type") + kv_cache_bits = kv_cache_scheme.get("num_bits") + elif isinstance(kv_cache_scheme, str): + # Handles the shorthand format: "FP8" + # We can infer the properties from the string. + if kv_cache_scheme.upper() == "FP8": + kv_cache_type = "float" + kv_cache_bits = 8 + + # Now, safely use the extracted values in the original logic + if kv_cache_type == "float" and kv_cache_bits == 8: kv_cache_quant_algo = "FP8" else: kv_cache_quant_algo = "auto"