feat: add direct routing strategy to DP worker (#6884)
This commit is contained in:
@@ -109,10 +109,12 @@ class CommonKVReceiver(BaseKVReceiver):
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
):
|
||||
self.bootstrap_room = bootstrap_room
|
||||
self.bootstrap_addr = bootstrap_addr
|
||||
self.kv_mgr = mgr
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
|
||||
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||
self.prefill_tp_size, self.prefill_dp_size = (
|
||||
@@ -180,7 +182,11 @@ class CommonKVReceiver(BaseKVReceiver):
|
||||
self.target_tp_rank = self.target_tp_ranks[0]
|
||||
self.required_dst_info_num = 1
|
||||
|
||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||
if self.data_parallel_rank is not None:
|
||||
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
||||
self.target_dp_group = self.data_parallel_rank
|
||||
else:
|
||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||
|
||||
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||
bootstrap_key = (
|
||||
|
||||
@@ -156,6 +156,7 @@ class DecodePreallocQueue:
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
data_parallel_rank=req.data_parallel_rank,
|
||||
)
|
||||
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ class FakeKVReceiver(BaseKVReceiver):
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
):
|
||||
self.has_init = False
|
||||
|
||||
|
||||
@@ -765,6 +765,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
mgr: MooncakeKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
):
|
||||
self.bootstrap_room = bootstrap_room
|
||||
self.bootstrap_addr = bootstrap_addr
|
||||
@@ -772,6 +773,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.session_id = self.kv_mgr.get_session_id()
|
||||
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
||||
self.conclude_state = None
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
|
||||
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||
self.prefill_tp_size, self.prefill_dp_size = (
|
||||
@@ -845,7 +847,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.target_tp_rank = self.target_tp_ranks[0]
|
||||
self.required_dst_info_num = 1
|
||||
|
||||
self.target_dp_group = self.bootstrap_room % self.prefill_dp_size
|
||||
if self.data_parallel_rank is not None:
|
||||
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
||||
self.target_dp_group = self.data_parallel_rank
|
||||
else:
|
||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||
|
||||
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||
bootstrap_key = (
|
||||
|
||||
@@ -407,9 +407,10 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
mgr: NixlKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
):
|
||||
self.started_transfer = False
|
||||
super().__init__(mgr, bootstrap_addr, bootstrap_room)
|
||||
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
|
||||
Reference in New Issue
Block a user