增加dcu_alloc_decode_kernel实现
This commit is contained in:
@@ -28,6 +28,7 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
||||
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
|
||||
from sgl_kernel.kvcacheio import dcu_alloc_decode_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
@@ -430,6 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||
self.num_pages = size // page_size
|
||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||
self.use_dcu_decode_kernel = get_bool_env_var("USE_DCU_DECODE_KERNEL")
|
||||
self.seen_max_num_extend_tokens_next_power_of_2 = 1
|
||||
self.clear()
|
||||
|
||||
@@ -525,14 +527,26 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
||||
alloc_decode_kernel[(bs,)](
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
if self.use_dcu_decode_kernel:
|
||||
dcu_alloc_decode_kernel(
|
||||
seq_lens_ptr = seq_lens,
|
||||
last_loc_ptr = last_loc,
|
||||
free_page_ptr = self.free_pages,
|
||||
out_indices = out_indices,
|
||||
bs = bs,
|
||||
bs_upper = next_power_of_2(bs),
|
||||
page_size = self.page_size,
|
||||
)
|
||||
else:
|
||||
alloc_decode_kernel[(bs,)](
|
||||
seq_lens,
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
Reference in New Issue
Block a user