[PP] Add pipeline parallelism (#5724)

This commit is contained in:
Ying Sheng
2025-04-30 18:18:07 -07:00
committed by GitHub
parent e97e57e699
commit 11383cec3c
25 changed files with 1150 additions and 308 deletions

View File

@@ -78,6 +78,8 @@ class ServerArgs:
# Other runtime options
tp_size: int = 1
pp_size: int = 1
max_micro_batch_size: Optional[int] = None
stream_interval: int = 1
stream_output: bool = False
random_seed: Optional[int] = None
@@ -222,14 +224,18 @@ class ServerArgs:
# Set mem fraction static, which depends on the tensor parallelism size
if self.mem_fraction_static is None:
if self.tp_size >= 16:
self.mem_fraction_static = 0.79
elif self.tp_size >= 8:
self.mem_fraction_static = 0.81
elif self.tp_size >= 4:
self.mem_fraction_static = 0.85
elif self.tp_size >= 2:
self.mem_fraction_static = 0.87
parallel_size = self.tp_size * self.pp_size
if gpu_mem <= 81920:
if parallel_size >= 16:
self.mem_fraction_static = 0.79
elif parallel_size >= 8:
self.mem_fraction_static = 0.81
elif parallel_size >= 4:
self.mem_fraction_static = 0.85
elif parallel_size >= 2:
self.mem_fraction_static = 0.87
else:
self.mem_fraction_static = 0.88
else:
self.mem_fraction_static = 0.88
if gpu_mem > 96 * 1024:
@@ -244,6 +250,8 @@ class ServerArgs:
if self.chunked_prefill_size is None:
if gpu_mem is not None and gpu_mem < 25_000:
self.chunked_prefill_size = 2048
elif self.disaggregation_mode != "null":
self.chunked_prefill_size = 16384
else:
self.chunked_prefill_size = 8192
assert self.chunked_prefill_size % self.page_size == 0
@@ -643,6 +651,19 @@ class ServerArgs:
default=ServerArgs.tp_size,
help="The tensor parallelism size.",
)
parser.add_argument(
"--pipeline-parallel-size",
"--pp-size",
type=int,
default=ServerArgs.pp_size,
help="The pipeline parallelism size.",
)
parser.add_argument(
"--max-micro-batch-size",
type=int,
default=ServerArgs.max_micro_batch_size,
help="The maximum micro batch size in pipeline parallelism.",
)
parser.add_argument(
"--stream-interval",
type=int,
@@ -1232,6 +1253,7 @@ class ServerArgs:
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
args.pp_size = args.pipeline_parallel_size
args.dp_size = args.data_parallel_size
args.ep_size = args.expert_parallel_size
attrs = [attr.name for attr in dataclasses.fields(cls)]
@@ -1245,8 +1267,19 @@ class ServerArgs:
def check_server_args(self):
assert (
self.tp_size % self.nnodes == 0
), "tp_size must be divisible by number of nodes"
self.tp_size * self.pp_size
) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
# FIXME pp constraints
if self.pp_size > 1:
logger.warning(f"Turn off overlap scheule for pipeline parallelism.")
self.disable_overlap_schedule = True
assert (
self.disable_overlap_schedule
and self.speculative_algorithm is None
and not self.enable_mixed_chunk
), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."
assert not (
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
), "multi-node data parallel is not supported unless dp attention!"