@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def gptq_get_quant_method(self, layer, prefix):
|
||||||
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQMarlinLinearMethod,
|
||||||
|
GPTQMarlinMoEMethod,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
||||||
|
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return GPTQMarlinLinearMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return GPTQMarlinMoEMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def awq_get_quant_method(self, layer, prefix):
|
||||||
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||||
|
AWQMarlinLinearMethod,
|
||||||
|
AWQMoEMethod,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
||||||
|
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return AWQMarlinLinearMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return AWQMoEMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def apply_monkey_patches():
|
def apply_monkey_patches():
|
||||||
"""Apply all monkey patches in one place."""
|
"""Apply all monkey patches in one place."""
|
||||||
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
||||||
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
||||||
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||||
|
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
||||||
|
|
||||||
|
|
||||||
# Apply patches when module is imported
|
# Apply patches when module is imported
|
||||||
|
|||||||
@@ -339,7 +339,9 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if (
|
||||||
|
name.endswith(".bias") or name.endswith("_bias")
|
||||||
|
) and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
@@ -353,6 +355,10 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
if (
|
||||||
|
name.endswith(".bias") or name.endswith("_bias")
|
||||||
|
) and name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(
|
weight_loader(
|
||||||
@@ -365,7 +371,9 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if (
|
||||||
|
name.endswith(".bias") or name.endswith("_bias")
|
||||||
|
) and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
# Skip loading kv_scale from ckpts towards new design.
|
# Skip loading kv_scale from ckpts towards new design.
|
||||||
if name.endswith(".kv_scale") and name not in params_dict:
|
if name.endswith(".kv_scale") and name not in params_dict:
|
||||||
|
|||||||
Reference in New Issue
Block a user