[5/N] MoE Refactor: Update MoE parallelism arguments (#8658)
This commit is contained in:
@@ -29,6 +29,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
@@ -117,7 +118,7 @@ class Grok1MoE(nn.Module):
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
if global_server_args_dict["enable_ep_moe"]:
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
MoEImpl = EPMoE
|
||||
else:
|
||||
MoEImpl = FusedMoE
|
||||
@@ -616,8 +617,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
ckpt_up_proj_name="w3",
|
||||
|
||||
Reference in New Issue
Block a user