Set num_fused_shared_experts as num_shared_experts when shared_experts fusion is not disabled (#6736)

This commit is contained in:
Cheng Wan
2025-06-04 15:53:22 -07:00
committed by GitHub
parent f0f84975f4
commit 81964328b7
22 changed files with 381 additions and 45 deletions

View File

@@ -400,7 +400,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + 1
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
@@ -408,7 +408,9 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + 1
E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
@@ -558,7 +560,7 @@ if __name__ == "__main__":
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", "-tp", type=int, default=2)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument(
"--dtype",
type=str,
@@ -568,6 +570,7 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
args = parser.parse_args()
main(args)