From ef8ec07b2ce4c70c2a33ec5acda4ce529bc3cda4 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Tue, 13 May 2025 06:47:01 +0800 Subject: [PATCH] Support tuning moe for llama 4 model (#6042) --- .../kernels/fused_moe_triton/tuning_fused_moe_triton.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index be349e456..30ffd1d40 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -408,6 +408,12 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] == "Llama4ForConditionalGeneration": + n_share_fusion_experts = args.n_share_experts_fusion + E = config.text_config.num_local_experts + n_share_fusion_experts + topk = config.text_config.num_experts_per_tok + intermediate_size = config.text_config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] in [ "Grok1ForCausalLM", "Grok1ImgGen", @@ -424,7 +430,7 @@ def main(args: argparse.Namespace): intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - hidden_size = config.hidden_size + hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size dtype = config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a8 = args.dtype == "int8_w8a8"