From bcbeed714f377d46365132f900b075dc914f0010 Mon Sep 17 00:00:00 2001 From: jingyu-ml <108295447+jingyu-ml@users.noreply.github.com> Date: Tue, 2 Sep 2025 22:56:03 -0500 Subject: [PATCH] Qwen FP8/NVFP4 ModelOPT Quantization support (#7912) Co-authored-by: Jingyu Xin --- .../srt/layers/quantization/modelopt_quant.py | 37 ++++++++++++++++++- python/sglang/srt/models/qwen3.py | 10 ++++- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index b8e02c792..bd4367234 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -517,6 +517,39 @@ class ModelOptFp4Config(QuantizationConfig): def get_config_filenames(cls) -> List[str]: return ["hf_quant_config.json"] + @staticmethod + def common_group_size(cfg: dict) -> int: + """Return the unique group_size across the config; raise if missing/mismatched.""" + sizes = set() + + # Top-level and 'quantization' block + v = cfg.get("group_size") + if isinstance(v, int): + sizes.add(v) + q = cfg.get("quantization") + if isinstance(q, dict): + v = q.get("group_size") + if isinstance(v, int): + sizes.add(v) + + # config_groups: accept group-level or nested dicts (e.g., weights/input_activations) + for g in (cfg.get("config_groups") or {}).values(): + if isinstance(g, dict): + v = g.get("group_size") + if isinstance(v, int): + sizes.add(v) + for sub in g.values(): + if isinstance(sub, dict): + v = sub.get("group_size") + if isinstance(v, int): + sizes.add(v) + + if not sizes: + raise ValueError("No group_size found in config.") + if len(sizes) > 1: + raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}") + return next(iter(sizes)) + @classmethod def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: # Handle two different config formats: @@ -549,7 +582,7 @@ class ModelOptFp4Config(QuantizationConfig): else: kv_cache_quant_algo = "auto" - group_size = config.get("group_size") + group_size = ModelOptFp4Config.common_group_size(config) exclude_modules = config.get("ignore", []) else: # Fall back to nested format (hf_quant_config.json - legacy format) @@ -559,7 +592,7 @@ class ModelOptFp4Config(QuantizationConfig): kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo") if not kv_cache_quant_algo: kv_cache_quant_algo = "auto" - group_size = quant_config.get("group_size") + group_size = ModelOptFp4Config.common_group_size(config) exclude_modules = quant_config.get("exclude_modules", []) except (ValueError, KeyError): raise ValueError( diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 042159a50..bc5f054d7 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -24,7 +24,10 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.utils import add_prefix, is_cuda @@ -458,7 +461,10 @@ class Qwen3ForCausalLM(nn.Module): continue if name.startswith("model.vision_tower") and name not in params_dict: continue - + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue