diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 02c64c8b3..83b19375c 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -358,6 +358,7 @@ class MHATokenToKVPoolHost(HostKVCache): dst_v=device_pool.v_buffer[layer_id], src_indices=host_indices, dst_indices=device_indices, + layer_id=layer_id, item_size=self.token_stride_size, src_layout_dim=self.layout_dim, ) @@ -585,6 +586,7 @@ class MLATokenToKVPoolHost(HostKVCache): dst=device_pool.kv_buffer[layer_id], src_indices=host_indices, dst_indices=device_indices, + layer_id=layer_id, item_size=self.token_stride_size, src_layout_dim=self.layout_dim, ) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 989ae14eb..2c3b9b767 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -250,7 +250,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer); m.def( "transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " - "dst_indices, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); + "dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf); m.def( "transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, " @@ -267,8 +267,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "block_quota, int num_warps_per_block) -> ()"); m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla); m.def( - "transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, " - "int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); + "transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int layer_id, " + "int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf); m.def( "transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int " diff --git a/sgl-kernel/csrc/kvcacheio/transfer.cu b/sgl-kernel/csrc/kvcacheio/transfer.cu index b79e9eb35..cbf5feeea 100644 --- a/sgl-kernel/csrc/kvcacheio/transfer.cu +++ b/sgl-kernel/csrc/kvcacheio/transfer.cu @@ -210,6 +210,7 @@ void transfer_kv_per_layer_pf_lf( at::Tensor dst_v, const at::Tensor src_indices, const at::Tensor dst_indices, + int64_t layer_id, int64_t item_size, int64_t src_layout_dim, int64_t block_quota, @@ -222,7 +223,7 @@ void transfer_kv_per_layer_pf_lf( dst_v, src_indices, dst_indices, - 0, + layer_id, 1, item_size, src_layout_dim, @@ -336,6 +337,7 @@ void transfer_kv_per_layer_mla_pf_lf( at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, + int64_t layer_id, int64_t item_size, int64_t src_layout_dim, int64_t block_quota, @@ -348,7 +350,7 @@ void transfer_kv_per_layer_mla_pf_lf( empty, src_indices, dst_indices, - 0, + layer_id, 1, item_size, src_layout_dim, diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 88720dfea..15b3e2db7 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -419,6 +419,7 @@ void transfer_kv_per_layer_pf_lf( at::Tensor dst_v, const at::Tensor src_indices, const at::Tensor dst_indices, + int64_t layer_id, int64_t item_size, int64_t src_layout_dim, int64_t block_quota, @@ -463,6 +464,7 @@ void transfer_kv_per_layer_mla_pf_lf( at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, + int64_t layer_id, int64_t item_size, int64_t src_layout_dim, int64_t block_quota, diff --git a/sgl-kernel/python/sgl_kernel/kvcacheio.py b/sgl-kernel/python/sgl_kernel/kvcacheio.py index 83a611dd5..fd05e8466 100644 --- a/sgl-kernel/python/sgl_kernel/kvcacheio.py +++ b/sgl-kernel/python/sgl_kernel/kvcacheio.py @@ -34,6 +34,7 @@ def transfer_kv_per_layer_pf_lf( dst_v: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, + layer_id: int, item_size: int, src_layout_dim: int, block_quota: int = 2, @@ -46,6 +47,7 @@ def transfer_kv_per_layer_pf_lf( dst_v, src_indices, dst_indices, + layer_id, item_size, src_layout_dim, block_quota, @@ -144,6 +146,7 @@ def transfer_kv_per_layer_mla_pf_lf( dst: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, + layer_id: int, item_size: int, src_layout_dim: int, block_quota: int = 2, @@ -154,6 +157,7 @@ def transfer_kv_per_layer_mla_pf_lf( dst, src_indices, dst_indices, + layer_id, item_size, src_layout_dim, block_quota,