Optimize copy_kv_cache for spec decoding (#11126)
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -415,6 +415,7 @@ class MHATokenToKVPool(KVCache):
|
|||||||
enable_memory_saver: bool,
|
enable_memory_saver: bool,
|
||||||
start_layer: Optional[int] = None,
|
start_layer: Optional[int] = None,
|
||||||
end_layer: Optional[int] = None,
|
end_layer: Optional[int] = None,
|
||||||
|
enable_kv_cache_copy: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
size,
|
size,
|
||||||
@@ -446,8 +447,57 @@ class MHATokenToKVPool(KVCache):
|
|||||||
|
|
||||||
self.device_module = torch.get_device_module(self.device)
|
self.device_module = torch.get_device_module(self.device)
|
||||||
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
||||||
|
|
||||||
|
if enable_kv_cache_copy:
|
||||||
|
self._init_kv_copy_and_warmup()
|
||||||
|
else:
|
||||||
|
self._kv_copy_config = None
|
||||||
|
|
||||||
self._finalize_allocation_log(size)
|
self._finalize_allocation_log(size)
|
||||||
|
|
||||||
|
def _init_kv_copy_and_warmup(self):
|
||||||
|
# Heuristics for KV copy tiling
|
||||||
|
_KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
|
||||||
|
_KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
|
||||||
|
_KV_COPY_TILE_SIZE_LARGE = 512
|
||||||
|
_KV_COPY_TILE_SIZE_MEDIUM = 256
|
||||||
|
_KV_COPY_TILE_SIZE_SMALL = 128
|
||||||
|
_KV_COPY_NUM_WARPS_LARGE_TILE = 8
|
||||||
|
_KV_COPY_NUM_WARPS_SMALL_TILE = 4
|
||||||
|
|
||||||
|
stride_bytes = int(self.data_strides[0].item())
|
||||||
|
if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
|
||||||
|
bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
|
||||||
|
elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
|
||||||
|
bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
|
||||||
|
else:
|
||||||
|
bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
|
||||||
|
|
||||||
|
self._kv_copy_config = {
|
||||||
|
"bytes_per_tile": bytes_per_tile,
|
||||||
|
"byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
|
||||||
|
"num_warps": (
|
||||||
|
_KV_COPY_NUM_WARPS_SMALL_TILE
|
||||||
|
if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
|
||||||
|
else _KV_COPY_NUM_WARPS_LARGE_TILE
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
|
||||||
|
grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
|
||||||
|
|
||||||
|
copy_all_layer_kv_cache_tiled[grid](
|
||||||
|
self.data_ptrs,
|
||||||
|
self.data_strides,
|
||||||
|
dummy_loc,
|
||||||
|
dummy_loc,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
|
||||||
|
num_warps=self._kv_copy_config["num_warps"],
|
||||||
|
num_stages=2,
|
||||||
|
)
|
||||||
|
|
||||||
def _create_buffers(self):
|
def _create_buffers(self):
|
||||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||||
with (
|
with (
|
||||||
@@ -642,13 +692,28 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||||
|
|
||||||
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
||||||
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
|
N = tgt_loc.numel()
|
||||||
|
if N == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self._kv_copy_config is not None
|
||||||
|
), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
|
||||||
|
|
||||||
|
cfg = self._kv_copy_config
|
||||||
|
N_upper = next_power_of_2(N)
|
||||||
|
grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
|
||||||
|
|
||||||
|
copy_all_layer_kv_cache_tiled[grid](
|
||||||
self.data_ptrs,
|
self.data_ptrs,
|
||||||
self.data_strides,
|
self.data_strides,
|
||||||
tgt_loc,
|
tgt_loc,
|
||||||
src_loc,
|
src_loc,
|
||||||
len(tgt_loc),
|
N,
|
||||||
next_power_of_2(len(tgt_loc)),
|
N_upper,
|
||||||
|
BYTES_PER_TILE=cfg["bytes_per_tile"],
|
||||||
|
num_warps=cfg["num_warps"],
|
||||||
|
num_stages=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1588,38 +1653,36 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def copy_all_layer_kv_cache(
|
def copy_all_layer_kv_cache_tiled(
|
||||||
data_ptrs,
|
data_ptrs,
|
||||||
strides,
|
strides,
|
||||||
tgt_loc_ptr,
|
tgt_loc_ptr,
|
||||||
src_loc_ptr,
|
src_loc_ptr,
|
||||||
num_locs,
|
num_locs,
|
||||||
num_locs_upper: tl.constexpr,
|
num_locs_upper: tl.constexpr,
|
||||||
|
BYTES_PER_TILE: tl.constexpr,
|
||||||
):
|
):
|
||||||
BLOCK_SIZE: tl.constexpr = 128
|
"""2D tiled kernel. Safe for in-place copy."""
|
||||||
|
|
||||||
bid = tl.program_id(0)
|
bid = tl.program_id(0)
|
||||||
|
tid = tl.program_id(1)
|
||||||
|
|
||||||
stride = tl.load(strides + bid)
|
stride = tl.load(strides + bid)
|
||||||
|
base_ptr = tl.load(data_ptrs + bid)
|
||||||
|
base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
|
||||||
|
|
||||||
data_ptr = tl.load(data_ptrs + bid)
|
byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
|
||||||
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
|
mask_byte = byte_off < stride
|
||||||
|
tl.multiple_of(byte_off, 16)
|
||||||
|
|
||||||
num_locs_offset = tl.arange(0, num_locs_upper)
|
loc_idx = tl.arange(0, num_locs_upper)
|
||||||
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
mask_loc = loc_idx < num_locs
|
||||||
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
|
||||||
|
|
||||||
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
|
src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
|
||||||
# because this copy is an inplace operation.
|
tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
|
||||||
|
|
||||||
num_loop = tl.cdiv(stride, BLOCK_SIZE)
|
src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
|
||||||
for i in range(num_loop):
|
tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
|
||||||
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
|
||||||
mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :]
|
mask = mask_loc[:, None] & mask_byte[None, :]
|
||||||
value = tl.load(
|
vals = tl.load(src_ptr, mask=mask)
|
||||||
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
|
tl.store(tgt_ptr, vals, mask=mask)
|
||||||
)
|
|
||||||
tl.store(
|
|
||||||
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
|
|
||||||
value,
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1672,6 +1672,9 @@ class ModelRunner:
|
|||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
start_layer=self.start_layer,
|
start_layer=self.start_layer,
|
||||||
end_layer=self.end_layer,
|
end_layer=self.end_layer,
|
||||||
|
enable_kv_cache_copy=(
|
||||||
|
self.server_args.speculative_algorithm is not None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize token_to_kv_pool_allocator
|
# Initialize token_to_kv_pool_allocator
|
||||||
|
|||||||
Reference in New Issue
Block a user