From fcd72bd100b5bdad4b304e2c76b82e657edf9502 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Fri, 29 Aug 2025 17:13:52 -0700 Subject: [PATCH] [ModelOpt] Fix Weight Loading for DSR1-FP4 Quantization (#9712) Signed-off-by: Pavani Majety --- python/sglang/srt/layers/linear.py | 5 +++-- python/sglang/srt/layers/quantization/modelopt_quant.py | 7 +++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index df2b77e08..47dfc7324 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -235,8 +235,9 @@ class ReplicatedLinear(LinearBase): loaded_weight = loaded_weight[:1] else: raise ValueError(f"{loaded_weight} are not all equal") - - assert param.size() == loaded_weight.size() + assert ( + param.size() == loaded_weight.size() + ), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}" param.data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index aff18fa2b..b8e02c792 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -599,6 +599,13 @@ class ModelOptFp4Config(QuantizationConfig): regex_str = pattern.replace(".", r"\.").replace("*", r".*") if re.fullmatch(regex_str, prefix): return True + + # Check if the last part of the excluded pattern is contained in the last part of the prefix + # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa + pattern_last_part = pattern.split(".")[-1] + prefix_last_part = prefix.split(".")[-1] + if pattern_last_part in prefix_last_part: + return True return False def get_quant_method(