[PD] use int32 for kv indices & get num_reserved_decode_tokens from server_args (#7214)

This commit is contained in:
Byron Hsu
2025-06-15 11:51:03 -07:00
committed by GitHub
parent fff10809bf
commit 88f9c347b2
8 changed files with 24 additions and 26 deletions

View File

@@ -44,7 +44,7 @@ class TransferInfo:
agent_metadata: bytes
agent_name: str
dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64]
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_ptrs: list[int]
dst_aux_index: int
dst_gpu_id: int
@@ -62,7 +62,7 @@ class TransferInfo:
agent_metadata=msg[3],
agent_name=msg[4].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
dst_aux_index=int(msg[8].decode("ascii")),
dst_gpu_id=int(msg[9].decode("ascii")),
@@ -162,9 +162,9 @@ class NixlKVManager(CommonKVManager):
def send_kvcache(
self,
peer_name: str,
prefill_kv_indices: npt.NDArray[np.int64],
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64],
dst_kv_indices: npt.NDArray[np.int32],
dst_gpu_id: int,
notif: str,
):
@@ -246,7 +246,7 @@ class NixlKVManager(CommonKVManager):
def add_transfer_request(
self,
bootstrap_room: int,
kv_indices: npt.NDArray[np.int64],
kv_indices: npt.NDArray[np.int32],
index_slice: slice,
is_last: bool,
chunk_id: int,
@@ -373,7 +373,7 @@ class NixlKVSender(BaseKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int64],
kv_indices: npt.NDArray[np.int32],
):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
@@ -417,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver):
self.started_transfer = False
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"