Support tuning moe for llama 4 model (#6042)
This commit is contained in:
@@ -408,6 +408,12 @@ def main(args: argparse.Namespace):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
||||||
|
n_share_fusion_experts = args.n_share_experts_fusion
|
||||||
|
E = config.text_config.num_local_experts + n_share_fusion_experts
|
||||||
|
topk = config.text_config.num_experts_per_tok
|
||||||
|
intermediate_size = config.text_config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
elif config.architectures[0] in [
|
elif config.architectures[0] in [
|
||||||
"Grok1ForCausalLM",
|
"Grok1ForCausalLM",
|
||||||
"Grok1ImgGen",
|
"Grok1ImgGen",
|
||||||
@@ -424,7 +430,7 @@ def main(args: argparse.Namespace):
|
|||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
|
||||||
hidden_size = config.hidden_size
|
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
|
||||||
dtype = config.torch_dtype
|
dtype = config.torch_dtype
|
||||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
use_int8_w8a8 = args.dtype == "int8_w8a8"
|
use_int8_w8a8 = args.dtype == "int8_w8a8"
|
||||||
|
|||||||
Reference in New Issue
Block a user