[PP] Add pipeline parallelism (#5724)
This commit is contained in:
@@ -78,6 +78,8 @@ class ServerArgs:
|
||||
|
||||
# Other runtime options
|
||||
tp_size: int = 1
|
||||
pp_size: int = 1
|
||||
max_micro_batch_size: Optional[int] = None
|
||||
stream_interval: int = 1
|
||||
stream_output: bool = False
|
||||
random_seed: Optional[int] = None
|
||||
@@ -222,14 +224,18 @@ class ServerArgs:
|
||||
|
||||
# Set mem fraction static, which depends on the tensor parallelism size
|
||||
if self.mem_fraction_static is None:
|
||||
if self.tp_size >= 16:
|
||||
self.mem_fraction_static = 0.79
|
||||
elif self.tp_size >= 8:
|
||||
self.mem_fraction_static = 0.81
|
||||
elif self.tp_size >= 4:
|
||||
self.mem_fraction_static = 0.85
|
||||
elif self.tp_size >= 2:
|
||||
self.mem_fraction_static = 0.87
|
||||
parallel_size = self.tp_size * self.pp_size
|
||||
if gpu_mem <= 81920:
|
||||
if parallel_size >= 16:
|
||||
self.mem_fraction_static = 0.79
|
||||
elif parallel_size >= 8:
|
||||
self.mem_fraction_static = 0.81
|
||||
elif parallel_size >= 4:
|
||||
self.mem_fraction_static = 0.85
|
||||
elif parallel_size >= 2:
|
||||
self.mem_fraction_static = 0.87
|
||||
else:
|
||||
self.mem_fraction_static = 0.88
|
||||
else:
|
||||
self.mem_fraction_static = 0.88
|
||||
if gpu_mem > 96 * 1024:
|
||||
@@ -244,6 +250,8 @@ class ServerArgs:
|
||||
if self.chunked_prefill_size is None:
|
||||
if gpu_mem is not None and gpu_mem < 25_000:
|
||||
self.chunked_prefill_size = 2048
|
||||
elif self.disaggregation_mode != "null":
|
||||
self.chunked_prefill_size = 16384
|
||||
else:
|
||||
self.chunked_prefill_size = 8192
|
||||
assert self.chunked_prefill_size % self.page_size == 0
|
||||
@@ -643,6 +651,19 @@ class ServerArgs:
|
||||
default=ServerArgs.tp_size,
|
||||
help="The tensor parallelism size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pipeline-parallel-size",
|
||||
"--pp-size",
|
||||
type=int,
|
||||
default=ServerArgs.pp_size,
|
||||
help="The pipeline parallelism size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-micro-batch-size",
|
||||
type=int,
|
||||
default=ServerArgs.max_micro_batch_size,
|
||||
help="The maximum micro batch size in pipeline parallelism.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stream-interval",
|
||||
type=int,
|
||||
@@ -1232,6 +1253,7 @@ class ServerArgs:
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
args.pp_size = args.pipeline_parallel_size
|
||||
args.dp_size = args.data_parallel_size
|
||||
args.ep_size = args.expert_parallel_size
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
@@ -1245,8 +1267,19 @@ class ServerArgs:
|
||||
|
||||
def check_server_args(self):
|
||||
assert (
|
||||
self.tp_size % self.nnodes == 0
|
||||
), "tp_size must be divisible by number of nodes"
|
||||
self.tp_size * self.pp_size
|
||||
) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
|
||||
|
||||
# FIXME pp constraints
|
||||
if self.pp_size > 1:
|
||||
logger.warning(f"Turn off overlap scheule for pipeline parallelism.")
|
||||
self.disable_overlap_schedule = True
|
||||
assert (
|
||||
self.disable_overlap_schedule
|
||||
and self.speculative_algorithm is None
|
||||
and not self.enable_mixed_chunk
|
||||
), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."
|
||||
|
||||
assert not (
|
||||
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
|
||||
), "multi-node data parallel is not supported unless dp attention!"
|
||||
|
||||
Reference in New Issue
Block a user