Optimize copy_kv_cache for spec decoding (#11126)
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -415,6 +415,7 @@ class MHATokenToKVPool(KVCache):
|
||||
enable_memory_saver: bool,
|
||||
start_layer: Optional[int] = None,
|
||||
end_layer: Optional[int] = None,
|
||||
enable_kv_cache_copy: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
size,
|
||||
@@ -446,8 +447,57 @@ class MHATokenToKVPool(KVCache):
|
||||
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
||||
|
||||
if enable_kv_cache_copy:
|
||||
self._init_kv_copy_and_warmup()
|
||||
else:
|
||||
self._kv_copy_config = None
|
||||
|
||||
self._finalize_allocation_log(size)
|
||||
|
||||
def _init_kv_copy_and_warmup(self):
|
||||
# Heuristics for KV copy tiling
|
||||
_KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
|
||||
_KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
|
||||
_KV_COPY_TILE_SIZE_LARGE = 512
|
||||
_KV_COPY_TILE_SIZE_MEDIUM = 256
|
||||
_KV_COPY_TILE_SIZE_SMALL = 128
|
||||
_KV_COPY_NUM_WARPS_LARGE_TILE = 8
|
||||
_KV_COPY_NUM_WARPS_SMALL_TILE = 4
|
||||
|
||||
stride_bytes = int(self.data_strides[0].item())
|
||||
if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
|
||||
bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
|
||||
elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
|
||||
bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
|
||||
else:
|
||||
bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
|
||||
|
||||
self._kv_copy_config = {
|
||||
"bytes_per_tile": bytes_per_tile,
|
||||
"byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
|
||||
"num_warps": (
|
||||
_KV_COPY_NUM_WARPS_SMALL_TILE
|
||||
if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
|
||||
else _KV_COPY_NUM_WARPS_LARGE_TILE
|
||||
),
|
||||
}
|
||||
|
||||
dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
|
||||
grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
|
||||
|
||||
copy_all_layer_kv_cache_tiled[grid](
|
||||
self.data_ptrs,
|
||||
self.data_strides,
|
||||
dummy_loc,
|
||||
dummy_loc,
|
||||
1,
|
||||
1,
|
||||
BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
|
||||
num_warps=self._kv_copy_config["num_warps"],
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
def _create_buffers(self):
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
with (
|
||||
@@ -642,13 +692,28 @@ class MHATokenToKVPool(KVCache):
|
||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||
|
||||
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
||||
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
|
||||
N = tgt_loc.numel()
|
||||
if N == 0:
|
||||
return
|
||||
|
||||
assert (
|
||||
self._kv_copy_config is not None
|
||||
), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
|
||||
|
||||
cfg = self._kv_copy_config
|
||||
N_upper = next_power_of_2(N)
|
||||
grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
|
||||
|
||||
copy_all_layer_kv_cache_tiled[grid](
|
||||
self.data_ptrs,
|
||||
self.data_strides,
|
||||
tgt_loc,
|
||||
src_loc,
|
||||
len(tgt_loc),
|
||||
next_power_of_2(len(tgt_loc)),
|
||||
N,
|
||||
N_upper,
|
||||
BYTES_PER_TILE=cfg["bytes_per_tile"],
|
||||
num_warps=cfg["num_warps"],
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
|
||||
@@ -1588,38 +1653,36 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def copy_all_layer_kv_cache(
|
||||
def copy_all_layer_kv_cache_tiled(
|
||||
data_ptrs,
|
||||
strides,
|
||||
tgt_loc_ptr,
|
||||
src_loc_ptr,
|
||||
num_locs,
|
||||
num_locs_upper: tl.constexpr,
|
||||
BYTES_PER_TILE: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 128
|
||||
|
||||
"""2D tiled kernel. Safe for in-place copy."""
|
||||
bid = tl.program_id(0)
|
||||
tid = tl.program_id(1)
|
||||
|
||||
stride = tl.load(strides + bid)
|
||||
base_ptr = tl.load(data_ptrs + bid)
|
||||
base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
|
||||
|
||||
data_ptr = tl.load(data_ptrs + bid)
|
||||
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
|
||||
byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
|
||||
mask_byte = byte_off < stride
|
||||
tl.multiple_of(byte_off, 16)
|
||||
|
||||
num_locs_offset = tl.arange(0, num_locs_upper)
|
||||
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
||||
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
||||
loc_idx = tl.arange(0, num_locs_upper)
|
||||
mask_loc = loc_idx < num_locs
|
||||
|
||||
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
|
||||
# because this copy is an inplace operation.
|
||||
src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
|
||||
tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
|
||||
|
||||
num_loop = tl.cdiv(stride, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :]
|
||||
value = tl.load(
|
||||
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
|
||||
)
|
||||
tl.store(
|
||||
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
|
||||
tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
|
||||
|
||||
mask = mask_loc[:, None] & mask_byte[None, :]
|
||||
vals = tl.load(src_ptr, mask=mask)
|
||||
tl.store(tgt_ptr, vals, mask=mask)
|
||||
|
||||
@@ -1672,6 +1672,9 @@ class ModelRunner:
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
start_layer=self.start_layer,
|
||||
end_layer=self.end_layer,
|
||||
enable_kv_cache_copy=(
|
||||
self.server_args.speculative_algorithm is not None
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize token_to_kv_pool_allocator
|
||||
|
||||
Reference in New Issue
Block a user