diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 3dd8975c5..73f32c0a6 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -159,7 +159,7 @@ class NixlKVManager(CommonKVManager): self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens ): kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, "")) - self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True) + self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False) logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}") if not self.kv_descs: raise Exception("NIXL memory registration failed for kv tensors") @@ -168,7 +168,7 @@ class NixlKVManager(CommonKVManager): self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens ): aux_addrs.append((aux_data_ptr, aux_data_len, 0, "")) - self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True) + self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False) logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}") if not self.aux_descs: raise Exception("NIXL memory registration failed for aux tensors") @@ -215,8 +215,8 @@ class NixlKVManager(CommonKVManager): logger.debug( f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}" ) - src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True) - dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True) + src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False) + dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False) # Transfer data xfer_handle = self.agent.initialize_xfer( "WRITE", @@ -248,8 +248,8 @@ class NixlKVManager(CommonKVManager): decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len src_addrs = [(prefill_aux_addr, aux_item_len, 0)] dst_addrs = [(decode_aux_addr, aux_item_len, 0)] - src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True) - dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True) + src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False) + dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False) # Transfer data xfer_handle = self.agent.initialize_xfer( "WRITE",