[1/N]DP refactor: Improve dp rank scheduling in PD disaggregation mode. (#10169)
This commit is contained in:
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
|
||||
is_valid_ipv6_address,
|
||||
nullable_str,
|
||||
)
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -223,6 +224,8 @@ class ServerArgs:
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
load_balance_method: str = "round_robin"
|
||||
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
|
||||
prefill_round_robin_balance: bool = False
|
||||
|
||||
# Multi-node distributed serving
|
||||
dist_init_addr: Optional[str] = None
|
||||
@@ -623,12 +626,12 @@ class ServerArgs:
|
||||
if self.grammar_backend is None:
|
||||
self.grammar_backend = "xgrammar"
|
||||
|
||||
if self.dp_size == 1:
|
||||
self.enable_dp_attention = False
|
||||
|
||||
# Data parallelism attention
|
||||
if self.enable_dp_attention:
|
||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||
assert (
|
||||
self.dp_size > 1
|
||||
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
|
||||
assert self.tp_size % self.dp_size == 0
|
||||
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
|
||||
logger.warning(
|
||||
@@ -807,6 +810,13 @@ class ServerArgs:
|
||||
|
||||
self.disable_radix_cache = True
|
||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||
|
||||
if self.dp_size > 1 and not is_in_ci():
|
||||
assert self.prefill_round_robin_balance, (
|
||||
"Prefill round robin balance is required when dp size > 1. "
|
||||
"Please make sure that the prefill instance is launched with `--load-balance-method round_robin`"
|
||||
" and `--prefill-round-robin-balance` is set for decode server."
|
||||
)
|
||||
elif self.disaggregation_mode == "prefill":
|
||||
if self.disaggregation_decode_tp is None:
|
||||
self.disaggregation_decode_tp = self.tp_size
|
||||
@@ -1384,6 +1394,12 @@ class ServerArgs:
|
||||
"minimum_tokens",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-round-robin-balance",
|
||||
default=ServerArgs.prefill_round_robin_balance,
|
||||
action="store_true",
|
||||
help="Prefill is round robin balanced. This is used to promise decode server can get the correct dp rank.",
|
||||
)
|
||||
|
||||
# Multi-node distributed serving
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user