From d40846d456ecc930c04538778ed11f67cc793c23 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Thu, 24 Jul 2025 17:33:17 -0700 Subject: [PATCH] breakdown kernel update (#8334) --- sgl-kernel/python/sgl_kernel/kvcacheio.py | 114 ++++++++-------------- sgl-kernel/tests/test_kvcacheio.py | 10 +- 2 files changed, 44 insertions(+), 80 deletions(-) diff --git a/sgl-kernel/python/sgl_kernel/kvcacheio.py b/sgl-kernel/python/sgl_kernel/kvcacheio.py index 1440c2ca3..83a611dd5 100644 --- a/sgl-kernel/python/sgl_kernel/kvcacheio.py +++ b/sgl-kernel/python/sgl_kernel/kvcacheio.py @@ -10,30 +10,21 @@ def transfer_kv_per_layer( dst_v: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, - io_backend: str, - page_size: int, item_size: int, block_quota: int = 2, num_warps_per_block: int = 32, ): - if io_backend == "kernel": - torch.ops.sgl_kernel.transfer_kv_per_layer( - src_k, - dst_k, - src_v, - dst_v, - src_indices, - dst_indices, - item_size * src_k.element_size(), # todo, hot fix for compatibility - block_quota, - num_warps_per_block, - ) - elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_direct( - [src_k, src_v], [dst_k, dst_v], src_indices, dst_indices, page_size - ) - else: - raise ValueError(f"Unsupported io backend") + torch.ops.sgl_kernel.transfer_kv_per_layer( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + item_size, + block_quota, + num_warps_per_block, + ) def transfer_kv_per_layer_pf_lf( @@ -69,29 +60,23 @@ def transfer_kv_all_layer( dst_v_layers: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, - io_backend: str, item_size: int, num_layers: int, block_quota: int = 2, num_warps_per_block: int = 32, ): - if io_backend == "kernel": - torch.ops.sgl_kernel.transfer_kv_all_layer( - src_k_layers, - dst_k_layers, - src_v_layers, - dst_v_layers, - src_indices, - dst_indices, - item_size, - num_layers, - block_quota, - num_warps_per_block, - ) - elif io_backend == "direct": - raise NotImplementedError("Deprecated interface") - else: - raise ValueError(f"Unsupported io backend") + torch.ops.sgl_kernel.transfer_kv_all_layer( + src_k_layers, + dst_k_layers, + src_v_layers, + dst_v_layers, + src_indices, + dst_indices, + item_size, + num_layers, + block_quota, + num_warps_per_block, + ) def transfer_kv_all_layer_lf_pf( @@ -139,28 +124,19 @@ def transfer_kv_per_layer_mla( dst: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, - io_backend: str, - page_size: int, item_size: int, block_quota: int = 2, num_warps_per_block: int = 32, ): - if io_backend == "kernel": - torch.ops.sgl_kernel.transfer_kv_per_layer_mla( - src, - dst, - src_indices, - dst_indices, - item_size * src.element_size(), # todo, hot fix for compatibility - block_quota, - num_warps_per_block, - ) - elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_direct( - [src], [dst], src_indices, dst_indices, page_size - ) - else: - raise ValueError(f"Unsupported io backend") + torch.ops.sgl_kernel.transfer_kv_per_layer_mla( + src, + dst, + src_indices, + dst_indices, + item_size, + block_quota, + num_warps_per_block, + ) def transfer_kv_per_layer_mla_pf_lf( @@ -190,27 +166,21 @@ def transfer_kv_all_layer_mla( dst_layers: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, - io_backend: str, item_size: int, num_layers: int, block_quota: int = 2, num_warps_per_block: int = 32, ): - if io_backend == "kernel": - torch.ops.sgl_kernel.transfer_kv_all_layer_mla( - src_layers, - dst_layers, - src_indices, - dst_indices, - item_size, - num_layers, - block_quota, - num_warps_per_block, - ) - elif io_backend == "direct": - raise NotImplementedError("Deprecated interface") - else: - raise ValueError(f"Unsupported io backend") + torch.ops.sgl_kernel.transfer_kv_all_layer_mla( + src_layers, + dst_layers, + src_indices, + dst_indices, + item_size, + num_layers, + block_quota, + num_warps_per_block, + ) def transfer_kv_all_layer_mla_lf_pf( diff --git a/sgl-kernel/tests/test_kvcacheio.py b/sgl-kernel/tests/test_kvcacheio.py index 171fc4ca4..d2b5be111 100644 --- a/sgl-kernel/tests/test_kvcacheio.py +++ b/sgl-kernel/tests/test_kvcacheio.py @@ -101,9 +101,7 @@ def test_transfer_kv( dst_pool_kernel[layer_idx_to_test], src_indices_device, dst_indices_device, - io_backend="kernel", - page_size=page_size, - item_size=item_size, + item_size=item_size * dtype.itemsize, ) transfer_kv_direct( [src_pool_host[layer_idx_to_test]], @@ -138,7 +136,6 @@ def test_transfer_kv( dst_layers_device, src_indices_device, dst_indices_device, - io_backend="kernel", item_size=item_size * dtype.itemsize, num_layers=num_layers, ) @@ -173,9 +170,7 @@ def test_transfer_kv( dst_v_pool_kernel[layer_idx_to_test], src_indices_device, dst_indices_device, - io_backend="kernel", - page_size=page_size, - item_size=item_size, + item_size=item_size * dtype.itemsize, ) transfer_kv_direct( [src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]], @@ -235,7 +230,6 @@ def test_transfer_kv( dst_v_layers_device, src_indices_device, dst_indices_device, - io_backend="kernel", item_size=item_size * dtype.itemsize, num_layers=num_layers, )