增加dcu_alloc_decode_kernel实现
This commit is contained in:
@@ -28,6 +28,7 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
||||
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
|
||||
from sgl_kernel.kvcacheio import dcu_alloc_decode_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
@@ -430,6 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||
self.num_pages = size // page_size
|
||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||
self.use_dcu_decode_kernel = get_bool_env_var("USE_DCU_DECODE_KERNEL")
|
||||
self.seen_max_num_extend_tokens_next_power_of_2 = 1
|
||||
self.clear()
|
||||
|
||||
@@ -525,14 +527,26 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
alloc_decode_kernel[(bs,)](
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
if self.use_dcu_decode_kernel:
|
||||
dcu_alloc_decode_kernel(
|
||||
seq_lens_ptr = seq_lens,
|
||||
last_loc_ptr = last_loc,
|
||||
free_page_ptr = self.free_pages,
|
||||
out_indices = out_indices,
|
||||
bs = bs,
|
||||
bs_upper = next_power_of_2(bs),
|
||||
page_size = self.page_size,
|
||||
)
|
||||
else:
|
||||
alloc_decode_kernel[(bs,)](
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
@@ -125,6 +125,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/kvcacheio
|
||||
*/
|
||||
m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size) -> ()");
|
||||
m.impl("dcu_alloc_decode_kernel", torch::kCUDA, &dcu_alloc_decode_kernel);
|
||||
m.def(
|
||||
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
|
||||
|
||||
@@ -571,3 +571,68 @@ void transfer_kv_all_layer_direct_lf_pf(
|
||||
int64_t page_size) {
|
||||
transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size);
|
||||
}
|
||||
|
||||
__device__ int64_t ceil_div(int64_t a, int64_t b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
__global__ void launch_alloc_decode_kernel(
|
||||
const int64_t* seq_lens_ptr,
|
||||
const int32_t* last_loc_ptr,
|
||||
const int64_t* free_page_ptr,
|
||||
int64_t* out_indices,
|
||||
int64_t bs_upper,
|
||||
int64_t page_size)
|
||||
{
|
||||
|
||||
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (pid >= bs_upper) return;
|
||||
|
||||
int64_t seq_len = seq_lens_ptr[pid];
|
||||
int64_t pre_len = seq_len - 1;
|
||||
|
||||
int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
|
||||
|
||||
int64_t sum_num_new_pages = 0;
|
||||
for (int64_t i = 0; i < pid; i++) {
|
||||
int64_t other_seq_len = seq_lens_ptr[i];
|
||||
int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len;
|
||||
|
||||
int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
|
||||
int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
|
||||
int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
|
||||
|
||||
sum_num_new_pages += other_num_new_pages;
|
||||
}
|
||||
int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;
|
||||
|
||||
if (num_page_start_loc_self == 0) {
|
||||
int32_t last_loc = last_loc_ptr[pid];
|
||||
out_indices[pid] = last_loc + 1;
|
||||
} else {
|
||||
int64_t page = free_page_ptr[new_page_start_loc];
|
||||
out_indices[pid] = page * page_size;
|
||||
}
|
||||
}
|
||||
|
||||
void dcu_alloc_decode_kernel(
|
||||
const at::Tensor seq_lens_ptr,
|
||||
const at::Tensor last_loc_ptr,
|
||||
const at::Tensor free_page_ptr,
|
||||
at::Tensor out_indices,
|
||||
int64_t bs,
|
||||
int64_t bs_upper,
|
||||
int64_t page_size) {
|
||||
|
||||
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
|
||||
const int32_t* last_loc_ptr1 = static_cast<const int32_t*>(last_loc_ptr.data_ptr());
|
||||
const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
|
||||
int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());
|
||||
|
||||
int64_t block_size = 64;
|
||||
int64_t grid_size = (bs + block_size - 1) / block_size;
|
||||
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
|
||||
launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
@@ -538,6 +538,15 @@ void segment_packbits(
|
||||
/*
|
||||
* From csrc/kvcacheio
|
||||
*/
|
||||
void dcu_alloc_decode_kernel(
|
||||
const at::Tensor seq_lens_ptr,
|
||||
const at::Tensor last_loc_ptr,
|
||||
const at::Tensor free_page_ptr,
|
||||
at::Tensor out_indices,
|
||||
int64_t bs,
|
||||
int64_t bs_upper,
|
||||
int64_t page_size);
|
||||
|
||||
void transfer_kv_per_layer(
|
||||
const at::Tensor src_k,
|
||||
at::Tensor dst_k,
|
||||
|
||||
@@ -10,6 +10,25 @@ def is_hip() -> bool:
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
def dcu_alloc_decode_kernel(
|
||||
seq_lens_ptr: torch.Tensor,
|
||||
last_loc_ptr: torch.Tensor,
|
||||
free_page_ptr: torch.Tensor ,
|
||||
out_indices: torch.Tensor ,
|
||||
bs: int,
|
||||
bs_upper: int,
|
||||
page_size: int,
|
||||
):
|
||||
torch.ops.sgl_kernel.dcu_alloc_decode_kernel(
|
||||
seq_lens_ptr,
|
||||
last_loc_ptr,
|
||||
free_page_ptr,
|
||||
out_indices,
|
||||
bs,
|
||||
bs_upper,
|
||||
page_size,
|
||||
)
|
||||
|
||||
def transfer_kv_per_layer(
|
||||
src_k: torch.Tensor,
|
||||
dst_k: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user