[PD] Update prefill.py (#7190)

This commit is contained in:
Byron Hsu
2025-06-14 15:59:54 -07:00
committed by GitHub
parent ab1a4fa5cb
commit 7d316991b2
11 changed files with 458 additions and 245 deletions

View File

@@ -227,6 +227,9 @@ class ServerArgs:
disaggregation_mode: str = "null"
disaggregation_transfer_backend: str = "mooncake"
disaggregation_bootstrap_port: int = 8998
disaggregation_decode_tp: Optional[int] = None
disaggregation_decode_dp: Optional[int] = None
disaggregation_prefill_pp: Optional[int] = 1
disaggregation_ib_device: Optional[str] = None
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
pdlb_url: Optional[str] = None
@@ -505,12 +508,27 @@ class ServerArgs:
self.triton_attention_num_kv_splits = 16
# PD disaggregation
if self.disaggregation_mode == "prefill":
self.disable_cuda_graph = True
logger.warning("Cuda graph is disabled for prefill server")
elif self.disaggregation_mode == "decode":
if self.disaggregation_mode == "decode":
assert (
self.disaggregation_decode_tp is None
), "Cannot set --disaggregation-decode-tp for the decode engine."
assert (
self.disaggregation_decode_dp is None
), "Cannot set --disaggregation-decode-dp for the decode engine."
self.disable_radix_cache = True
logger.warning("KV cache is forced as chunk cache for decode server")
elif self.disaggregation_mode == "prefill":
if self.disaggregation_decode_tp is None:
self.disaggregation_decode_tp = self.tp_size
if self.disaggregation_decode_dp is None:
self.disaggregation_decode_dp = self.dp_size
self.disaggregation_prefill_pp = self.pp_size
self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp)
self.disable_cuda_graph = True
logger.warning("Cuda graph is disabled for prefill server")
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
"1" if self.enable_torch_compile else "0"
@@ -520,6 +538,14 @@ class ServerArgs:
"1" if self.disable_outlines_disk_cache else "0"
)
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
larger_tp = max(decode_tp, prefill_tp)
smaller_tp = min(decode_tp, prefill_tp)
assert larger_tp % smaller_tp == 0, (
"Different tp size is supported only when one tp is multiple of the other. "
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
)
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
@@ -1512,6 +1538,24 @@ class ServerArgs:
default=ServerArgs.disaggregation_bootstrap_port,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
parser.add_argument(
"--disaggregation-decode-tp",
type=int,
default=ServerArgs.disaggregation_decode_tp,
help="Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server.",
)
parser.add_argument(
"--disaggregation-decode-dp",
type=int,
default=ServerArgs.disaggregation_decode_dp,
help="Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server.",
)
parser.add_argument(
"--disaggregation-prefill-pp",
type=int,
default=ServerArgs.disaggregation_prefill_pp,
help="Prefill pp size. If not set, it is default to 1. This is only set on the decode server.",
)
parser.add_argument(
"--disaggregation-ib-device",
type=str,