增加dcu_alloc_decode_kernel实现

This commit is contained in:
liucong
2025-11-04 20:27:27 +08:00
parent 46da95569f
commit c9bcffd2a5
5 changed files with 117 additions and 8 deletions

View File

@@ -10,6 +10,25 @@ def is_hip() -> bool:
_is_hip = is_hip()
def dcu_alloc_decode_kernel(
seq_lens_ptr: torch.Tensor,
last_loc_ptr: torch.Tensor,
free_page_ptr: torch.Tensor ,
out_indices: torch.Tensor ,
bs: int,
bs_upper: int,
page_size: int,
):
torch.ops.sgl_kernel.dcu_alloc_decode_kernel(
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
bs,
bs_upper,
page_size,
)
def transfer_kv_per_layer(
src_k: torch.Tensor,
dst_k: torch.Tensor,