Add DeepSeek V3/R1 shared experts fusion (#4918)
This commit is contained in:
@@ -399,7 +399,12 @@ def main(args: argparse.Namespace):
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||
E = config.n_routed_experts
|
||||
n_share_fusion_experts = args.n_share_experts_fusion
|
||||
E = (
|
||||
config.n_routed_experts + n_share_fusion_experts
|
||||
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||
else config.n_routed_experts
|
||||
)
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
@@ -559,6 +564,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument(
|
||||
"--n-share-experts-fusion",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user