breakdown kernel update (#8334)

This commit is contained in:
Zhiqiang Xie
2025-07-24 17:33:17 -07:00
committed by GitHub
parent 145482f422
commit d40846d456
2 changed files with 44 additions and 80 deletions

View File

@@ -10,30 +10,21 @@ def transfer_kv_per_layer(
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
io_backend: str,
page_size: int,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_per_layer(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
item_size * src_k.element_size(), # todo, hot fix for compatibility
block_quota,
num_warps_per_block,
)
elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_direct(
[src_k, src_v], [dst_k, dst_v], src_indices, dst_indices, page_size
)
else:
raise ValueError(f"Unsupported io backend")
torch.ops.sgl_kernel.transfer_kv_per_layer(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
item_size,
block_quota,
num_warps_per_block,
)
def transfer_kv_per_layer_pf_lf(
@@ -69,29 +60,23 @@ def transfer_kv_all_layer(
dst_v_layers: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
io_backend: str,
item_size: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_all_layer(
src_k_layers,
dst_k_layers,
src_v_layers,
dst_v_layers,
src_indices,
dst_indices,
item_size,
num_layers,
block_quota,
num_warps_per_block,
)
elif io_backend == "direct":
raise NotImplementedError("Deprecated interface")
else:
raise ValueError(f"Unsupported io backend")
torch.ops.sgl_kernel.transfer_kv_all_layer(
src_k_layers,
dst_k_layers,
src_v_layers,
dst_v_layers,
src_indices,
dst_indices,
item_size,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_lf_pf(
@@ -139,28 +124,19 @@ def transfer_kv_per_layer_mla(
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
io_backend: str,
page_size: int,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
src,
dst,
src_indices,
dst_indices,
item_size * src.element_size(), # todo, hot fix for compatibility
block_quota,
num_warps_per_block,
)
elif io_backend == "direct":
torch.ops.sgl_kernel.transfer_kv_direct(
[src], [dst], src_indices, dst_indices, page_size
)
else:
raise ValueError(f"Unsupported io backend")
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
src,
dst,
src_indices,
dst_indices,
item_size,
block_quota,
num_warps_per_block,
)
def transfer_kv_per_layer_mla_pf_lf(
@@ -190,27 +166,21 @@ def transfer_kv_all_layer_mla(
dst_layers: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
io_backend: str,
item_size: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
):
if io_backend == "kernel":
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
src_layers,
dst_layers,
src_indices,
dst_indices,
item_size,
num_layers,
block_quota,
num_warps_per_block,
)
elif io_backend == "direct":
raise NotImplementedError("Deprecated interface")
else:
raise ValueError(f"Unsupported io backend")
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
src_layers,
dst_layers,
src_indices,
dst_indices,
item_size,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_mla_lf_pf(