Set num_fused_shared_experts as num_shared_experts when shared_experts fusion is not disabled (#6736)
This commit is contained in:
@@ -400,7 +400,7 @@ def main(args: argparse.Namespace):
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||
E = (
|
||||
config.n_routed_experts + 1
|
||||
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
|
||||
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||
else config.n_routed_experts
|
||||
)
|
||||
@@ -408,7 +408,9 @@ def main(args: argparse.Namespace):
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||
E = config.text_config.num_local_experts + 1
|
||||
E = config.text_config.num_local_experts + (
|
||||
0 if args.disable_shared_experts_fusion else 1
|
||||
)
|
||||
topk = config.text_config.num_experts_per_tok
|
||||
intermediate_size = config.text_config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
@@ -558,7 +560,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||
parser.add_argument("--tp-size", "--tp", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
@@ -568,6 +570,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user