[PD] use int32 for kv indices & get num_reserved_decode_tokens from server_args (#7214)
This commit is contained in:
@@ -44,7 +44,7 @@ class TransferInfo:
|
||||
agent_metadata: bytes
|
||||
agent_name: str
|
||||
dst_kv_ptrs: list[int]
|
||||
dst_kv_indices: npt.NDArray[np.int64]
|
||||
dst_kv_indices: npt.NDArray[np.int32]
|
||||
dst_aux_ptrs: list[int]
|
||||
dst_aux_index: int
|
||||
dst_gpu_id: int
|
||||
@@ -62,7 +62,7 @@ class TransferInfo:
|
||||
agent_metadata=msg[3],
|
||||
agent_name=msg[4].decode("ascii"),
|
||||
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
||||
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
|
||||
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
|
||||
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
|
||||
dst_aux_index=int(msg[8].decode("ascii")),
|
||||
dst_gpu_id=int(msg[9].decode("ascii")),
|
||||
@@ -162,9 +162,9 @@ class NixlKVManager(CommonKVManager):
|
||||
def send_kvcache(
|
||||
self,
|
||||
peer_name: str,
|
||||
prefill_kv_indices: npt.NDArray[np.int64],
|
||||
prefill_kv_indices: npt.NDArray[np.int32],
|
||||
dst_kv_ptrs: list[int],
|
||||
dst_kv_indices: npt.NDArray[np.int64],
|
||||
dst_kv_indices: npt.NDArray[np.int32],
|
||||
dst_gpu_id: int,
|
||||
notif: str,
|
||||
):
|
||||
@@ -246,7 +246,7 @@ class NixlKVManager(CommonKVManager):
|
||||
def add_transfer_request(
|
||||
self,
|
||||
bootstrap_room: int,
|
||||
kv_indices: npt.NDArray[np.int64],
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
index_slice: slice,
|
||||
is_last: bool,
|
||||
chunk_id: int,
|
||||
@@ -373,7 +373,7 @@ class NixlKVSender(BaseKVSender):
|
||||
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int64],
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
):
|
||||
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
||||
self.curr_idx += len(kv_indices)
|
||||
@@ -417,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
self.started_transfer = False
|
||||
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
self.prefill_server_url = (
|
||||
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
||||
|
||||
Reference in New Issue
Block a user