[5/N] MoE Refactor: Update MoE parallelism arguments (#8658)

This commit is contained in:
Cheng Wan
2025-08-01 01:20:03 -07:00
committed by GitHub
parent c8d3a402c1
commit 6c88f6c8d9
38 changed files with 342 additions and 299 deletions

View File

@@ -29,6 +29,7 @@ from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
@@ -117,7 +118,7 @@ class Grok1MoE(nn.Module):
)
kwargs = {}
if global_server_args_dict["enable_ep_moe"]:
if get_moe_expert_parallel_world_size() > 1:
MoEImpl = EPMoE
else:
MoEImpl = FusedMoE
@@ -616,8 +617,7 @@ class Grok1ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",