[PD] Release initial code (#4654)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Ying1123 <sqy1415@gmail.com> Co-authored-by: merrymercy <lianminzheng@gmail.com> Co-authored-by: makro Co-authored-by: dhou-xai
This commit is contained in:
@@ -185,6 +185,10 @@ class ServerArgs:
|
||||
debug_tensor_dump_input_file: Optional[str] = None
|
||||
debug_tensor_dump_inject: bool = False
|
||||
|
||||
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
||||
disaggregation_mode: str = "null"
|
||||
disaggregation_bootstrap_port: int = 8998
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
@@ -325,6 +329,18 @@ class ServerArgs:
|
||||
if is_hip():
|
||||
self.triton_attention_num_kv_splits = 16
|
||||
|
||||
# PD disaggregation
|
||||
if self.disaggregation_mode == "prefill":
|
||||
self.disable_cuda_graph = True
|
||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning("Overlap scheduler is disabled for prefill server")
|
||||
elif self.disaggregation_mode == "decode":
|
||||
self.disable_radix_cache = True
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning("Overlap scheduler is disabled for decode server")
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Model and port args
|
||||
@@ -1063,6 +1079,21 @@ class ServerArgs:
|
||||
help="Inject the outputs from jax as the input of every layer.",
|
||||
)
|
||||
|
||||
# Disaggregation
|
||||
parser.add_argument(
|
||||
"--disaggregation-mode",
|
||||
type=str,
|
||||
default="null",
|
||||
choices=["null", "prefill", "decode"],
|
||||
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-bootstrap-port",
|
||||
type=int,
|
||||
default=ServerArgs.disaggregation_bootstrap_port,
|
||||
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
|
||||
Reference in New Issue
Block a user