fused moe triton tuning script support qwen3 (#5842)
This commit is contained in:
@@ -20,6 +20,13 @@ python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
|||||||
--dtype fp8_w8a8 \
|
--dtype fp8_w8a8 \
|
||||||
--tune
|
--tune
|
||||||
|
|
||||||
|
# Tune Qwen3-235B-A22B-FP8 and TP=4
|
||||||
|
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||||
|
--model Qwen/Qwen3-235B-A22B-FP8 \
|
||||||
|
--tp-size 4 \
|
||||||
|
--dtype fp8_w8a8 \
|
||||||
|
--tune
|
||||||
|
|
||||||
# Tune DeepSeek-V3 with FP8, TP=8 and n_share_experts_fusion=8
|
# Tune DeepSeek-V3 with FP8, TP=8 and n_share_experts_fusion=8
|
||||||
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
|
||||||
--model deepseek-ai/DeepSeek-V3-0324 \
|
--model deepseek-ai/DeepSeek-V3-0324 \
|
||||||
|
|||||||
@@ -30,10 +30,15 @@ def get_model_config(model_name: str, tp_size: int):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||||
E = config.n_routed_experts
|
E = config.n_routed_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
elif config.architectures[0] in [
|
elif config.architectures[0] in [
|
||||||
"Grok1ForCausalLM",
|
"Grok1ForCausalLM",
|
||||||
|
|||||||
@@ -30,6 +30,11 @@ def get_model_config(model_name: str, tp_size: int):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||||
E = config.n_routed_experts
|
E = config.n_routed_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
|
|||||||
Reference in New Issue
Block a user