diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 046788250..bba83a95f 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1458,43 +1458,49 @@ def initialize_model_parallel( _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False moe_ep_size = expert_model_parallel_size - moe_tp_size = tensor_model_parallel_size // moe_ep_size + global _MOE_EP assert _MOE_EP is None, "expert model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - for j in range(moe_tp_size): - st = i * tensor_model_parallel_size + j - en = (i + 1) * tensor_model_parallel_size + j - ranks = list(range(st, en, moe_tp_size)) - group_ranks.append(ranks) - _MOE_EP = init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False, - group_name="moe_ep", - ) + if moe_ep_size == tensor_model_parallel_size: + _MOE_EP = _TP + else: + # TODO(ch-wan): use split_group to save memory + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + for j in range(moe_tp_size): + st = i * tensor_model_parallel_size + j + en = (i + 1) * tensor_model_parallel_size + j + ranks = list(range(st, en, moe_tp_size)) + group_ranks.append(ranks) + _MOE_EP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="moe_ep", + ) global _MOE_TP assert _MOE_TP is None, "expert model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - for j in range(moe_ep_size): - st = i * tensor_model_parallel_size + j * moe_tp_size - en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size - ranks = list(range(st, en)) - group_ranks.append(ranks) - _MOE_TP = init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False, - group_name="moe_tp", - ) + if moe_tp_size == tensor_model_parallel_size: + _MOE_TP = _TP + else: + # TODO(ch-wan): use split_group to save memory + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + for j in range(moe_ep_size): + st = i * tensor_model_parallel_size + j * moe_tp_size + en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size + ranks = list(range(st, en)) + group_ranks.append(ranks) + _MOE_TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="moe_tp", + ) # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size