Fix gptq for moe layers (#2300)

Co-authored-by: root <me@zhyncs.com>
This commit is contained in:
Lianmin Zheng
2024-12-03 07:12:33 -08:00
committed by GitHub
parent fda628d8f2
commit 1228f7ca69
2 changed files with 44 additions and 2 deletions

View File

@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
return None
def gptq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
return None
def awq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinLinearMethod,
AWQMoEMethod,
)
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase):
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
return None
def apply_monkey_patches():
"""Apply all monkey patches in one place."""
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
# Apply patches when module is imported