Hicache IO kernel refactoring (#8264)
This commit is contained in:
@@ -3,6 +3,7 @@ import torch
|
||||
from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_all_layer,
|
||||
transfer_kv_all_layer_mla,
|
||||
transfer_kv_direct,
|
||||
transfer_kv_per_layer,
|
||||
transfer_kv_per_layer_mla,
|
||||
)
|
||||
@@ -104,14 +105,12 @@ def test_transfer_kv(
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
)
|
||||
transfer_kv_per_layer_mla(
|
||||
src_pool_host[layer_idx_to_test],
|
||||
dst_pool_direct[layer_idx_to_test],
|
||||
transfer_kv_direct(
|
||||
[src_pool_host[layer_idx_to_test]],
|
||||
[dst_pool_direct[layer_idx_to_test]],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
io_backend="direct",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
)
|
||||
else:
|
||||
for layer_id in range(num_layers):
|
||||
@@ -121,29 +120,34 @@ def test_transfer_kv(
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
src_layers_device = torch.tensor(
|
||||
[src_pool_host[layer_id].data_ptr() for layer_id in range(num_layers)],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
dst_layers_device = torch.tensor(
|
||||
[
|
||||
dst_pool_kernel[layer_id].data_ptr()
|
||||
for layer_id in range(num_layers)
|
||||
],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
transfer_kv_all_layer_mla(
|
||||
src_pool_host,
|
||||
dst_pool_kernel,
|
||||
src_layers_device,
|
||||
dst_layers_device,
|
||||
src_indices_device,
|
||||
dst_indices_device,
|
||||
io_backend="kernel",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
item_size=item_size * dtype.itemsize,
|
||||
num_layers=num_layers,
|
||||
src_layer_offset=total_items_in_pool * item_size,
|
||||
dst_layer_offset=total_items_in_pool * item_size,
|
||||
)
|
||||
transfer_kv_all_layer_mla(
|
||||
src_pool_host,
|
||||
dst_pool_direct,
|
||||
transfer_kv_direct(
|
||||
[src_pool_host[layer_id] for layer_id in range(num_layers)],
|
||||
[dst_pool_direct[layer_id] for layer_id in range(num_layers)],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
io_backend="direct",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
num_layers=num_layers,
|
||||
src_layer_offset=total_items_in_pool * item_size,
|
||||
dst_layer_offset=total_items_in_pool * item_size,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_pool_kernel, dst_pool_ref)
|
||||
@@ -173,16 +177,15 @@ def test_transfer_kv(
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
)
|
||||
transfer_kv_per_layer(
|
||||
src_k_pool[layer_idx_to_test],
|
||||
dst_k_pool_direct[layer_idx_to_test],
|
||||
src_v_pool[layer_idx_to_test],
|
||||
dst_v_pool_direct[layer_idx_to_test],
|
||||
transfer_kv_direct(
|
||||
[src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]],
|
||||
[
|
||||
dst_k_pool_direct[layer_idx_to_test],
|
||||
dst_v_pool_direct[layer_idx_to_test],
|
||||
],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
io_backend="direct",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
)
|
||||
else:
|
||||
for layer_id in range(num_layers):
|
||||
@@ -198,33 +201,52 @@ def test_transfer_kv(
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
)
|
||||
|
||||
src_k_layers_device = torch.tensor(
|
||||
[src_k_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
src_v_layers_device = torch.tensor(
|
||||
[src_v_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
dst_k_layers_device = torch.tensor(
|
||||
[
|
||||
dst_k_pool_kernel[layer_id].data_ptr()
|
||||
for layer_id in range(num_layers)
|
||||
],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
dst_v_layers_device = torch.tensor(
|
||||
[
|
||||
dst_v_pool_kernel[layer_id].data_ptr()
|
||||
for layer_id in range(num_layers)
|
||||
],
|
||||
dtype=torch.uint64,
|
||||
device=device,
|
||||
)
|
||||
transfer_kv_all_layer(
|
||||
src_k_pool,
|
||||
dst_k_pool_kernel,
|
||||
src_v_pool,
|
||||
dst_v_pool_kernel,
|
||||
src_k_layers_device,
|
||||
dst_k_layers_device,
|
||||
src_v_layers_device,
|
||||
dst_v_layers_device,
|
||||
src_indices_device,
|
||||
dst_indices_device,
|
||||
io_backend="kernel",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
item_size=item_size * dtype.itemsize,
|
||||
num_layers=num_layers,
|
||||
src_layer_offset=total_items_in_pool * item_size,
|
||||
dst_layer_offset=total_items_in_pool * item_size,
|
||||
)
|
||||
transfer_kv_all_layer(
|
||||
src_k_pool,
|
||||
dst_k_pool_direct,
|
||||
src_v_pool,
|
||||
dst_v_pool_direct,
|
||||
transfer_kv_direct(
|
||||
[src_k_pool[layer_id] for layer_id in range(num_layers)]
|
||||
+ [src_v_pool[layer_id] for layer_id in range(num_layers)],
|
||||
[dst_k_pool_direct[layer_id] for layer_id in range(num_layers)]
|
||||
+ [dst_v_pool_direct[layer_id] for layer_id in range(num_layers)],
|
||||
src_indices_host,
|
||||
dst_indices_device,
|
||||
io_backend="direct",
|
||||
page_size=page_size,
|
||||
item_size=item_size,
|
||||
num_layers=num_layers,
|
||||
src_layer_offset=total_items_in_pool * item_size,
|
||||
dst_layer_offset=total_items_in_pool * item_size,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)
|
||||
|
||||
Reference in New Issue
Block a user