[1/N]DP refactor: Improve dp rank scheduling in PD disaggregation mode. (#10169)
This commit is contained in:
@@ -128,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver):
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
prefill_dp_rank: Optional[int] = None,
|
||||
):
|
||||
self.bootstrap_room = bootstrap_room
|
||||
self.bootstrap_addr = bootstrap_addr
|
||||
self.kv_mgr = mgr
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
|
||||
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||
self.prefill_tp_size, self.prefill_dp_size = (
|
||||
@@ -201,11 +200,14 @@ class CommonKVReceiver(BaseKVReceiver):
|
||||
self.target_tp_rank = self.target_tp_ranks[0]
|
||||
self.required_dst_info_num = 1
|
||||
|
||||
if self.data_parallel_rank is not None:
|
||||
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
||||
self.target_dp_group = self.data_parallel_rank
|
||||
if prefill_dp_rank is not None:
|
||||
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
||||
self.prefill_dp_rank = prefill_dp_rank
|
||||
else:
|
||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
|
||||
|
||||
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
||||
self.target_dp_group = self.prefill_dp_rank
|
||||
|
||||
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||
bootstrap_key = (
|
||||
|
||||
@@ -250,7 +250,7 @@ class DecodePreallocQueue:
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
data_parallel_rank=req.data_parallel_rank,
|
||||
prefill_dp_rank=req.data_parallel_rank,
|
||||
)
|
||||
|
||||
self.queue.append(
|
||||
|
||||
@@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver):
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
prefill_dp_rank: Optional[int] = None,
|
||||
):
|
||||
self.has_init = False
|
||||
|
||||
|
||||
@@ -1212,7 +1212,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
mgr: MooncakeKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
prefill_dp_rank: Optional[int] = None,
|
||||
):
|
||||
self.bootstrap_room = bootstrap_room
|
||||
self.bootstrap_addr = bootstrap_addr
|
||||
@@ -1221,7 +1221,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
||||
self.conclude_state = None
|
||||
self.init_time = None
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
|
||||
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||
(
|
||||
@@ -1320,11 +1319,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
||||
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
||||
|
||||
if self.data_parallel_rank is not None:
|
||||
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
||||
self.target_dp_group = self.data_parallel_rank
|
||||
if prefill_dp_rank is not None:
|
||||
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
||||
self.prefill_dp_rank = prefill_dp_rank
|
||||
else:
|
||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
|
||||
|
||||
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
||||
self.target_dp_group = self.prefill_dp_rank
|
||||
|
||||
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
||||
self.required_prefill_response_num
|
||||
|
||||
@@ -454,11 +454,11 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
mgr: NixlKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
prefill_dp_rank: Optional[int] = None,
|
||||
):
|
||||
self.started_transfer = False
|
||||
self.conclude_state = None
|
||||
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
||||
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
|
||||
@@ -106,7 +106,7 @@ class DataParallelController:
|
||||
|
||||
# Launch data parallel workers
|
||||
self.scheduler_procs = []
|
||||
self.workers = [None] * server_args.dp_size
|
||||
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
||||
|
||||
if server_args.enable_dp_attention:
|
||||
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
||||
@@ -272,27 +272,34 @@ class DataParallelController:
|
||||
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
||||
|
||||
def maybe_external_dp_rank_routing(self, req: Req):
|
||||
if req.data_parallel_rank is not None:
|
||||
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
||||
self.workers[req.data_parallel_rank].send_pyobj(req)
|
||||
return True
|
||||
return False
|
||||
|
||||
def round_robin_scheduler(self, req: Req):
|
||||
if self.maybe_external_dp_rank_routing(req):
|
||||
return
|
||||
|
||||
if self.server_args.disaggregation_mode == "null":
|
||||
if req.data_parallel_rank is not None:
|
||||
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
||||
self.workers[req.data_parallel_rank].send_pyobj(req)
|
||||
else:
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||
self.workers
|
||||
)
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||
self.workers
|
||||
)
|
||||
else:
|
||||
if req.data_parallel_rank is not None:
|
||||
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
||||
self.workers[req.data_parallel_rank].send_pyobj(req)
|
||||
else:
|
||||
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
||||
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
||||
|
||||
def shortest_queue_scheduler(self, input_requests):
|
||||
if self.maybe_external_dp_rank_routing(req):
|
||||
return
|
||||
raise NotImplementedError()
|
||||
|
||||
def minimum_tokens_scheduler(self, req):
|
||||
if self.maybe_external_dp_rank_routing(req):
|
||||
return
|
||||
|
||||
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
|
||||
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
|
||||
def get_next_global_balance_id() -> int:
|
||||
|
||||
@@ -450,9 +450,7 @@ class MultiTokenizerManager(TokenizerManager):
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
setproctitle.setproctitle(
|
||||
f"sglang::http_server/multi_tokenizer_manager:{os.getpid()}"
|
||||
)
|
||||
setproctitle.setproctitle(f"sglang::tokenizer_worker:{os.getpid()}")
|
||||
# prevent init prefill bootstrapserver again
|
||||
disaggregation_mode = server_args.disaggregation_mode
|
||||
server_args.disaggregation_mode = "null"
|
||||
|
||||
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
|
||||
is_valid_ipv6_address,
|
||||
nullable_str,
|
||||
)
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -223,6 +224,8 @@ class ServerArgs:
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
load_balance_method: str = "round_robin"
|
||||
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
|
||||
prefill_round_robin_balance: bool = False
|
||||
|
||||
# Multi-node distributed serving
|
||||
dist_init_addr: Optional[str] = None
|
||||
@@ -623,12 +626,12 @@ class ServerArgs:
|
||||
if self.grammar_backend is None:
|
||||
self.grammar_backend = "xgrammar"
|
||||
|
||||
if self.dp_size == 1:
|
||||
self.enable_dp_attention = False
|
||||
|
||||
# Data parallelism attention
|
||||
if self.enable_dp_attention:
|
||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||
assert (
|
||||
self.dp_size > 1
|
||||
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
|
||||
assert self.tp_size % self.dp_size == 0
|
||||
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
|
||||
logger.warning(
|
||||
@@ -807,6 +810,13 @@ class ServerArgs:
|
||||
|
||||
self.disable_radix_cache = True
|
||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||
|
||||
if self.dp_size > 1 and not is_in_ci():
|
||||
assert self.prefill_round_robin_balance, (
|
||||
"Prefill round robin balance is required when dp size > 1. "
|
||||
"Please make sure that the prefill instance is launched with `--load-balance-method round_robin`"
|
||||
" and `--prefill-round-robin-balance` is set for decode server."
|
||||
)
|
||||
elif self.disaggregation_mode == "prefill":
|
||||
if self.disaggregation_decode_tp is None:
|
||||
self.disaggregation_decode_tp = self.tp_size
|
||||
@@ -1384,6 +1394,12 @@ class ServerArgs:
|
||||
"minimum_tokens",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-round-robin-balance",
|
||||
default=ServerArgs.prefill_round_robin_balance,
|
||||
action="store_true",
|
||||
help="Prefill is round robin balanced. This is used to promise decode server can get the correct dp rank.",
|
||||
)
|
||||
|
||||
# Multi-node distributed serving
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user