diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 01f422504..69d07d41b 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -415,6 +415,7 @@ class MHATokenToKVPool(KVCache): enable_memory_saver: bool, start_layer: Optional[int] = None, end_layer: Optional[int] = None, + enable_kv_cache_copy: bool = False, ): super().__init__( size, @@ -446,8 +447,57 @@ class MHATokenToKVPool(KVCache): self.device_module = torch.get_device_module(self.device) self.alt_stream = self.device_module.Stream() if _is_cuda else None + + if enable_kv_cache_copy: + self._init_kv_copy_and_warmup() + else: + self._kv_copy_config = None + self._finalize_allocation_log(size) + def _init_kv_copy_and_warmup(self): + # Heuristics for KV copy tiling + _KV_COPY_STRIDE_THRESHOLD_LARGE = 8192 + _KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096 + _KV_COPY_TILE_SIZE_LARGE = 512 + _KV_COPY_TILE_SIZE_MEDIUM = 256 + _KV_COPY_TILE_SIZE_SMALL = 128 + _KV_COPY_NUM_WARPS_LARGE_TILE = 8 + _KV_COPY_NUM_WARPS_SMALL_TILE = 4 + + stride_bytes = int(self.data_strides[0].item()) + if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE: + bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE + elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM: + bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM + else: + bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL + + self._kv_copy_config = { + "bytes_per_tile": bytes_per_tile, + "byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile, + "num_warps": ( + _KV_COPY_NUM_WARPS_SMALL_TILE + if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM + else _KV_COPY_NUM_WARPS_LARGE_TILE + ), + } + + dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device) + grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"]) + + copy_all_layer_kv_cache_tiled[grid]( + self.data_ptrs, + self.data_strides, + dummy_loc, + dummy_loc, + 1, + 1, + BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"], + num_warps=self._kv_copy_config["num_warps"], + num_stages=2, + ) + def _create_buffers(self): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with ( @@ -642,13 +692,28 @@ class MHATokenToKVPool(KVCache): self.v_buffer[layer_id - self.start_layer][loc] = cache_v def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): - copy_all_layer_kv_cache[(len(self.data_ptrs),)]( + N = tgt_loc.numel() + if N == 0: + return + + assert ( + self._kv_copy_config is not None + ), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__" + + cfg = self._kv_copy_config + N_upper = next_power_of_2(N) + grid = (self.data_ptrs.numel(), cfg["byte_tiles"]) + + copy_all_layer_kv_cache_tiled[grid]( self.data_ptrs, self.data_strides, tgt_loc, src_loc, - len(tgt_loc), - next_power_of_2(len(tgt_loc)), + N, + N_upper, + BYTES_PER_TILE=cfg["bytes_per_tile"], + num_warps=cfg["num_warps"], + num_stages=2, ) @@ -1588,38 +1653,36 @@ class DoubleSparseTokenToKVPool(KVCache): @triton.jit -def copy_all_layer_kv_cache( +def copy_all_layer_kv_cache_tiled( data_ptrs, strides, tgt_loc_ptr, src_loc_ptr, num_locs, num_locs_upper: tl.constexpr, + BYTES_PER_TILE: tl.constexpr, ): - BLOCK_SIZE: tl.constexpr = 128 - + """2D tiled kernel. Safe for in-place copy.""" bid = tl.program_id(0) + tid = tl.program_id(1) + stride = tl.load(strides + bid) + base_ptr = tl.load(data_ptrs + bid) + base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8)) - data_ptr = tl.load(data_ptrs + bid) - data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8)) + byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE) + mask_byte = byte_off < stride + tl.multiple_of(byte_off, 16) - num_locs_offset = tl.arange(0, num_locs_upper) - tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs) - src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs) + loc_idx = tl.arange(0, num_locs_upper) + mask_loc = loc_idx < num_locs - # NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks - # because this copy is an inplace operation. + src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0) + tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0) - num_loop = tl.cdiv(stride, BLOCK_SIZE) - for i in range(num_loop): - copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :] - value = tl.load( - data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask - ) - tl.store( - data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :], - value, - mask=mask, - ) + src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :] + tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :] + + mask = mask_loc[:, None] & mask_byte[None, :] + vals = tl.load(src_ptr, mask=mask) + tl.store(tgt_ptr, vals, mask=mask) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e4569ed20..e92fe4250 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1672,6 +1672,9 @@ class ModelRunner: enable_memory_saver=self.server_args.enable_memory_saver, start_layer=self.start_layer, end_layer=self.end_layer, + enable_kv_cache_copy=( + self.server_args.speculative_algorithm is not None + ), ) # Initialize token_to_kv_pool_allocator