2025-07-23 01:49:03 -07:00
|
|
|
from typing import List
|
|
|
|
|
|
2025-06-23 11:58:59 -07:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
2025-08-28 15:27:07 -07:00
|
|
|
def is_hip() -> bool:
|
|
|
|
|
return torch.version.hip is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_is_hip = is_hip()
|
|
|
|
|
|
|
|
|
|
|
2025-11-04 20:27:27 +08:00
|
|
|
def dcu_alloc_decode_kernel(
|
|
|
|
|
seq_lens_ptr: torch.Tensor,
|
|
|
|
|
last_loc_ptr: torch.Tensor,
|
|
|
|
|
free_page_ptr: torch.Tensor ,
|
|
|
|
|
out_indices: torch.Tensor ,
|
|
|
|
|
bs: int,
|
|
|
|
|
bs_upper: int,
|
|
|
|
|
page_size: int,
|
|
|
|
|
):
|
|
|
|
|
torch.ops.sgl_kernel.dcu_alloc_decode_kernel(
|
|
|
|
|
seq_lens_ptr,
|
|
|
|
|
last_loc_ptr,
|
|
|
|
|
free_page_ptr,
|
|
|
|
|
out_indices,
|
|
|
|
|
bs,
|
|
|
|
|
bs_upper,
|
|
|
|
|
page_size,
|
|
|
|
|
)
|
|
|
|
|
|
2025-06-23 11:58:59 -07:00
|
|
|
def transfer_kv_per_layer(
|
|
|
|
|
src_k: torch.Tensor,
|
|
|
|
|
dst_k: torch.Tensor,
|
|
|
|
|
src_v: torch.Tensor,
|
|
|
|
|
dst_v: torch.Tensor,
|
|
|
|
|
src_indices: torch.Tensor,
|
|
|
|
|
dst_indices: torch.Tensor,
|
|
|
|
|
item_size: int,
|
|
|
|
|
block_quota: int = 2,
|
2025-08-28 15:27:07 -07:00
|
|
|
num_warps_per_block: int = 16 if _is_hip else 32,
|
2025-06-23 11:58:59 -07:00
|
|
|
):
|
2025-07-24 17:33:17 -07:00
|
|
|
torch.ops.sgl_kernel.transfer_kv_per_layer(
|
|
|
|
|
src_k,
|
|
|
|
|
dst_k,
|
|
|
|
|
src_v,
|
|
|
|
|
dst_v,
|
|
|
|
|
src_indices,
|
|
|
|
|
dst_indices,
|
|
|
|
|
item_size,
|
|
|
|
|
block_quota,
|
|
|
|
|
num_warps_per_block,
|
|
|
|
|
)
|
2025-06-23 11:58:59 -07:00
|
|
|
|
|
|
|
|
|
2025-07-23 01:49:03 -07:00
|
|
|
def transfer_kv_per_layer_pf_lf(
|
2025-06-23 11:58:59 -07:00
|
|
|
src_k: torch.Tensor,
|
|
|
|
|
dst_k: torch.Tensor,
|
|
|
|
|
src_v: torch.Tensor,
|
|
|
|
|
dst_v: torch.Tensor,
|
|
|
|
|
src_indices: torch.Tensor,
|
|
|
|
|
dst_indices: torch.Tensor,
|
2025-08-10 08:16:11 +08:00
|
|
|
layer_id: int,
|
2025-07-23 01:49:03 -07:00
|
|
|
item_size: int,
|
|
|
|
|
src_layout_dim: int,
|
|
|
|
|
block_quota: int = 2,
|
2025-08-28 15:27:07 -07:00
|
|
|
num_warps_per_block: int = 16 if _is_hip else 32,
|
2025-07-23 01:49:03 -07:00
|
|
|
):
|
|
|
|
|
torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
|
|
|
|
|
src_k,
|
|
|
|
|
dst_k,
|
|
|
|
|
src_v,
|
|
|
|
|
dst_v,
|
|
|
|
|
src_indices,
|
|
|
|
|
dst_indices,
|
2025-08-10 08:16:11 +08:00
|
|
|
layer_id,
|
2025-07-23 01:49:03 -07:00
|
|
|
item_size,
|
|
|
|
|
src_layout_dim,
|
|
|
|
|
block_quota,
|
|
|
|
|
num_warps_per_block,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transfer_kv_all_layer(
|
|
|
|
|
src_k_layers: torch.Tensor,
|
|
|
|
|
dst_k_layers: torch.Tensor,
|
|
|
|
|
src_v_layers: torch.Tensor,
|
|
|
|
|
dst_v_layers: torch.Tensor,
|
|
|
|
|
src_indices: torch.Tensor,
|
|
|
|
|
dst_indices: torch.Tensor,
|
2025-06-23 11:58:59 -07:00
|
|
|
item_size: int,
|
|
|
|
|
num_layers: int,
|
|
|
|
|
block_quota: int = 2,
|
2025-08-28 15:27:07 -07:00
|
|
|
num_warps_per_block: int = 16 if _is_hip else 32,
|
2025-06-23 11:58:59 -07:00
|
|
|
):
|
2025-07-24 17:33:17 -07:00
|
|
|
torch.ops.sgl_kernel.transfer_kv_all_layer(
|
|
|
|
|
src_k_layers,
|
|
|
|
|
dst_k_layers,
|
|
|
|
|
src_v_layers,
|
|
|
|
|
dst_v_layers,
|
|
|
|
|
src_indices,
|
|
|
|
|
dst_indices,
|
|
|
|
|
item_size,
|
|
|
|
|
num_layers,
|
|
|
|
|
block_quota,
|
|
|
|
|
num_warps_per_block,
|
|
|
|
|
)
|
2025-06-23 11:58:59 -07:00
|
|
|
|
|
|
|
|
|
2025-07-23 01:49:03 -07:00
|
|
|
def transfer_kv_all_layer_lf_pf(
|
|
|
|
|
src_k_layers: torch.Tensor,
|
|
|
|
|
dst_k: torch.Tensor,
|
|
|
|
|
src_v_layers: torch.Tensor,
|
|
|
|
|
dst_v: torch.Tensor,
|
|
|
|
|
src_indices: torch.Tensor,
|
|
|
|
|
dst_indices: torch.Tensor,
|
|
|
|
|
item_size: int,
|
|
|
|
|
dst_layout_dim: int,
|
|
|
|
|
num_layers: int,
|
|
|
|
|
block_quota: int = 2,
|
2025-08-28 15:27:07 -07:00
|
|
|
num_warps_per_block: int = 16 if _is_hip else 32,
|
2025-07-23 01:49:03 -07:00
|
|
|
):
|
|
|
|
|
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
|
|
|
|
|
src_k_layers,
|
|
|
|
|
dst_k,
|
|
|
|
|
src_v_layers,
|
|
|
|
|
dst_v,
|
|
|
|
|
src_indices,
|
|
|
|
|
dst_indices,
|
|
|
|
|
item_size,
|
|
|
|
|
dst_layout_dim,
|
|
|
|
|
num_layers,
|
|
|
|
|
block_quota,
|
|
|
|
|
num_warps_per_block,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transfer_kv_direct(
|
|
|
|
|
src_layers: List[torch.Tensor],
|
|
|
|
|
dst_layers: List[torch.Tensor],
|
|
|
|
|
src_indices: torch.Tensor,
|
|
|
|
|
dst_indices: torch.Tensor,
|
|
|
|
|
page_size: int,
|
|
|
|
|
):
|
|
|
|
|
torch.ops.sgl_kernel.transfer_kv_direct(
|
|
|
|
|
src_layers, dst_layers, src_indices, dst_indices, page_size
|
2025-09-10 13:35:34 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transfer_kv_per_layer_direct_pf_lf(
|
|
|
|
|
src_ptrs: List[torch.Tensor],
|
|
|
|
|
dst_ptrs: List[torch.Tensor],
|
|
|
|
|
src_indices: torch.Tensor,
|
|
|
|
|
dst_indices: torch.Tensor,
|
|
|
|
|
layer_id: int,
|
|
|
|
|
page_size: int,
|
|
|
|
|
):
|
|
|
|
|
torch.ops.sgl_kernel.transfer_kv_per_layer_direct_pf_lf(
|
|
|
|
|
src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transfer_kv_all_layer_direct_lf_pf(
|
|
|
|
|
src_ptrs: List[torch.Tensor],
|
|
|
|
|
dst_ptrs: List[torch.Tensor],
|
|
|
|
|
src_indices: torch.Tensor,
|
|
|
|
|
dst_indices: torch.Tensor,
|
|
|
|
|
page_size: int,
|
|
|
|
|
):
|
|
|
|
|
torch.ops.sgl_kernel.transfer_kv_all_layer_direct_lf_pf(
|
|
|
|
|
src_ptrs, dst_ptrs, src_indices, dst_indices, page_size
|
2025-07-23 01:49:03 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-06-23 11:58:59 -07:00
|
|
|
def transfer_kv_per_layer_mla(
|
|
|
|
|
src: torch.Tensor,
|
|
|
|
|
dst: torch.Tensor,
|
|
|
|
|
src_indices: torch.Tensor,
|
|
|
|
|
dst_indices: torch.Tensor,
|
|
|
|
|
item_size: int,
|
|
|
|
|
block_quota: int = 2,
|
2025-08-28 15:27:07 -07:00
|
|
|
num_warps_per_block: int = 16 if _is_hip else 32,
|
2025-06-23 11:58:59 -07:00
|
|
|
):
|
2025-07-24 17:33:17 -07:00
|
|
|
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
|
|
|
|
|
src,
|
|
|
|
|
dst,
|
|
|
|
|
src_indices,
|
|
|
|
|
dst_indices,
|
|
|
|
|
item_size,
|
|
|
|
|
block_quota,
|
|
|
|
|
num_warps_per_block,
|
|
|
|
|
)
|
2025-06-23 11:58:59 -07:00
|
|
|
|
|
|
|
|
|
2025-07-23 01:49:03 -07:00
|
|
|
def transfer_kv_per_layer_mla_pf_lf(
|
2025-06-23 11:58:59 -07:00
|
|
|
src: torch.Tensor,
|
|
|
|
|
dst: torch.Tensor,
|
|
|
|
|
src_indices: torch.Tensor,
|
|
|
|
|
dst_indices: torch.Tensor,
|
2025-08-10 08:16:11 +08:00
|
|
|
layer_id: int,
|
2025-07-23 01:49:03 -07:00
|
|
|
item_size: int,
|
|
|
|
|
src_layout_dim: int,
|
|
|
|
|
block_quota: int = 2,
|
2025-08-28 15:27:07 -07:00
|
|
|
num_warps_per_block: int = 16 if _is_hip else 32,
|
2025-07-23 01:49:03 -07:00
|
|
|
):
|
|
|
|
|
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
|
|
|
|
|
src,
|
|
|
|
|
dst,
|
|
|
|
|
src_indices,
|
|
|
|
|
dst_indices,
|
2025-08-10 08:16:11 +08:00
|
|
|
layer_id,
|
2025-07-23 01:49:03 -07:00
|
|
|
item_size,
|
|
|
|
|
src_layout_dim,
|
|
|
|
|
block_quota,
|
|
|
|
|
num_warps_per_block,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transfer_kv_all_layer_mla(
|
|
|
|
|
src_layers: torch.Tensor,
|
|
|
|
|
dst_layers: torch.Tensor,
|
|
|
|
|
src_indices: torch.Tensor,
|
|
|
|
|
dst_indices: torch.Tensor,
|
2025-06-23 11:58:59 -07:00
|
|
|
item_size: int,
|
|
|
|
|
num_layers: int,
|
|
|
|
|
block_quota: int = 2,
|
2025-08-28 15:27:07 -07:00
|
|
|
num_warps_per_block: int = 16 if _is_hip else 32,
|
2025-06-23 11:58:59 -07:00
|
|
|
):
|
2025-07-24 17:33:17 -07:00
|
|
|
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
|
|
|
|
|
src_layers,
|
|
|
|
|
dst_layers,
|
|
|
|
|
src_indices,
|
|
|
|
|
dst_indices,
|
|
|
|
|
item_size,
|
|
|
|
|
num_layers,
|
|
|
|
|
block_quota,
|
|
|
|
|
num_warps_per_block,
|
|
|
|
|
)
|
2025-07-23 01:49:03 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def transfer_kv_all_layer_mla_lf_pf(
|
|
|
|
|
src_layers: torch.Tensor,
|
|
|
|
|
dst: torch.Tensor,
|
|
|
|
|
src_indices: torch.Tensor,
|
|
|
|
|
dst_indices: torch.Tensor,
|
|
|
|
|
item_size: int,
|
|
|
|
|
dst_layout_dim: int,
|
|
|
|
|
num_layers: int,
|
|
|
|
|
block_quota: int = 2,
|
2025-08-28 15:27:07 -07:00
|
|
|
num_warps_per_block: int = 16 if _is_hip else 32,
|
2025-07-23 01:49:03 -07:00
|
|
|
):
|
|
|
|
|
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
|
|
|
|
|
src_layers,
|
|
|
|
|
dst,
|
|
|
|
|
src_indices,
|
|
|
|
|
dst_indices,
|
|
|
|
|
item_size,
|
|
|
|
|
dst_layout_dim,
|
|
|
|
|
num_layers,
|
|
|
|
|
block_quota,
|
|
|
|
|
num_warps_per_block,
|
|
|
|
|
)
|