Files
enginex-bi_series-vllm/vllm/model_executor/layers/fused_moe/layer.py

717 lines
31 KiB
Python
Raw Normal View History

2025-08-05 19:02:46 +08:00
from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
# from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import CompressionFormat, QuantizationStrategy
# from vllm.model_executor.layers.quantization.utils.w8a8_utils import create_per_channel_scale_param
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError
@abstractmethod
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
router_logits: torch.Tensor, top_k: int, renormalize: bool,
use_grouped_topk: bool) -> torch.Tensor:
raise NotImplementedError
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
def forward_cuda(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True)
def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")
def forward_tpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
return fused_moe(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=top_k,
gating_output=router_logits,
renormalize=renormalize)
class W8A8QuantizedFusedMoEMethod(FusedMoEMethodBase):
"""MoE method W8A8 quantization. This class is for compressed-tensors format loading"""
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
self.strategy = extra_weight_attrs['quant_config'].target_scheme_map['Linear']['weights'].strategy
self.is_static_input_scheme = not extra_weight_attrs['quant_config'].target_scheme_map['Linear']['input_activations'].dynamic
# assert self.is_static_input_scheme, "W8A8 int quantization only support static input activation for now"
self.quant_config = extra_weight_attrs["quant_config"]
self.weight_loader = extra_weight_attrs["weight_loader"]
self.logical_widths_13 = [intermediate_size * 2]
self.logical_widths_2 = [intermediate_size * 2]
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": self.weight_loader,
})
# WEIGHT SCALE
layer_kwargs = {"weight_loader": self.weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
scale = torch.nn.Parameter(torch.empty((num_experts, intermediate_size * 2, 1),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"input_dim": 1, "output_dim": 0, **layer_kwargs})
else:
assert self.strategy == QuantizationStrategy.TENSOR
scale = torch.nn.Parameter(torch.empty((num_experts, 2), dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
set_weight_attrs(scale, {"is_int8_weight_scale": True})
layer.register_parameter("w13_weight_scale", scale)
# INPUT SCALE
if self.is_static_input_scheme:
scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
set_weight_attrs(scale, {"is_int8_input_scale": True})
layer.register_parameter("w13_input_scale", scale)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": self.weight_loader,
})
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
scale = torch.nn.Parameter(torch.empty((num_experts, hidden_size, 1),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"input_dim": 0, "output_dim": 1, **layer_kwargs})
else:
assert self.strategy == QuantizationStrategy.TENSOR
scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
set_weight_attrs(scale, {"is_int8_weight_scale": True})
layer.register_parameter("w2_weight_scale", scale)
# INPUT SCALE
if self.is_static_input_scheme:
scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
set_weight_attrs(scale, {"is_int8_input_scale": True})
layer.register_parameter("w2_input_scale", scale)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
assert False, "fused_moe w8a8 use ixformer.contrib.vllm.layers.mixtral_decoder_layer_forward"
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
if self.is_static_input_scheme:
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
use_int8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
else:
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
use_int8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None)
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
self.top_k = top_k
self.num_experts = num_experts
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod())
elif (isinstance(quant_config, CompressedTensorsConfig)
and quant_config.quant_format == CompressionFormat.int_quantized.value
and quant_config.target_scheme_map['Linear']['input_activations'].num_bits == 8
and quant_config.target_scheme_map['Linear']['weights'].num_bits == 8):
self.quant_method: Optional[QuantizeMethodBase] = (
W8A8QuantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
quant_config=quant_config)
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
expert_id: int):
param_data = param.data
# for per tensor weight quantization
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim: int, shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int):
# for per channel weight quantization
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def _load_single_value(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, expert_id: int):
param_data = param.data
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = loaded_weight.t().contiguous() if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod") else loaded_weight
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
# Special case for fp8 scales.
if getattr(param, "is_fp8_scale", False):
self._load_fp8_scale(param.data, loaded_weight, weight_name,
shard_id, expert_id)
return
elif getattr(param, "is_int8_input_scale", False):
self._load_int8_input_scale(param.data, loaded_weight,
shard_id, expert_id)
return
WEIGHT_SCALE_SUPPORTED = [
e.value for e in FusedMoeWeightScaleSupported
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = ~shard_dim
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# this is needed for compressed-tensors only
loaded_weight = loaded_weight.to(param.data.device)
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}")
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Case g_idx
if "g_idx" in weight_name:
self._load_g_idx(shard_dim=0,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
return
# Case weight scales and zero_points
if ("scale" in weight_name or "zero" in weight_name):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", FusedMoeWeightScaleSupported.CHANNEL.value)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
else:
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return
# Case weight_shape
if "weight_shape" in weight_name:
# only required by compressed-tensors
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
return
@staticmethod
def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk)
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
return topk_weights, topk_ids
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function)
if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states
@classmethod
def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
for expert_id in range(num_experts) for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
def _load_int8_input_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
shard_id: str, expert_id: int) -> None:
param_data = param.data
# Input scales can be loaded directly and should be equal.
if param_data[expert_id] != 1 and (param_data[expert_id].to(loaded_weight.device) -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight