@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
|
||||
return None
|
||||
|
||||
|
||||
def gptq_get_quant_method(self, layer, prefix):
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return GPTQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
def awq_get_quant_method(self, layer, prefix):
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinLinearMethod,
|
||||
AWQMoEMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return AWQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return AWQMoEMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
def apply_monkey_patches():
|
||||
"""Apply all monkey patches in one place."""
|
||||
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
||||
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
||||
|
||||
|
||||
# Apply patches when module is imported
|
||||
|
||||
@@ -339,7 +339,9 @@ class MixtralForCausalLM(nn.Module):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
@@ -353,6 +355,10 @@ class MixtralForCausalLM(nn.Module):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
@@ -365,7 +371,9 @@ class MixtralForCausalLM(nn.Module):
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
# Skip loading kv_scale from ckpts towards new design.
|
||||
if name.endswith(".kv_scale") and name not in params_dict:
|
||||
|
||||
Reference in New Issue
Block a user