@@ -200,8 +200,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.quant_method.create_weights(
|
self.quant_method.create_weights(
|
||||||
layer=self,
|
layer=self,
|
||||||
num_experts=self.num_experts,
|
num_experts=self.num_local_experts,
|
||||||
num_local_experts=self.num_local_experts,
|
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
# FIXME: figure out which intermediate_size to use
|
# FIXME: figure out which intermediate_size to use
|
||||||
intermediate_size=self.intermediate_size_per_partition,
|
intermediate_size=self.intermediate_size_per_partition,
|
||||||
|
|||||||
@@ -752,7 +752,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
num_local_experts: int,
|
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size_per_partition: int,
|
intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
@@ -766,7 +765,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# TODO(ch-wan): check if this is needed
|
# TODO(ch-wan): check if this is needed
|
||||||
layer.num_experts = num_experts
|
layer.num_experts = num_experts
|
||||||
layer.num_local_experts = num_local_experts
|
layer.num_local_experts = num_experts
|
||||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||||
layer.params_dtype = params_dtype
|
layer.params_dtype = params_dtype
|
||||||
layer.quant_config = self.quant_config
|
layer.quant_config = self.quant_config
|
||||||
|
|||||||
Reference in New Issue
Block a user