[1/N]DP refactor: Improve dp rank scheduling in PD disaggregation mode. (#10169)

This commit is contained in:
Liangsheng Yin
2025-09-09 12:27:55 +08:00
committed by GitHub
parent 2fe17735a6
commit 83d55ac51f
8 changed files with 61 additions and 36 deletions

View File

@@ -44,6 +44,7 @@ from sglang.srt.utils import (
is_valid_ipv6_address,
nullable_str,
)
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
@@ -223,6 +224,8 @@ class ServerArgs:
# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
prefill_round_robin_balance: bool = False
# Multi-node distributed serving
dist_init_addr: Optional[str] = None
@@ -623,12 +626,12 @@ class ServerArgs:
if self.grammar_backend is None:
self.grammar_backend = "xgrammar"
if self.dp_size == 1:
self.enable_dp_attention = False
# Data parallelism attention
if self.enable_dp_attention:
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
assert (
self.dp_size > 1
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
logger.warning(
@@ -807,6 +810,13 @@ class ServerArgs:
self.disable_radix_cache = True
logger.warning("KV cache is forced as chunk cache for decode server")
if self.dp_size > 1 and not is_in_ci():
assert self.prefill_round_robin_balance, (
"Prefill round robin balance is required when dp size > 1. "
"Please make sure that the prefill instance is launched with `--load-balance-method round_robin`"
" and `--prefill-round-robin-balance` is set for decode server."
)
elif self.disaggregation_mode == "prefill":
if self.disaggregation_decode_tp is None:
self.disaggregation_decode_tp = self.tp_size
@@ -1384,6 +1394,12 @@ class ServerArgs:
"minimum_tokens",
],
)
parser.add_argument(
"--prefill-round-robin-balance",
default=ServerArgs.prefill_round_robin_balance,
action="store_true",
help="Prefill is round robin balanced. This is used to promise decode server can get the correct dp rank.",
)
# Multi-node distributed serving
parser.add_argument(