diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 35874bb18..57b0a47c4 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -433,7 +433,9 @@ class HiCacheController: if self.io_backend == "kernel": return host_indices.to(self.mem_pool_device.device), device_indices elif self.io_backend == "direct": - return host_indices, device_indices.cpu() + device_indices = device_indices.cpu() + host_indices, idx = host_indices.sort() + return host_indices, device_indices.index_select(0, idx) else: raise ValueError(f"Unsupported io backend") diff --git a/sgl-kernel/csrc/kvcacheio/transfer.cu b/sgl-kernel/csrc/kvcacheio/transfer.cu index cc6942e67..b79e9eb35 100644 --- a/sgl-kernel/csrc/kvcacheio/transfer.cu +++ b/sgl-kernel/csrc/kvcacheio/transfer.cu @@ -451,15 +451,33 @@ void transfer_kv_direct( auto src_indices_cpu = src_indices.cpu(); auto dst_indices_cpu = dst_indices.cpu(); - const int64_t num_pages = src_indices_cpu.size(0) / page_size; + const auto num_indices = src_indices_cpu.numel(); const int64_t num_layers = src_layers.size(); + int64_t* src_indices_ptr = src_indices_cpu.data_ptr(); + int64_t* dst_indices_ptr = dst_indices_cpu.data_ptr(); - for (int64_t i = 0; i < num_pages; ++i) { - auto src_index = src_indices_cpu[i * page_size].item(); - auto dst_index = dst_indices_cpu[i * page_size].item(); + int64_t start_index = 0; + int64_t end_index = 0; + + for (int64_t i = 0; i < num_indices; ++i) { + if (i < num_indices - 1) { + auto src_diff = src_indices_ptr[i + 1] - src_indices_ptr[i]; + auto dst_diff = dst_indices_ptr[i + 1] - dst_indices_ptr[i]; + + if (src_diff == 1 && dst_diff == 1) { + continue; + } + end_index = i + 1; + } else { // last batch + end_index = num_indices; + } + auto src_index = src_indices_ptr[start_index]; + auto dst_index = dst_indices_ptr[start_index]; + auto num_tokens = end_index - start_index; for (int64_t j = 0; j < num_layers; ++j) { - transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, page_size); + transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, num_tokens); } + start_index = end_index; } }