fix: minor fix for modelopt weight load compatibility (#7953)

This commit is contained in:
Peng Zhang
2025-07-12 05:20:58 +08:00
committed by GitHub
parent 86044712c6
commit 191d836ff6

View File

@@ -518,6 +518,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
assert self.quant_method is not None
self.quant_config = quant_config
self.quant_method.create_weights(
layer=self,
num_experts=self.local_num_experts,
@@ -661,7 +662,11 @@ class FusedMoE(torch.nn.Module):
):
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
if expert_data.dim() != 2 or loaded_weight.dim() != 2:
if (
self.quant_config is not None
and "modelopt" in self.quant_config.get_name()
and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
):
raise ValueError(
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
)