From 1228f7ca69e6ee3f5076f2381c3a187120e0de00 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 3 Dec 2024 07:12:33 -0800 Subject: [PATCH] Fix gptq for moe layers (#2300) Co-authored-by: root --- .../srt/layers/quantization/__init__.py | 34 +++++++++++++++++++ python/sglang/srt/models/mixtral.py | 12 +++++-- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index a1bacdce0..f34a581d6 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix): return None +def gptq_get_quant_method(self, layer, prefix): + from vllm.model_executor.layers.linear import LinearBase + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinLinearMethod, + GPTQMarlinMoEMethod, + ) + + from sglang.srt.layers.fused_moe_triton.layer import FusedMoE + + if isinstance(layer, LinearBase): + return GPTQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return GPTQMarlinMoEMethod(self) + return None + + +def awq_get_quant_method(self, layer, prefix): + from vllm.model_executor.layers.linear import LinearBase + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinLinearMethod, + AWQMoEMethod, + ) + + from sglang.srt.layers.fused_moe_triton.layer import FusedMoE + + if isinstance(layer, LinearBase): + return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) + return None + + def apply_monkey_patches(): """Apply all monkey patches in one place.""" setattr(Fp8MoEMethod, "apply", fp8_moe_apply) setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) + setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) + setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) # Apply patches when module is imported diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index b222387a7..e75dc1288 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -339,7 +339,9 @@ class MixtralForCausalLM(nn.Module): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] @@ -353,6 +355,10 @@ class MixtralForCausalLM(nn.Module): continue name = name.replace(weight_name, param_name) + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( @@ -365,7 +371,9 @@ class MixtralForCausalLM(nn.Module): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip loading kv_scale from ckpts towards new design. if name.endswith(".kv_scale") and name not in params_dict: