[PD] Release initial code (#4654)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Ying1123 <sqy1415@gmail.com>
Co-authored-by: merrymercy <lianminzheng@gmail.com>
Co-authored-by: makro
Co-authored-by: dhou-xai
This commit is contained in:
Byron Hsu
2025-03-21 14:47:47 -07:00
committed by GitHub
parent 417fc72f6f
commit c7c7dbebbe
10 changed files with 1410 additions and 9 deletions

View File

@@ -185,6 +185,10 @@ class ServerArgs:
debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode: str = "null"
disaggregation_bootstrap_port: int = 8998
def __post_init__(self):
# Set missing default values
if self.tokenizer_path is None:
@@ -325,6 +329,18 @@ class ServerArgs:
if is_hip():
self.triton_attention_num_kv_splits = 16
# PD disaggregation
if self.disaggregation_mode == "prefill":
self.disable_cuda_graph = True
logger.warning("KV cache is forced as chunk cache for decode server")
self.disable_overlap_schedule = True
logger.warning("Overlap scheduler is disabled for prefill server")
elif self.disaggregation_mode == "decode":
self.disable_radix_cache = True
logger.warning("Cuda graph is disabled for prefill server")
self.disable_overlap_schedule = True
logger.warning("Overlap scheduler is disabled for decode server")
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
@@ -1063,6 +1079,21 @@ class ServerArgs:
help="Inject the outputs from jax as the input of every layer.",
)
# Disaggregation
parser.add_argument(
"--disaggregation-mode",
type=str,
default="null",
choices=["null", "prefill", "decode"],
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
)
parser.add_argument(
"--disaggregation-bootstrap-port",
type=int,
default=ServerArgs.disaggregation_bootstrap_port,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size