Hicache IO kernel refactoring (#8264)

This commit is contained in:
Zhiqiang Xie
2025-07-23 01:49:03 -07:00
committed by GitHub
parent 8abd3e77fe
commit b43263307f
5 changed files with 545 additions and 280 deletions

View File

@@ -3,6 +3,7 @@ import torch
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_mla,
transfer_kv_direct,
transfer_kv_per_layer,
transfer_kv_per_layer_mla,
)
@@ -104,14 +105,12 @@ def test_transfer_kv(
page_size=page_size,
item_size=item_size,
)
transfer_kv_per_layer_mla(
src_pool_host[layer_idx_to_test],
dst_pool_direct[layer_idx_to_test],
transfer_kv_direct(
[src_pool_host[layer_idx_to_test]],
[dst_pool_direct[layer_idx_to_test]],
src_indices_host,
dst_indices_device,
io_backend="direct",
page_size=page_size,
item_size=item_size,
)
else:
for layer_id in range(num_layers):
@@ -121,29 +120,34 @@ def test_transfer_kv(
src_indices_host,
dst_indices_device,
)
src_layers_device = torch.tensor(
[src_pool_host[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
dst_layers_device = torch.tensor(
[
dst_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
transfer_kv_all_layer_mla(
src_pool_host,
dst_pool_kernel,
src_layers_device,
dst_layers_device,
src_indices_device,
dst_indices_device,
io_backend="kernel",
page_size=page_size,
item_size=item_size,
item_size=item_size * dtype.itemsize,
num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
)
transfer_kv_all_layer_mla(
src_pool_host,
dst_pool_direct,
transfer_kv_direct(
[src_pool_host[layer_id] for layer_id in range(num_layers)],
[dst_pool_direct[layer_id] for layer_id in range(num_layers)],
src_indices_host,
dst_indices_device,
io_backend="direct",
page_size=page_size,
item_size=item_size,
num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_pool_kernel, dst_pool_ref)
@@ -173,16 +177,15 @@ def test_transfer_kv(
page_size=page_size,
item_size=item_size,
)
transfer_kv_per_layer(
src_k_pool[layer_idx_to_test],
dst_k_pool_direct[layer_idx_to_test],
src_v_pool[layer_idx_to_test],
dst_v_pool_direct[layer_idx_to_test],
transfer_kv_direct(
[src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]],
[
dst_k_pool_direct[layer_idx_to_test],
dst_v_pool_direct[layer_idx_to_test],
],
src_indices_host,
dst_indices_device,
io_backend="direct",
page_size=page_size,
item_size=item_size,
)
else:
for layer_id in range(num_layers):
@@ -198,33 +201,52 @@ def test_transfer_kv(
src_indices_host,
dst_indices_device,
)
src_k_layers_device = torch.tensor(
[src_k_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
src_v_layers_device = torch.tensor(
[src_v_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
dst_k_layers_device = torch.tensor(
[
dst_k_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
dst_v_layers_device = torch.tensor(
[
dst_v_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
transfer_kv_all_layer(
src_k_pool,
dst_k_pool_kernel,
src_v_pool,
dst_v_pool_kernel,
src_k_layers_device,
dst_k_layers_device,
src_v_layers_device,
dst_v_layers_device,
src_indices_device,
dst_indices_device,
io_backend="kernel",
page_size=page_size,
item_size=item_size,
item_size=item_size * dtype.itemsize,
num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
)
transfer_kv_all_layer(
src_k_pool,
dst_k_pool_direct,
src_v_pool,
dst_v_pool_direct,
transfer_kv_direct(
[src_k_pool[layer_id] for layer_id in range(num_layers)]
+ [src_v_pool[layer_id] for layer_id in range(num_layers)],
[dst_k_pool_direct[layer_id] for layer_id in range(num_layers)]
+ [dst_v_pool_direct[layer_id] for layer_id in range(num_layers)],
src_indices_host,
dst_indices_device,
io_backend="direct",
page_size=page_size,
item_size=item_size,
num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)