From cd4da1f19b1422ae960ab330a6d0e5d780fc5de5 Mon Sep 17 00:00:00 2001 From: Mohammad Miadh Angkad Date: Fri, 26 Sep 2025 01:32:15 +0800 Subject: [PATCH] Refactor kv_cache_scheme handling for quantization (#10132) --- .../srt/layers/quantization/modelopt_quant.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index d72526a61..b5e8d5a35 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -140,11 +140,21 @@ class ModelOptFp8Config(QuantizationConfig): # Flat format (config.json quantization_config) # For kv_cache, check if kv_cache_scheme exists and extract algo kv_cache_scheme = config.get("kv_cache_scheme") - if ( - kv_cache_scheme - and kv_cache_scheme.get("type") == "float" - and kv_cache_scheme.get("num_bits") == 8 - ): + + kv_cache_type = None + kv_cache_bits = None + if isinstance(kv_cache_scheme, dict): + # Handles the expected format: {"type": "float", "num_bits": 8} + kv_cache_type = kv_cache_scheme.get("type") + kv_cache_bits = kv_cache_scheme.get("num_bits") + elif isinstance(kv_cache_scheme, str): + # Handles the shorthand format: "FP8" + if kv_cache_scheme.upper() == "FP8": + kv_cache_type = "float" + kv_cache_bits = 8 + + # Now, safely use the extracted values + if kv_cache_type == "float" and kv_cache_bits == 8: kv_cache_quant_method = "FP8" # Map 'ignore' field to 'exclude_modules' @@ -594,11 +604,22 @@ class ModelOptFp4Config(QuantizationConfig): if not kv_cache_quant_algo: # For config.json format, derive from kv_cache_scheme if available kv_cache_scheme = config.get("kv_cache_scheme") - if ( - kv_cache_scheme - and kv_cache_scheme.get("type") == "float" - and kv_cache_scheme.get("num_bits") == 8 - ): + + kv_cache_type = None + kv_cache_bits = None + if isinstance(kv_cache_scheme, dict): + # Handles the expected format: {"type": "float", "num_bits": 8} + kv_cache_type = kv_cache_scheme.get("type") + kv_cache_bits = kv_cache_scheme.get("num_bits") + elif isinstance(kv_cache_scheme, str): + # Handles the shorthand format: "FP8" + # We can infer the properties from the string. + if kv_cache_scheme.upper() == "FP8": + kv_cache_type = "float" + kv_cache_bits = 8 + + # Now, safely use the extracted values in the original logic + if kv_cache_type == "float" and kv_cache_bits == 8: kv_cache_quant_algo = "FP8" else: kv_cache_quant_algo = "auto"