[PD][NIXL] Set is_sorted=False to fix NIXL_ERR_NOT_FOUND (#7330)
This commit is contained in:
@@ -159,7 +159,7 @@ class NixlKVManager(CommonKVManager):
|
||||
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
||||
):
|
||||
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
||||
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
|
||||
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
|
||||
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
||||
if not self.kv_descs:
|
||||
raise Exception("NIXL memory registration failed for kv tensors")
|
||||
@@ -168,7 +168,7 @@ class NixlKVManager(CommonKVManager):
|
||||
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
||||
):
|
||||
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
||||
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
|
||||
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
|
||||
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
||||
if not self.aux_descs:
|
||||
raise Exception("NIXL memory registration failed for aux tensors")
|
||||
@@ -215,8 +215,8 @@ class NixlKVManager(CommonKVManager):
|
||||
logger.debug(
|
||||
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
||||
)
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
|
||||
# Transfer data
|
||||
xfer_handle = self.agent.initialize_xfer(
|
||||
"WRITE",
|
||||
@@ -248,8 +248,8 @@ class NixlKVManager(CommonKVManager):
|
||||
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
||||
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
||||
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True)
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True)
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
|
||||
# Transfer data
|
||||
xfer_handle = self.agent.initialize_xfer(
|
||||
"WRITE",
|
||||
|
||||
Reference in New Issue
Block a user