feat: add direct routing strategy to DP worker (#6884)

This commit is contained in:
ishandhanani
2025-06-09 11:44:05 -07:00
committed by GitHub
parent 3465d7ae78
commit f1569876d5
12 changed files with 78 additions and 8 deletions

View File

@@ -109,10 +109,12 @@ class CommonKVReceiver(BaseKVReceiver):
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
@@ -180,7 +182,11 @@ class CommonKVReceiver(BaseKVReceiver):
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
self.target_dp_group = bootstrap_room % self.prefill_dp_size
if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (

View File

@@ -156,6 +156,7 @@ class DecodePreallocQueue:
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
)
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))

View File

@@ -56,6 +56,7 @@ class FakeKVReceiver(BaseKVReceiver):
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.has_init = False

View File

@@ -765,6 +765,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
mgr: MooncakeKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
@@ -772,6 +773,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
self.conclude_state = None
self.data_parallel_rank = data_parallel_rank
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
@@ -845,7 +847,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
self.target_dp_group = self.bootstrap_room % self.prefill_dp_size
if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
self.target_dp_group = self.data_parallel_rank
else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (

View File

@@ -407,9 +407,10 @@ class NixlKVReceiver(CommonKVReceiver):
mgr: NixlKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
):
self.started_transfer = False
super().__init__(mgr, bootstrap_addr, bootstrap_room)
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos: