[Feature] Hybrid EP and TP (#8590)
This commit is contained in:
@@ -354,6 +354,13 @@ class GroupCoordinator:
|
||||
self.cpu_group, 1 << 22, 6
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ranks={self.ranks} rank={self.rank} local_rank={self.local_rank} use_pynccl={self.use_pynccl} "
|
||||
f"device_group={self.device_group} cpu_group={self.cpu_group} unique_name={self.unique_name} "
|
||||
f"world_size={self.world_size} rank_in_group={self.rank_in_group}"
|
||||
)
|
||||
|
||||
@property
|
||||
def first_rank(self):
|
||||
"""Return the global rank of the first process in the group"""
|
||||
@@ -1141,6 +1148,20 @@ def get_tp_group() -> GroupCoordinator:
|
||||
return _TP
|
||||
|
||||
|
||||
_MOE_EP: Optional[GroupCoordinator] = None
|
||||
_MOE_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_moe_ep_group() -> GroupCoordinator:
|
||||
assert _MOE_EP is not None, "expert model parallel group is not initialized"
|
||||
return _MOE_EP
|
||||
|
||||
|
||||
def get_moe_tp_group() -> GroupCoordinator:
|
||||
assert _MOE_TP is not None, "expert model parallel group is not initialized"
|
||||
return _MOE_TP
|
||||
|
||||
|
||||
# kept for backward compatibility
|
||||
get_tensor_model_parallel_group = get_tp_group
|
||||
|
||||
@@ -1250,6 +1271,7 @@ def init_distributed_environment(
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
expert_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
backend: Optional[str] = None,
|
||||
duplicate_tp_group: bool = False,
|
||||
@@ -1327,6 +1349,45 @@ def initialize_model_parallel(
|
||||
_TP.pynccl_comm.disabled = False
|
||||
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
||||
|
||||
moe_ep_size = expert_model_parallel_size
|
||||
|
||||
moe_tp_size = tensor_model_parallel_size // moe_ep_size
|
||||
global _MOE_EP
|
||||
assert _MOE_EP is None, "expert model parallel group is already initialized"
|
||||
group_ranks = []
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
for j in range(moe_tp_size):
|
||||
st = i * tensor_model_parallel_size + j
|
||||
en = (i + 1) * tensor_model_parallel_size + j
|
||||
ranks = list(range(st, en, moe_tp_size))
|
||||
group_ranks.append(ranks)
|
||||
|
||||
_MOE_EP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_custom_allreduce=False,
|
||||
group_name="moe_ep",
|
||||
)
|
||||
|
||||
global _MOE_TP
|
||||
assert _MOE_TP is None, "expert model parallel group is already initialized"
|
||||
group_ranks = []
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
for j in range(moe_ep_size):
|
||||
st = i * tensor_model_parallel_size + j * moe_tp_size
|
||||
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
||||
ranks = list(range(st, en))
|
||||
group_ranks.append(ranks)
|
||||
|
||||
_MOE_TP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_custom_allreduce=False,
|
||||
group_name="moe_tp",
|
||||
)
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
||||
global _PP
|
||||
@@ -1347,6 +1408,7 @@ def initialize_model_parallel(
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size: int,
|
||||
expert_model_parallel_size: int,
|
||||
pipeline_model_parallel_size: int,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
@@ -1357,7 +1419,10 @@ def ensure_model_parallel_initialized(
|
||||
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size, pipeline_model_parallel_size, backend
|
||||
tensor_model_parallel_size,
|
||||
expert_model_parallel_size,
|
||||
pipeline_model_parallel_size,
|
||||
backend,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -1417,6 +1482,26 @@ def get_tensor_model_parallel_rank():
|
||||
return get_tp_group().rank_in_group
|
||||
|
||||
|
||||
def get_moe_expert_parallel_world_size():
|
||||
"""Return world size for the moe expert parallel group."""
|
||||
return get_moe_ep_group().world_size
|
||||
|
||||
|
||||
def get_moe_expert_parallel_rank():
|
||||
"""Return my rank for the moe expert parallel group."""
|
||||
return get_moe_ep_group().rank_in_group
|
||||
|
||||
|
||||
def get_moe_tensor_parallel_world_size():
|
||||
"""Return world size for the moe tensor parallel group."""
|
||||
return get_moe_tp_group().world_size
|
||||
|
||||
|
||||
def get_moe_tensor_parallel_rank():
|
||||
"""Return my rank for the moe tensor parallel group."""
|
||||
return get_moe_tp_group().rank_in_group
|
||||
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none and destroy them."""
|
||||
global _TP
|
||||
|
||||
Reference in New Issue
Block a user