[Feature]TP Group Switching for PD-Multiplexing (#7653)
This commit is contained in:
@@ -1065,8 +1065,23 @@ def init_model_parallel_group(
|
|||||||
|
|
||||||
_TP: Optional[GroupCoordinator] = None
|
_TP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
|
# duplicate GroupCoordinator for prefill in PD-Multiplexing
|
||||||
|
_PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
|
_ENABLE_PDMUX_P_TP: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def set_pdmux_status(enable_prefill_multiplexing: bool):
|
||||||
|
global _ENABLE_PDMUX_P_TP
|
||||||
|
_ENABLE_PDMUX_P_TP = enable_prefill_multiplexing
|
||||||
|
|
||||||
|
|
||||||
def get_tp_group() -> GroupCoordinator:
|
def get_tp_group() -> GroupCoordinator:
|
||||||
|
if _ENABLE_PDMUX_P_TP:
|
||||||
|
assert (
|
||||||
|
_PDMUX_PREFILL_TP_GROUP is not None
|
||||||
|
), "tensor model parallel group for PD-Multiplexing Prefill is not initialized"
|
||||||
|
return _PDMUX_PREFILL_TP_GROUP
|
||||||
assert _TP is not None, "tensor model parallel group is not initialized"
|
assert _TP is not None, "tensor model parallel group is not initialized"
|
||||||
return _TP
|
return _TP
|
||||||
|
|
||||||
@@ -1182,6 +1197,7 @@ def initialize_model_parallel(
|
|||||||
tensor_model_parallel_size: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
pipeline_model_parallel_size: int = 1,
|
pipeline_model_parallel_size: int = 1,
|
||||||
backend: Optional[str] = None,
|
backend: Optional[str] = None,
|
||||||
|
duplicate_tp_group: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize model parallel groups.
|
Initialize model parallel groups.
|
||||||
@@ -1239,6 +1255,23 @@ def initialize_model_parallel(
|
|||||||
group_name="tp",
|
group_name="tp",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if duplicate_tp_group:
|
||||||
|
global _PDMUX_PREFILL_TP_GROUP
|
||||||
|
assert (
|
||||||
|
_PDMUX_PREFILL_TP_GROUP is None
|
||||||
|
), "tensor model parallel group for PD-Multiplexing Prefill is already initialized"
|
||||||
|
_PDMUX_PREFILL_TP_GROUP = init_model_parallel_group(
|
||||||
|
group_ranks,
|
||||||
|
get_world_group().local_rank,
|
||||||
|
backend,
|
||||||
|
use_message_queue_broadcaster=get_bool_env_var(
|
||||||
|
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
|
||||||
|
),
|
||||||
|
group_name="pdmux_prefill_tp",
|
||||||
|
)
|
||||||
|
_TP.pynccl_comm.disabled = False
|
||||||
|
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
||||||
|
|
||||||
# Build the pipeline model-parallel groups.
|
# Build the pipeline model-parallel groups.
|
||||||
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
||||||
global _PP
|
global _PP
|
||||||
|
|||||||
@@ -539,6 +539,7 @@ class ModelRunner:
|
|||||||
initialize_model_parallel(
|
initialize_model_parallel(
|
||||||
tensor_model_parallel_size=self.tp_size,
|
tensor_model_parallel_size=self.tp_size,
|
||||||
pipeline_model_parallel_size=self.pp_size,
|
pipeline_model_parallel_size=self.pp_size,
|
||||||
|
duplicate_tp_group=self.server_args.enable_pdmux,
|
||||||
)
|
)
|
||||||
initialize_dp_attention(
|
initialize_dp_attention(
|
||||||
enable_dp_attention=self.server_args.enable_dp_attention,
|
enable_dp_attention=self.server_args.enable_dp_attention,
|
||||||
|
|||||||
@@ -251,6 +251,10 @@ class ServerArgs:
|
|||||||
custom_weight_loader: Optional[List[str]] = None
|
custom_weight_loader: Optional[List[str]] = None
|
||||||
weight_loader_disable_mmap: bool = False
|
weight_loader_disable_mmap: bool = False
|
||||||
|
|
||||||
|
# For PD-Multiplexing
|
||||||
|
enable_pdmux: bool = False
|
||||||
|
sm_group_num: int = 3
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
if self.enable_ep_moe:
|
if self.enable_ep_moe:
|
||||||
@@ -1721,6 +1725,17 @@ class ServerArgs:
|
|||||||
default=None,
|
default=None,
|
||||||
help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
|
help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-pdmux",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable PD-Multiplexing, PD running on greenctx stream.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sm-group-num",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.sm_group_num,
|
||||||
|
help="Number of sm partition groups.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--weight-loader-disable-mmap",
|
"--weight-loader-disable-mmap",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user