fix: minor fix for modelopt weight load compatibility (#7953)
This commit is contained in:
@@ -518,6 +518,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
|
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
self.quant_config = quant_config
|
||||||
self.quant_method.create_weights(
|
self.quant_method.create_weights(
|
||||||
layer=self,
|
layer=self,
|
||||||
num_experts=self.local_num_experts,
|
num_experts=self.local_num_experts,
|
||||||
@@ -661,7 +662,11 @@ class FusedMoE(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
|
raise ValueError("expert_data and loaded_weight must be torch.Tensor")
|
||||||
|
|
||||||
if expert_data.dim() != 2 or loaded_weight.dim() != 2:
|
if (
|
||||||
|
self.quant_config is not None
|
||||||
|
and "modelopt" in self.quant_config.get_name()
|
||||||
|
and (expert_data.dim() != 2 or loaded_weight.dim() != 2)
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
|
f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user