kvcache io kernels and test case (#7382)

This commit is contained in:
Zhiqiang Xie
2025-06-23 11:58:59 -07:00
committed by GitHub
parent 76139bfba0
commit 34c3f9b2d3
7 changed files with 845 additions and 0 deletions

View File

@@ -47,6 +47,12 @@ from sgl_kernel.gemm import (
shuffle_rows,
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_mla,
transfer_kv_per_layer,
transfer_kv_per_layer_mla,
)
from sgl_kernel.moe import (
apply_shuffle_mul_sum,
cutlass_fp4_group_mm,

View File

@@ -0,0 +1,137 @@
import torch
def transfer_kv_per_layer(
src_k: torch.Tensor,
dst_k: torch.Tensor,
src_v: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
io_backend: str,
page_size: int,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_per_layer(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
item_size,
block_quota,
num_warps_per_block,
)
elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_per_layer_direct(
src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size
)
else:
raise ValueError(f"Unsupported io backend")
def transfer_kv_all_layer(
src_k: torch.Tensor,
dst_k: torch.Tensor,
src_v: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
io_backend: str,
page_size: int,
item_size: int,
num_layers: int,
src_layer_offset: int,
dst_layer_offset: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_all_layer(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
item_size,
num_layers,
src_layer_offset,
dst_layer_offset,
block_quota,
num_warps_per_block,
)
elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_all_layer_direct(
src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers
)
else:
raise ValueError(f"Unsupported io backend")
def transfer_kv_per_layer_mla(
src: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
io_backend: str,
page_size: int,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
src,
dst,
src_indices,
dst_indices,
item_size,
block_quota,
num_warps_per_block,
)
elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_direct(
src, dst, src_indices, dst_indices, page_size
)
else:
raise ValueError(f"Unsupported io backend")
def transfer_kv_all_layer_mla(
src: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
io_backend: str,
page_size: int,
item_size: int,
num_layers: int,
src_layer_offset: int,
dst_layer_offset: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
src,
dst,
src_indices,
dst_indices,
item_size,
num_layers,
src_layer_offset,
dst_layer_offset,
block_quota,
num_warps_per_block,
)
elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_direct(
src, dst, src_indices, dst_indices, page_size, num_layers
)
else:
raise ValueError(f"Unsupported io backend")