@@ -200,8 +200,7 @@ class FusedMoE(torch.nn.Module):
|
||||
self.quant_config = quant_config
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
num_experts=self.num_experts,
|
||||
num_local_experts=self.num_local_experts,
|
||||
num_experts=self.num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
# FIXME: figure out which intermediate_size to use
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
|
||||
@@ -752,7 +752,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
@@ -766,7 +765,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# TODO(ch-wan): check if this is needed
|
||||
layer.num_experts = num_experts
|
||||
layer.num_local_experts = num_local_experts
|
||||
layer.num_local_experts = num_experts
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.params_dtype = params_dtype
|
||||
layer.quant_config = self.quant_config
|
||||
|
||||
Reference in New Issue
Block a user