Save memory for expert model parallel (#9957)

This commit is contained in:
Cheng Wan
2025-09-04 13:31:47 -07:00
committed by GitHub
parent d07304870b
commit 453511acc7

View File

@@ -1458,10 +1458,15 @@ def initialize_model_parallel(
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
moe_ep_size = expert_model_parallel_size moe_ep_size = expert_model_parallel_size
moe_tp_size = tensor_model_parallel_size // moe_ep_size moe_tp_size = tensor_model_parallel_size // moe_ep_size
global _MOE_EP global _MOE_EP
assert _MOE_EP is None, "expert model parallel group is already initialized" assert _MOE_EP is None, "expert model parallel group is already initialized"
if moe_ep_size == tensor_model_parallel_size:
_MOE_EP = _TP
else:
# TODO(ch-wan): use split_group to save memory
group_ranks = [] group_ranks = []
for i in range(num_tensor_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
for j in range(moe_tp_size): for j in range(moe_tp_size):
@@ -1469,17 +1474,20 @@ def initialize_model_parallel(
en = (i + 1) * tensor_model_parallel_size + j en = (i + 1) * tensor_model_parallel_size + j
ranks = list(range(st, en, moe_tp_size)) ranks = list(range(st, en, moe_tp_size))
group_ranks.append(ranks) group_ranks.append(ranks)
_MOE_EP = init_model_parallel_group( _MOE_EP = init_model_parallel_group(
group_ranks, group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,
use_custom_allreduce=False,
group_name="moe_ep", group_name="moe_ep",
) )
global _MOE_TP global _MOE_TP
assert _MOE_TP is None, "expert model parallel group is already initialized" assert _MOE_TP is None, "expert model parallel group is already initialized"
if moe_tp_size == tensor_model_parallel_size:
_MOE_TP = _TP
else:
# TODO(ch-wan): use split_group to save memory
group_ranks = [] group_ranks = []
for i in range(num_tensor_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
for j in range(moe_ep_size): for j in range(moe_ep_size):
@@ -1487,12 +1495,10 @@ def initialize_model_parallel(
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
ranks = list(range(st, en)) ranks = list(range(st, en))
group_ranks.append(ranks) group_ranks.append(ranks)
_MOE_TP = init_model_parallel_group( _MOE_TP = init_model_parallel_group(
group_ranks, group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,
use_custom_allreduce=False,
group_name="moe_tp", group_name="moe_tp",
) )