[PD][NIXL] Set is_sorted=False to fix NIXL_ERR_NOT_FOUND (#7330)
This commit is contained in:
@@ -159,7 +159,7 @@ class NixlKVManager(CommonKVManager):
|
|||||||
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
||||||
):
|
):
|
||||||
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
||||||
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
|
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
|
||||||
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
||||||
if not self.kv_descs:
|
if not self.kv_descs:
|
||||||
raise Exception("NIXL memory registration failed for kv tensors")
|
raise Exception("NIXL memory registration failed for kv tensors")
|
||||||
@@ -168,7 +168,7 @@ class NixlKVManager(CommonKVManager):
|
|||||||
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
||||||
):
|
):
|
||||||
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
||||||
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
|
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
|
||||||
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
||||||
if not self.aux_descs:
|
if not self.aux_descs:
|
||||||
raise Exception("NIXL memory registration failed for aux tensors")
|
raise Exception("NIXL memory registration failed for aux tensors")
|
||||||
@@ -215,8 +215,8 @@ class NixlKVManager(CommonKVManager):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
||||||
)
|
)
|
||||||
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
|
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
|
||||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
|
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
|
||||||
# Transfer data
|
# Transfer data
|
||||||
xfer_handle = self.agent.initialize_xfer(
|
xfer_handle = self.agent.initialize_xfer(
|
||||||
"WRITE",
|
"WRITE",
|
||||||
@@ -248,8 +248,8 @@ class NixlKVManager(CommonKVManager):
|
|||||||
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
||||||
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
||||||
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
||||||
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True)
|
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
|
||||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True)
|
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
|
||||||
# Transfer data
|
# Transfer data
|
||||||
xfer_handle = self.agent.initialize_xfer(
|
xfer_handle = self.agent.initialize_xfer(
|
||||||
"WRITE",
|
"WRITE",
|
||||||
|
|||||||
Reference in New Issue
Block a user