增加dcu_alloc_decode_kernel实现

This commit is contained in:
liucong
2025-11-04 20:27:27 +08:00
parent 46da95569f
commit c9bcffd2a5
5 changed files with 117 additions and 8 deletions

View File

@@ -28,6 +28,7 @@ import triton.language as tl
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
from sgl_kernel.kvcacheio import dcu_alloc_decode_kernel
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
@@ -430,6 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.use_dcu_decode_kernel = get_bool_env_var("USE_DCU_DECODE_KERNEL")
self.seen_max_num_extend_tokens_next_power_of_2 = 1
self.clear()
@@ -525,14 +527,26 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self.merge_and_sort_free()
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)](
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
)
if self.use_dcu_decode_kernel:
dcu_alloc_decode_kernel(
seq_lens_ptr = seq_lens,
last_loc_ptr = last_loc,
free_page_ptr = self.free_pages,
out_indices = out_indices,
bs = bs,
bs_upper = next_power_of_2(bs),
page_size = self.page_size,
)
else:
alloc_decode_kernel[(bs,)](
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)