[1/N]DP refactor: Improve dp rank scheduling in PD disaggregation mode. (#10169)
This commit is contained in:
@@ -128,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver):
|
|||||||
mgr: BaseKVManager,
|
mgr: BaseKVManager,
|
||||||
bootstrap_addr: str,
|
bootstrap_addr: str,
|
||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
data_parallel_rank: Optional[int] = None,
|
prefill_dp_rank: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.bootstrap_room = bootstrap_room
|
self.bootstrap_room = bootstrap_room
|
||||||
self.bootstrap_addr = bootstrap_addr
|
self.bootstrap_addr = bootstrap_addr
|
||||||
self.kv_mgr = mgr
|
self.kv_mgr = mgr
|
||||||
self.data_parallel_rank = data_parallel_rank
|
|
||||||
|
|
||||||
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||||
self.prefill_tp_size, self.prefill_dp_size = (
|
self.prefill_tp_size, self.prefill_dp_size = (
|
||||||
@@ -201,11 +200,14 @@ class CommonKVReceiver(BaseKVReceiver):
|
|||||||
self.target_tp_rank = self.target_tp_ranks[0]
|
self.target_tp_rank = self.target_tp_ranks[0]
|
||||||
self.required_dst_info_num = 1
|
self.required_dst_info_num = 1
|
||||||
|
|
||||||
if self.data_parallel_rank is not None:
|
if prefill_dp_rank is not None:
|
||||||
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
||||||
self.target_dp_group = self.data_parallel_rank
|
self.prefill_dp_rank = prefill_dp_rank
|
||||||
else:
|
else:
|
||||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
|
||||||
|
|
||||||
|
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
||||||
|
self.target_dp_group = self.prefill_dp_rank
|
||||||
|
|
||||||
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||||
bootstrap_key = (
|
bootstrap_key = (
|
||||||
|
|||||||
@@ -250,7 +250,7 @@ class DecodePreallocQueue:
|
|||||||
mgr=self.kv_manager,
|
mgr=self.kv_manager,
|
||||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||||
bootstrap_room=req.bootstrap_room,
|
bootstrap_room=req.bootstrap_room,
|
||||||
data_parallel_rank=req.data_parallel_rank,
|
prefill_dp_rank=req.data_parallel_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.queue.append(
|
self.queue.append(
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class FakeKVReceiver(BaseKVReceiver):
|
|||||||
mgr: BaseKVManager,
|
mgr: BaseKVManager,
|
||||||
bootstrap_addr: str,
|
bootstrap_addr: str,
|
||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
data_parallel_rank: Optional[int] = None,
|
prefill_dp_rank: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.has_init = False
|
self.has_init = False
|
||||||
|
|
||||||
|
|||||||
@@ -1212,7 +1212,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
mgr: MooncakeKVManager,
|
mgr: MooncakeKVManager,
|
||||||
bootstrap_addr: str,
|
bootstrap_addr: str,
|
||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
data_parallel_rank: Optional[int] = None,
|
prefill_dp_rank: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.bootstrap_room = bootstrap_room
|
self.bootstrap_room = bootstrap_room
|
||||||
self.bootstrap_addr = bootstrap_addr
|
self.bootstrap_addr = bootstrap_addr
|
||||||
@@ -1221,7 +1221,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
||||||
self.conclude_state = None
|
self.conclude_state = None
|
||||||
self.init_time = None
|
self.init_time = None
|
||||||
self.data_parallel_rank = data_parallel_rank
|
|
||||||
|
|
||||||
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||||
(
|
(
|
||||||
@@ -1320,11 +1319,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
||||||
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
||||||
|
|
||||||
if self.data_parallel_rank is not None:
|
if prefill_dp_rank is not None:
|
||||||
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
||||||
self.target_dp_group = self.data_parallel_rank
|
self.prefill_dp_rank = prefill_dp_rank
|
||||||
else:
|
else:
|
||||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
|
||||||
|
|
||||||
|
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
||||||
|
self.target_dp_group = self.prefill_dp_rank
|
||||||
|
|
||||||
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
||||||
self.required_prefill_response_num
|
self.required_prefill_response_num
|
||||||
|
|||||||
@@ -454,11 +454,11 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
mgr: NixlKVManager,
|
mgr: NixlKVManager,
|
||||||
bootstrap_addr: str,
|
bootstrap_addr: str,
|
||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
data_parallel_rank: Optional[int] = None,
|
prefill_dp_rank: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.started_transfer = False
|
self.started_transfer = False
|
||||||
self.conclude_state = None
|
self.conclude_state = None
|
||||||
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
||||||
|
|
||||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||||
for bootstrap_info in self.bootstrap_infos:
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ class DataParallelController:
|
|||||||
|
|
||||||
# Launch data parallel workers
|
# Launch data parallel workers
|
||||||
self.scheduler_procs = []
|
self.scheduler_procs = []
|
||||||
self.workers = [None] * server_args.dp_size
|
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
|
||||||
|
|
||||||
if server_args.enable_dp_attention:
|
if server_args.enable_dp_attention:
|
||||||
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
||||||
@@ -272,27 +272,34 @@ class DataParallelController:
|
|||||||
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||||
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
||||||
|
|
||||||
|
def maybe_external_dp_rank_routing(self, req: Req):
|
||||||
|
if req.data_parallel_rank is not None:
|
||||||
|
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
||||||
|
self.workers[req.data_parallel_rank].send_pyobj(req)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def round_robin_scheduler(self, req: Req):
|
def round_robin_scheduler(self, req: Req):
|
||||||
|
if self.maybe_external_dp_rank_routing(req):
|
||||||
|
return
|
||||||
|
|
||||||
if self.server_args.disaggregation_mode == "null":
|
if self.server_args.disaggregation_mode == "null":
|
||||||
if req.data_parallel_rank is not None:
|
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||||
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||||
self.workers[req.data_parallel_rank].send_pyobj(req)
|
self.workers
|
||||||
else:
|
)
|
||||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
|
||||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
|
||||||
self.workers
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if req.data_parallel_rank is not None:
|
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
||||||
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
|
||||||
self.workers[req.data_parallel_rank].send_pyobj(req)
|
|
||||||
else:
|
|
||||||
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
|
||||||
|
|
||||||
def shortest_queue_scheduler(self, input_requests):
|
def shortest_queue_scheduler(self, input_requests):
|
||||||
|
if self.maybe_external_dp_rank_routing(req):
|
||||||
|
return
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def minimum_tokens_scheduler(self, req):
|
def minimum_tokens_scheduler(self, req):
|
||||||
|
if self.maybe_external_dp_rank_routing(req):
|
||||||
|
return
|
||||||
|
|
||||||
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
|
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
|
||||||
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
|
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
|
||||||
def get_next_global_balance_id() -> int:
|
def get_next_global_balance_id() -> int:
|
||||||
|
|||||||
@@ -450,9 +450,7 @@ class MultiTokenizerManager(TokenizerManager):
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
):
|
):
|
||||||
setproctitle.setproctitle(
|
setproctitle.setproctitle(f"sglang::tokenizer_worker:{os.getpid()}")
|
||||||
f"sglang::http_server/multi_tokenizer_manager:{os.getpid()}"
|
|
||||||
)
|
|
||||||
# prevent init prefill bootstrapserver again
|
# prevent init prefill bootstrapserver again
|
||||||
disaggregation_mode = server_args.disaggregation_mode
|
disaggregation_mode = server_args.disaggregation_mode
|
||||||
server_args.disaggregation_mode = "null"
|
server_args.disaggregation_mode = "null"
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
|
|||||||
is_valid_ipv6_address,
|
is_valid_ipv6_address,
|
||||||
nullable_str,
|
nullable_str,
|
||||||
)
|
)
|
||||||
|
from sglang.utils import is_in_ci
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -223,6 +224,8 @@ class ServerArgs:
|
|||||||
# Data parallelism
|
# Data parallelism
|
||||||
dp_size: int = 1
|
dp_size: int = 1
|
||||||
load_balance_method: str = "round_robin"
|
load_balance_method: str = "round_robin"
|
||||||
|
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
|
||||||
|
prefill_round_robin_balance: bool = False
|
||||||
|
|
||||||
# Multi-node distributed serving
|
# Multi-node distributed serving
|
||||||
dist_init_addr: Optional[str] = None
|
dist_init_addr: Optional[str] = None
|
||||||
@@ -623,12 +626,12 @@ class ServerArgs:
|
|||||||
if self.grammar_backend is None:
|
if self.grammar_backend is None:
|
||||||
self.grammar_backend = "xgrammar"
|
self.grammar_backend = "xgrammar"
|
||||||
|
|
||||||
|
if self.dp_size == 1:
|
||||||
|
self.enable_dp_attention = False
|
||||||
|
|
||||||
# Data parallelism attention
|
# Data parallelism attention
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||||
assert (
|
|
||||||
self.dp_size > 1
|
|
||||||
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
|
|
||||||
assert self.tp_size % self.dp_size == 0
|
assert self.tp_size % self.dp_size == 0
|
||||||
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
|
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -807,6 +810,13 @@ class ServerArgs:
|
|||||||
|
|
||||||
self.disable_radix_cache = True
|
self.disable_radix_cache = True
|
||||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||||
|
|
||||||
|
if self.dp_size > 1 and not is_in_ci():
|
||||||
|
assert self.prefill_round_robin_balance, (
|
||||||
|
"Prefill round robin balance is required when dp size > 1. "
|
||||||
|
"Please make sure that the prefill instance is launched with `--load-balance-method round_robin`"
|
||||||
|
" and `--prefill-round-robin-balance` is set for decode server."
|
||||||
|
)
|
||||||
elif self.disaggregation_mode == "prefill":
|
elif self.disaggregation_mode == "prefill":
|
||||||
if self.disaggregation_decode_tp is None:
|
if self.disaggregation_decode_tp is None:
|
||||||
self.disaggregation_decode_tp = self.tp_size
|
self.disaggregation_decode_tp = self.tp_size
|
||||||
@@ -1384,6 +1394,12 @@ class ServerArgs:
|
|||||||
"minimum_tokens",
|
"minimum_tokens",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill-round-robin-balance",
|
||||||
|
default=ServerArgs.prefill_round_robin_balance,
|
||||||
|
action="store_true",
|
||||||
|
help="Prefill is round robin balanced. This is used to promise decode server can get the correct dp rank.",
|
||||||
|
)
|
||||||
|
|
||||||
# Multi-node distributed serving
|
# Multi-node distributed serving
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user