h100 tuning fused_moe_triton for qwen2 moe (#2560)

This commit is contained in:
Xiaoyu Zhang
2024-12-26 19:13:31 +08:00
committed by GitHub
parent 635a042623
commit 9a23c48456
10 changed files with 812 additions and 67 deletions

View File

@@ -307,7 +307,7 @@ def save_configs(
def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(args.model)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
@@ -323,6 +323,11 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "DeepseekV2ForCausalLM":
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral
E = config.num_local_experts