Compare commits
1 Commits
v0.5.4_dev
...
v0.5.4_dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9bcffd2a5 |
@@ -28,6 +28,7 @@ import triton.language as tl
|
|||||||
|
|
||||||
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
||||||
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
|
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
|
||||||
|
from sgl_kernel.kvcacheio import dcu_alloc_decode_kernel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||||
@@ -430,6 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||||
self.num_pages = size // page_size
|
self.num_pages = size // page_size
|
||||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||||
|
self.use_dcu_decode_kernel = get_bool_env_var("USE_DCU_DECODE_KERNEL")
|
||||||
self.seen_max_num_extend_tokens_next_power_of_2 = 1
|
self.seen_max_num_extend_tokens_next_power_of_2 = 1
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
@@ -525,14 +527,26 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|||||||
self.merge_and_sort_free()
|
self.merge_and_sort_free()
|
||||||
|
|
||||||
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||||
alloc_decode_kernel[(bs,)](
|
|
||||||
seq_lens,
|
if self.use_dcu_decode_kernel:
|
||||||
last_loc,
|
dcu_alloc_decode_kernel(
|
||||||
self.free_pages,
|
seq_lens_ptr = seq_lens,
|
||||||
out_indices,
|
last_loc_ptr = last_loc,
|
||||||
next_power_of_2(bs),
|
free_page_ptr = self.free_pages,
|
||||||
self.page_size,
|
out_indices = out_indices,
|
||||||
)
|
bs = bs,
|
||||||
|
bs_upper = next_power_of_2(bs),
|
||||||
|
page_size = self.page_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
alloc_decode_kernel[(bs,)](
|
||||||
|
seq_lens,
|
||||||
|
last_loc,
|
||||||
|
self.free_pages,
|
||||||
|
out_indices,
|
||||||
|
next_power_of_2(bs),
|
||||||
|
self.page_size,
|
||||||
|
)
|
||||||
|
|
||||||
if self.debug_mode:
|
if self.debug_mode:
|
||||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||||
|
|||||||
@@ -125,6 +125,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
/*
|
/*
|
||||||
* From csrc/kvcacheio
|
* From csrc/kvcacheio
|
||||||
*/
|
*/
|
||||||
|
m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size) -> ()");
|
||||||
|
m.impl("dcu_alloc_decode_kernel", torch::kCUDA, &dcu_alloc_decode_kernel);
|
||||||
m.def(
|
m.def(
|
||||||
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||||
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
|
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
|
||||||
|
|||||||
@@ -571,3 +571,68 @@ void transfer_kv_all_layer_direct_lf_pf(
|
|||||||
int64_t page_size) {
|
int64_t page_size) {
|
||||||
transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size);
|
transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ int64_t ceil_div(int64_t a, int64_t b) {
|
||||||
|
return (a + b - 1) / b;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void launch_alloc_decode_kernel(
|
||||||
|
const int64_t* seq_lens_ptr,
|
||||||
|
const int32_t* last_loc_ptr,
|
||||||
|
const int64_t* free_page_ptr,
|
||||||
|
int64_t* out_indices,
|
||||||
|
int64_t bs_upper,
|
||||||
|
int64_t page_size)
|
||||||
|
{
|
||||||
|
|
||||||
|
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (pid >= bs_upper) return;
|
||||||
|
|
||||||
|
int64_t seq_len = seq_lens_ptr[pid];
|
||||||
|
int64_t pre_len = seq_len - 1;
|
||||||
|
|
||||||
|
int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
|
||||||
|
|
||||||
|
int64_t sum_num_new_pages = 0;
|
||||||
|
for (int64_t i = 0; i < pid; i++) {
|
||||||
|
int64_t other_seq_len = seq_lens_ptr[i];
|
||||||
|
int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len;
|
||||||
|
|
||||||
|
int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
|
||||||
|
int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
|
||||||
|
int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
|
||||||
|
|
||||||
|
sum_num_new_pages += other_num_new_pages;
|
||||||
|
}
|
||||||
|
int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;
|
||||||
|
|
||||||
|
if (num_page_start_loc_self == 0) {
|
||||||
|
int32_t last_loc = last_loc_ptr[pid];
|
||||||
|
out_indices[pid] = last_loc + 1;
|
||||||
|
} else {
|
||||||
|
int64_t page = free_page_ptr[new_page_start_loc];
|
||||||
|
out_indices[pid] = page * page_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void dcu_alloc_decode_kernel(
|
||||||
|
const at::Tensor seq_lens_ptr,
|
||||||
|
const at::Tensor last_loc_ptr,
|
||||||
|
const at::Tensor free_page_ptr,
|
||||||
|
at::Tensor out_indices,
|
||||||
|
int64_t bs,
|
||||||
|
int64_t bs_upper,
|
||||||
|
int64_t page_size) {
|
||||||
|
|
||||||
|
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
|
||||||
|
const int32_t* last_loc_ptr1 = static_cast<const int32_t*>(last_loc_ptr.data_ptr());
|
||||||
|
const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
|
||||||
|
int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());
|
||||||
|
|
||||||
|
int64_t block_size = 64;
|
||||||
|
int64_t grid_size = (bs + block_size - 1) / block_size;
|
||||||
|
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size);
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
}
|
||||||
|
|||||||
@@ -538,6 +538,15 @@ void segment_packbits(
|
|||||||
/*
|
/*
|
||||||
* From csrc/kvcacheio
|
* From csrc/kvcacheio
|
||||||
*/
|
*/
|
||||||
|
void dcu_alloc_decode_kernel(
|
||||||
|
const at::Tensor seq_lens_ptr,
|
||||||
|
const at::Tensor last_loc_ptr,
|
||||||
|
const at::Tensor free_page_ptr,
|
||||||
|
at::Tensor out_indices,
|
||||||
|
int64_t bs,
|
||||||
|
int64_t bs_upper,
|
||||||
|
int64_t page_size);
|
||||||
|
|
||||||
void transfer_kv_per_layer(
|
void transfer_kv_per_layer(
|
||||||
const at::Tensor src_k,
|
const at::Tensor src_k,
|
||||||
at::Tensor dst_k,
|
at::Tensor dst_k,
|
||||||
|
|||||||
@@ -10,6 +10,25 @@ def is_hip() -> bool:
|
|||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
|
def dcu_alloc_decode_kernel(
|
||||||
|
seq_lens_ptr: torch.Tensor,
|
||||||
|
last_loc_ptr: torch.Tensor,
|
||||||
|
free_page_ptr: torch.Tensor ,
|
||||||
|
out_indices: torch.Tensor ,
|
||||||
|
bs: int,
|
||||||
|
bs_upper: int,
|
||||||
|
page_size: int,
|
||||||
|
):
|
||||||
|
torch.ops.sgl_kernel.dcu_alloc_decode_kernel(
|
||||||
|
seq_lens_ptr,
|
||||||
|
last_loc_ptr,
|
||||||
|
free_page_ptr,
|
||||||
|
out_indices,
|
||||||
|
bs,
|
||||||
|
bs_upper,
|
||||||
|
page_size,
|
||||||
|
)
|
||||||
|
|
||||||
def transfer_kv_per_layer(
|
def transfer_kv_per_layer(
|
||||||
src_k: torch.Tensor,
|
src_k: torch.Tensor,
|
||||||
dst_k: torch.Tensor,
|
dst_k: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user