diff --git a/vllm_kunlun/models/mimo_v2_flash.py b/vllm_kunlun/models/mimo_v2_flash.py index e381b24..033e1be 100644 --- a/vllm_kunlun/models/mimo_v2_flash.py +++ b/vllm_kunlun/models/mimo_v2_flash.py @@ -21,7 +21,7 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm_kunlun.ops.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, diff --git a/vllm_kunlun/ops/fused_moe/layer.py b/vllm_kunlun/ops/fused_moe/layer.py index 8120868..b6fb2be 100644 --- a/vllm_kunlun/ops/fused_moe/layer.py +++ b/vllm_kunlun/ops/fused_moe/layer.py @@ -4,6 +4,9 @@ from contextlib import nullcontext from typing import Callable, Optional, Union, get_args import torch +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + should_ignore_layer, +) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod @@ -129,8 +132,26 @@ class VllmFusedMoE(FusedMoE): is_sequence_parallel=is_sequence_parallel, zero_expert_num=zero_expert_num, zero_expert_type=zero_expert_type) - self.has_bias=has_bias + self.has_bias = has_bias self.register_parameter("w13_bias", None) self.register_parameter("w2_bias", None) -FusedMoE=VllmFusedMoE \ No newline at end of file + if (self.quant_config is None) or ( + should_ignore_layer( + prefix, + ignore=self.quant_config.ignore, + fused_mapping=self.quant_config.packed_modules_mapping, + ) + ): + self.quant_method = UnquantizedFusedMoEMethod(self.moe_config) + moe_quant_params = { + "num_experts": self.local_num_experts, + "hidden_size": hidden_size, + "intermediate_size_per_partition": self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + } + self.quant_method.create_weights(layer=self, **moe_quant_params) + + +FusedMoE = VllmFusedMoE