Save memory for expert model parallel (#9957)
This commit is contained in:
@@ -1458,43 +1458,49 @@ def initialize_model_parallel(
|
||||
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
||||
|
||||
moe_ep_size = expert_model_parallel_size
|
||||
|
||||
moe_tp_size = tensor_model_parallel_size // moe_ep_size
|
||||
|
||||
global _MOE_EP
|
||||
assert _MOE_EP is None, "expert model parallel group is already initialized"
|
||||
group_ranks = []
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
for j in range(moe_tp_size):
|
||||
st = i * tensor_model_parallel_size + j
|
||||
en = (i + 1) * tensor_model_parallel_size + j
|
||||
ranks = list(range(st, en, moe_tp_size))
|
||||
group_ranks.append(ranks)
|
||||
|
||||
_MOE_EP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_custom_allreduce=False,
|
||||
group_name="moe_ep",
|
||||
)
|
||||
if moe_ep_size == tensor_model_parallel_size:
|
||||
_MOE_EP = _TP
|
||||
else:
|
||||
# TODO(ch-wan): use split_group to save memory
|
||||
group_ranks = []
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
for j in range(moe_tp_size):
|
||||
st = i * tensor_model_parallel_size + j
|
||||
en = (i + 1) * tensor_model_parallel_size + j
|
||||
ranks = list(range(st, en, moe_tp_size))
|
||||
group_ranks.append(ranks)
|
||||
_MOE_EP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="moe_ep",
|
||||
)
|
||||
|
||||
global _MOE_TP
|
||||
assert _MOE_TP is None, "expert model parallel group is already initialized"
|
||||
group_ranks = []
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
for j in range(moe_ep_size):
|
||||
st = i * tensor_model_parallel_size + j * moe_tp_size
|
||||
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
||||
ranks = list(range(st, en))
|
||||
group_ranks.append(ranks)
|
||||
|
||||
_MOE_TP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_custom_allreduce=False,
|
||||
group_name="moe_tp",
|
||||
)
|
||||
if moe_tp_size == tensor_model_parallel_size:
|
||||
_MOE_TP = _TP
|
||||
else:
|
||||
# TODO(ch-wan): use split_group to save memory
|
||||
group_ranks = []
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
for j in range(moe_ep_size):
|
||||
st = i * tensor_model_parallel_size + j * moe_tp_size
|
||||
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
||||
ranks = list(range(st, en))
|
||||
group_ranks.append(ranks)
|
||||
_MOE_TP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="moe_tp",
|
||||
)
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
||||
|
||||
Reference in New Issue
Block a user