[PD] Update prefill.py (#7190)
This commit is contained in:
@@ -227,6 +227,9 @@ class ServerArgs:
|
||||
disaggregation_mode: str = "null"
|
||||
disaggregation_transfer_backend: str = "mooncake"
|
||||
disaggregation_bootstrap_port: int = 8998
|
||||
disaggregation_decode_tp: Optional[int] = None
|
||||
disaggregation_decode_dp: Optional[int] = None
|
||||
disaggregation_prefill_pp: Optional[int] = 1
|
||||
disaggregation_ib_device: Optional[str] = None
|
||||
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
||||
pdlb_url: Optional[str] = None
|
||||
@@ -505,12 +508,27 @@ class ServerArgs:
|
||||
self.triton_attention_num_kv_splits = 16
|
||||
|
||||
# PD disaggregation
|
||||
if self.disaggregation_mode == "prefill":
|
||||
self.disable_cuda_graph = True
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
elif self.disaggregation_mode == "decode":
|
||||
if self.disaggregation_mode == "decode":
|
||||
assert (
|
||||
self.disaggregation_decode_tp is None
|
||||
), "Cannot set --disaggregation-decode-tp for the decode engine."
|
||||
assert (
|
||||
self.disaggregation_decode_dp is None
|
||||
), "Cannot set --disaggregation-decode-dp for the decode engine."
|
||||
|
||||
self.disable_radix_cache = True
|
||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||
elif self.disaggregation_mode == "prefill":
|
||||
if self.disaggregation_decode_tp is None:
|
||||
self.disaggregation_decode_tp = self.tp_size
|
||||
if self.disaggregation_decode_dp is None:
|
||||
self.disaggregation_decode_dp = self.dp_size
|
||||
|
||||
self.disaggregation_prefill_pp = self.pp_size
|
||||
self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp)
|
||||
|
||||
self.disable_cuda_graph = True
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
|
||||
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
||||
"1" if self.enable_torch_compile else "0"
|
||||
@@ -520,6 +538,14 @@ class ServerArgs:
|
||||
"1" if self.disable_outlines_disk_cache else "0"
|
||||
)
|
||||
|
||||
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
||||
larger_tp = max(decode_tp, prefill_tp)
|
||||
smaller_tp = min(decode_tp, prefill_tp)
|
||||
assert larger_tp % smaller_tp == 0, (
|
||||
"Different tp size is supported only when one tp is multiple of the other. "
|
||||
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Model and port args
|
||||
@@ -1512,6 +1538,24 @@ class ServerArgs:
|
||||
default=ServerArgs.disaggregation_bootstrap_port,
|
||||
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-decode-tp",
|
||||
type=int,
|
||||
default=ServerArgs.disaggregation_decode_tp,
|
||||
help="Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-decode-dp",
|
||||
type=int,
|
||||
default=ServerArgs.disaggregation_decode_dp,
|
||||
help="Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-prefill-pp",
|
||||
type=int,
|
||||
default=ServerArgs.disaggregation_prefill_pp,
|
||||
help="Prefill pp size. If not set, it is default to 1. This is only set on the decode server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-ib-device",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user