fix: minor fix for modelopt weight load compatibility (#7953)
This commit is contained in:
@@ -518,6 +518,7 @@ class FusedMoE(torch.nn.Module):
|
||||
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
num_experts=self.local_num_experts,
|
||||
@@ -661,7 +662,11 @@ class FusedMoE(torch.nn.Module):
|
||||
):
|
||||
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
|
||||
|
||||
if expert_data.dim() != 2 or loaded_weight.dim() != 2:
|
||||
if (
|
||||
self.quant_config is not None
|
||||
and "modelopt" in self.quant_config.get_name()
|
||||
and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user