[Feature]TP Group Switching for PD-Multiplexing (#7653)
This commit is contained in:
@@ -1065,8 +1065,23 @@ def init_model_parallel_group(
|
||||
|
||||
_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
# duplicate GroupCoordinator for prefill in PD-Multiplexing
|
||||
_PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None
|
||||
|
||||
_ENABLE_PDMUX_P_TP: bool = False
|
||||
|
||||
|
||||
def set_pdmux_status(enable_prefill_multiplexing: bool):
|
||||
global _ENABLE_PDMUX_P_TP
|
||||
_ENABLE_PDMUX_P_TP = enable_prefill_multiplexing
|
||||
|
||||
|
||||
def get_tp_group() -> GroupCoordinator:
|
||||
if _ENABLE_PDMUX_P_TP:
|
||||
assert (
|
||||
_PDMUX_PREFILL_TP_GROUP is not None
|
||||
), "tensor model parallel group for PD-Multiplexing Prefill is not initialized"
|
||||
return _PDMUX_PREFILL_TP_GROUP
|
||||
assert _TP is not None, "tensor model parallel group is not initialized"
|
||||
return _TP
|
||||
|
||||
@@ -1182,6 +1197,7 @@ def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
backend: Optional[str] = None,
|
||||
duplicate_tp_group: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model parallel groups.
|
||||
@@ -1239,6 +1255,23 @@ def initialize_model_parallel(
|
||||
group_name="tp",
|
||||
)
|
||||
|
||||
if duplicate_tp_group:
|
||||
global _PDMUX_PREFILL_TP_GROUP
|
||||
assert (
|
||||
_PDMUX_PREFILL_TP_GROUP is None
|
||||
), "tensor model parallel group for PD-Multiplexing Prefill is already initialized"
|
||||
_PDMUX_PREFILL_TP_GROUP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=get_bool_env_var(
|
||||
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||
),
|
||||
group_name="pdmux_prefill_tp",
|
||||
)
|
||||
_TP.pynccl_comm.disabled = False
|
||||
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
||||
global _PP
|
||||
|
||||
Reference in New Issue
Block a user