Kernels for efficient KV cache IO (#7313)
This commit is contained in:
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import concurrent.futures
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import threading
|
import threading
|
||||||
@@ -169,12 +168,23 @@ class HiCacheController:
|
|||||||
page_size: int,
|
page_size: int,
|
||||||
load_cache_event: threading.Event = None,
|
load_cache_event: threading.Event = None,
|
||||||
write_policy: str = "write_through_selective",
|
write_policy: str = "write_through_selective",
|
||||||
|
io_backend: str = "",
|
||||||
):
|
):
|
||||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||||
self.mem_pool_host = mem_pool_host
|
self.mem_pool_host = mem_pool_host
|
||||||
self.write_policy = write_policy
|
self.write_policy = write_policy
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
|
# using kernel for small page KV cache transfer and DMA for large pages
|
||||||
|
if not io_backend:
|
||||||
|
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
|
||||||
|
self.io_backend = (
|
||||||
|
"direct"
|
||||||
|
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
|
||||||
|
else "kernel"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.io_backend = io_backend
|
||||||
|
|
||||||
self.load_cache_event = load_cache_event
|
self.load_cache_event = load_cache_event
|
||||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||||
@@ -203,12 +213,7 @@ class HiCacheController:
|
|||||||
self.load_stream = torch.cuda.Stream()
|
self.load_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
self.write_thread = threading.Thread(
|
self.write_thread = threading.Thread(
|
||||||
target=(
|
target=self.write_thread_func_direct, daemon=True
|
||||||
self.write_thread_func_buffer
|
|
||||||
if self.page_size == 1
|
|
||||||
else self.write_thread_func_direct
|
|
||||||
),
|
|
||||||
daemon=True,
|
|
||||||
)
|
)
|
||||||
self.load_thread = threading.Thread(
|
self.load_thread = threading.Thread(
|
||||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||||
@@ -229,12 +234,7 @@ class HiCacheController:
|
|||||||
self.ack_load_queue.queue.clear()
|
self.ack_load_queue.queue.clear()
|
||||||
|
|
||||||
self.write_thread = threading.Thread(
|
self.write_thread = threading.Thread(
|
||||||
target=(
|
target=self.write_thread_func_direct, daemon=True
|
||||||
self.write_thread_func_buffer
|
|
||||||
if self.page_size == 1
|
|
||||||
else self.write_thread_func_direct
|
|
||||||
),
|
|
||||||
daemon=True,
|
|
||||||
)
|
)
|
||||||
self.load_thread = threading.Thread(
|
self.load_thread = threading.Thread(
|
||||||
target=self.load_thread_func_layer_by_layer, daemon=True
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
||||||
@@ -281,6 +281,15 @@ class HiCacheController:
|
|||||||
)
|
)
|
||||||
return device_indices
|
return device_indices
|
||||||
|
|
||||||
|
def move_indices(self, host_indices, device_indices):
|
||||||
|
# move indices to GPU if using kernels, to host if using direct indexing
|
||||||
|
if self.io_backend == "kernel":
|
||||||
|
return host_indices.to(self.mem_pool_device.device), device_indices
|
||||||
|
elif self.io_backend == "direct":
|
||||||
|
return host_indices, device_indices.cpu()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported io backend")
|
||||||
|
|
||||||
def write_thread_func_direct(self):
|
def write_thread_func_direct(self):
|
||||||
"""
|
"""
|
||||||
Directly write through KV caches to host memory without buffering.
|
Directly write through KV caches to host memory without buffering.
|
||||||
@@ -289,10 +298,14 @@ class HiCacheController:
|
|||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
operation = self.write_queue.get(block=True, timeout=1)
|
operation = self.write_queue.get(block=True, timeout=1)
|
||||||
self.mem_pool_host.write_page_all_layers(
|
host_indices, device_indices = self.move_indices(
|
||||||
operation.host_indices,
|
operation.host_indices, operation.device_indices
|
||||||
operation.device_indices,
|
)
|
||||||
self.mem_pool_device,
|
self.mem_pool_device.backup_to_host_all_layer(
|
||||||
|
self.mem_pool_host,
|
||||||
|
host_indices,
|
||||||
|
device_indices,
|
||||||
|
self.io_backend,
|
||||||
)
|
)
|
||||||
self.write_stream.synchronize()
|
self.write_stream.synchronize()
|
||||||
self.mem_pool_host.complete_io(operation.host_indices)
|
self.mem_pool_host.complete_io(operation.host_indices)
|
||||||
@@ -304,27 +317,6 @@ class HiCacheController:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
|
||||||
def load_thread_func_direct(self):
|
|
||||||
"""
|
|
||||||
Directly load KV caches from host memory to device memory without buffering.
|
|
||||||
"""
|
|
||||||
torch.cuda.set_stream(self.load_stream)
|
|
||||||
while not self.stop_event.is_set():
|
|
||||||
try:
|
|
||||||
operation = self.load_queue.get(block=True, timeout=1)
|
|
||||||
operation.data = self.mem_pool_host.get_flat_data(
|
|
||||||
operation.host_indices
|
|
||||||
)
|
|
||||||
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
|
||||||
self.mem_pool_host.complete_io(operation.host_indices)
|
|
||||||
for node_id in operation.node_ids:
|
|
||||||
if node_id != 0:
|
|
||||||
self.ack_load_queue.put(node_id)
|
|
||||||
except Empty:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
|
|
||||||
def load_thread_func_layer_by_layer(self):
|
def load_thread_func_layer_by_layer(self):
|
||||||
"""
|
"""
|
||||||
Load KV caches from host memory to device memory layer by layer.
|
Load KV caches from host memory to device memory layer by layer.
|
||||||
@@ -349,22 +341,18 @@ class HiCacheController:
|
|||||||
|
|
||||||
# start layer-wise KV cache transfer from CPU to GPU
|
# start layer-wise KV cache transfer from CPU to GPU
|
||||||
self.layer_done_counter.reset()
|
self.layer_done_counter.reset()
|
||||||
|
host_indices, device_indices = self.move_indices(
|
||||||
|
batch_operation.host_indices, batch_operation.device_indices
|
||||||
|
)
|
||||||
for i in range(self.mem_pool_host.layer_num):
|
for i in range(self.mem_pool_host.layer_num):
|
||||||
if self.page_size == 1:
|
self.mem_pool_device.load_from_host_per_layer(
|
||||||
flat_data = self.mem_pool_host.get_flat_data_by_layer(
|
self.mem_pool_host,
|
||||||
batch_operation.host_indices, i
|
host_indices,
|
||||||
)
|
device_indices,
|
||||||
self.mem_pool_device.transfer_per_layer(
|
i,
|
||||||
batch_operation.device_indices, flat_data, i
|
self.io_backend,
|
||||||
)
|
)
|
||||||
else:
|
self.load_stream.synchronize()
|
||||||
self.mem_pool_host.load_page_per_layer(
|
|
||||||
batch_operation.host_indices,
|
|
||||||
batch_operation.device_indices,
|
|
||||||
self.mem_pool_device,
|
|
||||||
i,
|
|
||||||
)
|
|
||||||
self.load_stream.synchronize()
|
|
||||||
self.layer_done_counter.increment()
|
self.layer_done_counter.increment()
|
||||||
|
|
||||||
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
||||||
@@ -372,148 +360,6 @@ class HiCacheController:
|
|||||||
if node_id != 0:
|
if node_id != 0:
|
||||||
self.ack_load_queue.put(node_id)
|
self.ack_load_queue.put(node_id)
|
||||||
|
|
||||||
def write_aux_func(self, no_wait=False):
|
|
||||||
"""
|
|
||||||
Auxiliary function to prepare the buffer for write operations.
|
|
||||||
"""
|
|
||||||
torch.cuda.set_stream(self.write_stream)
|
|
||||||
|
|
||||||
def _to_op(op_):
|
|
||||||
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
|
||||||
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
|
|
||||||
self.mem_pool_host.device
|
|
||||||
)
|
|
||||||
self.write_buffer.put(op_)
|
|
||||||
return op_
|
|
||||||
|
|
||||||
buffer = None
|
|
||||||
while not self.stop_event.is_set():
|
|
||||||
try:
|
|
||||||
operation = self.write_queue.get(block=True, timeout=1)
|
|
||||||
factor = (
|
|
||||||
len(operation.device_indices) // self.write_buffer.max_buffer_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if factor >= 1:
|
|
||||||
if buffer is not None:
|
|
||||||
_to_op(buffer)
|
|
||||||
buffer = None
|
|
||||||
|
|
||||||
if factor < 2:
|
|
||||||
_to_op(operation)
|
|
||||||
else:
|
|
||||||
split_ops = operation.split(factor)
|
|
||||||
for op_ in split_ops:
|
|
||||||
_to_op(op_)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if buffer is None:
|
|
||||||
buffer = operation
|
|
||||||
else:
|
|
||||||
buffer.merge(operation)
|
|
||||||
if (
|
|
||||||
no_wait
|
|
||||||
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
|
||||||
or self.write_queue.empty()
|
|
||||||
or self.write_buffer.empty()
|
|
||||||
):
|
|
||||||
_to_op(buffer)
|
|
||||||
buffer = None
|
|
||||||
except Empty:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
|
|
||||||
def load_aux_func(self):
|
|
||||||
"""
|
|
||||||
Auxiliary function to prepare the buffer for load operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _pin_op(op_, put=True):
|
|
||||||
op_.data = (
|
|
||||||
self.mem_pool_host.get_flat_data(op_.host_indices)
|
|
||||||
.contiguous()
|
|
||||||
.pin_memory()
|
|
||||||
)
|
|
||||||
if put:
|
|
||||||
self.load_buffer.put(op_)
|
|
||||||
return op_
|
|
||||||
|
|
||||||
buffer = None
|
|
||||||
while not self.stop_event.is_set():
|
|
||||||
try:
|
|
||||||
operation = self.load_queue.get(block=True, timeout=1)
|
|
||||||
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
|
|
||||||
|
|
||||||
if factor >= 1:
|
|
||||||
if buffer is not None:
|
|
||||||
_pin_op(buffer)
|
|
||||||
buffer = None
|
|
||||||
|
|
||||||
if factor < 2:
|
|
||||||
_pin_op(operation)
|
|
||||||
else:
|
|
||||||
split_ops = operation.split(factor)
|
|
||||||
split_args = [(op_, True) for op_ in split_ops[:-1]]
|
|
||||||
split_args.append((split_ops[-1], False))
|
|
||||||
# Spawn threads to pin each op concurrently
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
pinned_ops = list(
|
|
||||||
executor.map(
|
|
||||||
lambda x: _pin_op(x[0], put=x[1]), split_args
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# preserve the order of last op to ensure correct ack
|
|
||||||
self.load_buffer.put(pinned_ops[-1])
|
|
||||||
continue
|
|
||||||
|
|
||||||
if buffer is None:
|
|
||||||
buffer = operation
|
|
||||||
else:
|
|
||||||
buffer.merge(operation)
|
|
||||||
if (
|
|
||||||
len(buffer.host_indices) >= self.load_buffer.max_buffer_size
|
|
||||||
or self.load_queue.empty()
|
|
||||||
or self.load_buffer.empty()
|
|
||||||
):
|
|
||||||
_pin_op(buffer)
|
|
||||||
buffer = None
|
|
||||||
except Empty:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
|
|
||||||
# todo (zhiqiang): double buffering to be deprecated
|
|
||||||
def write_thread_func_buffer(self):
|
|
||||||
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
|
||||||
aux_thread.start()
|
|
||||||
|
|
||||||
while not self.stop_event.is_set():
|
|
||||||
operation = self.write_buffer.get()
|
|
||||||
if operation is None:
|
|
||||||
continue
|
|
||||||
self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
|
|
||||||
self.mem_pool_host.complete_io(operation.host_indices)
|
|
||||||
for node_id in operation.node_ids:
|
|
||||||
if node_id != 0:
|
|
||||||
self.ack_write_queue.put(node_id)
|
|
||||||
aux_thread.join()
|
|
||||||
|
|
||||||
def load_thread_func_buffer(self):
|
|
||||||
torch.cuda.set_stream(self.load_stream)
|
|
||||||
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
|
||||||
aux_thread.start()
|
|
||||||
while not self.stop_event.is_set():
|
|
||||||
operation = self.load_buffer.get()
|
|
||||||
if operation is None:
|
|
||||||
continue
|
|
||||||
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
|
||||||
self.mem_pool_host.complete_io(operation.host_indices)
|
|
||||||
for node_id in operation.node_ids:
|
|
||||||
if node_id != 0:
|
|
||||||
self.ack_load_queue.put(node_id)
|
|
||||||
aux_thread.join()
|
|
||||||
|
|
||||||
def evict_device(
|
def evict_device(
|
||||||
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
||||||
) -> int:
|
) -> int:
|
||||||
|
|||||||
@@ -591,6 +591,12 @@ class Scheduler(
|
|||||||
hicache_ratio=server_args.hicache_ratio,
|
hicache_ratio=server_args.hicache_ratio,
|
||||||
hicache_size=server_args.hicache_size,
|
hicache_size=server_args.hicache_size,
|
||||||
hicache_write_policy=server_args.hicache_write_policy,
|
hicache_write_policy=server_args.hicache_write_policy,
|
||||||
|
hicache_io_backend=(
|
||||||
|
"direct"
|
||||||
|
if server_args.attention_backend
|
||||||
|
== "fa3" # hot fix for incompatibility
|
||||||
|
else server_args.hicache_io_backend
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||||
self.tree_cache.cache_controller.layer_done_counter
|
self.tree_cache.cache_controller.layer_done_counter
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class HiRadixCache(RadixCache):
|
|||||||
hicache_ratio: float,
|
hicache_ratio: float,
|
||||||
hicache_size: int,
|
hicache_size: int,
|
||||||
hicache_write_policy: str,
|
hicache_write_policy: str,
|
||||||
|
hicache_io_backend: str,
|
||||||
):
|
):
|
||||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||||
@@ -56,6 +57,7 @@ class HiRadixCache(RadixCache):
|
|||||||
page_size,
|
page_size,
|
||||||
load_cache_event=self.load_cache_event,
|
load_cache_event=self.load_cache_event,
|
||||||
write_policy=hicache_write_policy,
|
write_policy=hicache_write_policy,
|
||||||
|
io_backend=hicache_io_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
# record the nodes with ongoing write through
|
# record the nodes with ongoing write through
|
||||||
|
|||||||
@@ -34,10 +34,11 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
||||||
|
|
||||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -150,13 +151,16 @@ class KVCache(abc.ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_flat_data(self, indices):
|
@abc.abstractmethod
|
||||||
|
def load_from_host_per_layer(
|
||||||
|
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
||||||
|
):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def transfer(self, indices, flat_data):
|
@abc.abstractmethod
|
||||||
raise NotImplementedError()
|
def backup_to_host_all_layer(
|
||||||
|
self, host_pool, host_indices, device_indices, io_backend
|
||||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||||
@@ -247,7 +251,7 @@ class MHATokenToKVPool(KVCache):
|
|||||||
)
|
)
|
||||||
for _ in range(self.layer_num)
|
for _ in range(self.layer_num)
|
||||||
]
|
]
|
||||||
|
self.token_stride = self.head_num * self.head_dim
|
||||||
self.data_ptrs = torch.tensor(
|
self.data_ptrs = torch.tensor(
|
||||||
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
||||||
dtype=torch.uint64,
|
dtype=torch.uint64,
|
||||||
@@ -281,24 +285,24 @@ class MHATokenToKVPool(KVCache):
|
|||||||
# layer_num x [seq_len, head_num, head_dim]
|
# layer_num x [seq_len, head_num, head_dim]
|
||||||
# layer_num x [page_num, page_size, head_num, head_dim]
|
# layer_num x [page_num, page_size, head_num, head_dim]
|
||||||
kv_data_ptrs = [
|
kv_data_ptrs = [
|
||||||
self.get_key_buffer(i).data_ptr()
|
self._get_key_buffer(i).data_ptr()
|
||||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
] + [
|
] + [
|
||||||
self.get_value_buffer(i).data_ptr()
|
self._get_value_buffer(i).data_ptr()
|
||||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
]
|
]
|
||||||
kv_data_lens = [
|
kv_data_lens = [
|
||||||
self.get_key_buffer(i).nbytes
|
self._get_key_buffer(i).nbytes
|
||||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
] + [
|
] + [
|
||||||
self.get_value_buffer(i).nbytes
|
self._get_value_buffer(i).nbytes
|
||||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
]
|
]
|
||||||
kv_item_lens = [
|
kv_item_lens = [
|
||||||
self.get_key_buffer(i)[0].nbytes * self.page_size
|
self._get_key_buffer(i)[0].nbytes * self.page_size
|
||||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
] + [
|
] + [
|
||||||
self.get_value_buffer(i)[0].nbytes * self.page_size
|
self._get_value_buffer(i)[0].nbytes * self.page_size
|
||||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
]
|
]
|
||||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||||
@@ -341,49 +345,73 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# Todo: different memory layout
|
def load_from_host_per_layer(
|
||||||
def get_flat_data(self, indices):
|
self,
|
||||||
# prepare a large chunk of contiguous data for efficient transfer
|
host_pool,
|
||||||
flatten = torch.stack(
|
host_indices,
|
||||||
[
|
device_indices,
|
||||||
torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
|
layer_id,
|
||||||
torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
|
io_backend,
|
||||||
]
|
):
|
||||||
|
transfer_kv_per_layer(
|
||||||
|
src_k=host_pool.k_buffer[layer_id],
|
||||||
|
dst_k=self.k_buffer[layer_id],
|
||||||
|
src_v=host_pool.v_buffer[layer_id],
|
||||||
|
dst_v=self.v_buffer[layer_id],
|
||||||
|
src_indices=host_indices,
|
||||||
|
dst_indices=device_indices,
|
||||||
|
io_backend=io_backend,
|
||||||
|
page_size=self.page_size,
|
||||||
|
item_size=self.token_stride,
|
||||||
)
|
)
|
||||||
return flatten
|
|
||||||
|
|
||||||
@debug_timing
|
def backup_to_host_all_layer(
|
||||||
def transfer(self, indices, flat_data):
|
self, host_pool, host_indices, device_indices, io_backend
|
||||||
# transfer prepared data from host to device
|
):
|
||||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
||||||
k_data, v_data = flat_data[0], flat_data[1]
|
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
||||||
for i in range(self.layer_num):
|
if layer_id - self.start_layer >= len(host_pool.k_buffer):
|
||||||
self.k_buffer[i][indices] = k_data[i]
|
raise ValueError(
|
||||||
self.v_buffer[i][indices] = v_data[i]
|
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
||||||
|
)
|
||||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
transfer_kv_per_layer(
|
||||||
# transfer prepared data from host to device
|
src_k=self.k_buffer[layer_id],
|
||||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
dst_k=host_pool.k_buffer[layer_id],
|
||||||
k_data, v_data = flat_data[0], flat_data[1]
|
src_v=self.v_buffer[layer_id],
|
||||||
self.k_buffer[layer_id - self.start_layer][indices] = k_data
|
dst_v=host_pool.v_buffer[layer_id],
|
||||||
self.v_buffer[layer_id - self.start_layer][indices] = v_data
|
src_indices=device_indices,
|
||||||
|
dst_indices=host_indices,
|
||||||
def get_key_buffer(self, layer_id: int):
|
io_backend=io_backend,
|
||||||
if self.layer_transfer_counter is not None:
|
page_size=self.page_size,
|
||||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
item_size=self.token_stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_key_buffer(self, layer_id: int):
|
||||||
|
# for internal use of referencing
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||||
return self.k_buffer[layer_id - self.start_layer]
|
return self.k_buffer[layer_id - self.start_layer]
|
||||||
|
|
||||||
def get_value_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
|
# note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
|
||||||
|
# it is supposed to be used only by attention backend not for information purpose
|
||||||
|
# same applies to get_value_buffer and get_kv_buffer
|
||||||
if self.layer_transfer_counter is not None:
|
if self.layer_transfer_counter is not None:
|
||||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||||
|
|
||||||
|
return self._get_key_buffer(layer_id)
|
||||||
|
|
||||||
|
def _get_value_buffer(self, layer_id: int):
|
||||||
|
# for internal use of referencing
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||||
return self.v_buffer[layer_id - self.start_layer]
|
return self.v_buffer[layer_id - self.start_layer]
|
||||||
|
|
||||||
|
def get_value_buffer(self, layer_id: int):
|
||||||
|
if self.layer_transfer_counter is not None:
|
||||||
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||||
|
return self._get_value_buffer(layer_id)
|
||||||
|
|
||||||
def get_kv_buffer(self, layer_id: int):
|
def get_kv_buffer(self, layer_id: int):
|
||||||
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||||
|
|
||||||
@@ -761,6 +789,7 @@ class MLATokenToKVPool(KVCache):
|
|||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.token_stride = kv_lora_rank + qk_rope_head_dim
|
||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
|
|
||||||
kv_size = self.get_kv_size_bytes()
|
kv_size = self.get_kv_size_bytes()
|
||||||
@@ -846,21 +875,37 @@ class MLATokenToKVPool(KVCache):
|
|||||||
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_flat_data(self, indices):
|
def load_from_host_per_layer(
|
||||||
# prepare a large chunk of contiguous data for efficient transfer
|
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
||||||
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
):
|
||||||
|
transfer_kv_per_layer_mla(
|
||||||
|
src=host_pool.kv_buffer[layer_id],
|
||||||
|
dst=self.kv_buffer[layer_id],
|
||||||
|
src_indices=host_indices,
|
||||||
|
dst_indices=device_indices,
|
||||||
|
io_backend=io_backend,
|
||||||
|
page_size=self.page_size,
|
||||||
|
item_size=self.token_stride,
|
||||||
|
)
|
||||||
|
|
||||||
@debug_timing
|
def backup_to_host_all_layer(
|
||||||
def transfer(self, indices, flat_data):
|
self, host_pool, host_indices, device_indices, io_backend
|
||||||
# transfer prepared data from host to device
|
):
|
||||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
||||||
for i in range(self.layer_num):
|
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
||||||
self.kv_buffer[i][indices] = flat_data[i]
|
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
|
||||||
|
raise ValueError(
|
||||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
||||||
# transfer prepared data from host to device
|
)
|
||||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
transfer_kv_per_layer_mla(
|
||||||
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
src=self.kv_buffer[layer_id],
|
||||||
|
dst=host_pool.kv_buffer[layer_id],
|
||||||
|
src_indices=device_indices,
|
||||||
|
dst_indices=host_indices,
|
||||||
|
io_backend=io_backend,
|
||||||
|
page_size=self.page_size,
|
||||||
|
item_size=self.token_stride,
|
||||||
|
)
|
||||||
|
|
||||||
def get_cpu_copy(self, indices):
|
def get_cpu_copy(self, indices):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@@ -1046,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||||
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
||||||
|
|
||||||
def get_flat_data(self, indices):
|
def load_from_host_per_layer(
|
||||||
pass
|
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"HiCache not supported for DoubleSparseTokenToKVPool."
|
||||||
|
)
|
||||||
|
|
||||||
def transfer(self, indices, flat_data):
|
def backup_to_host_all_layer(
|
||||||
pass
|
self, host_pool, host_indices, device_indices, io_backend
|
||||||
|
):
|
||||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
raise NotImplementedError(
|
||||||
pass
|
"HiCache not supported for DoubleSparseTokenToKVPool."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
||||||
from sglang.srt.utils import debug_timing
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC):
|
|||||||
def init_kv_buffer(self):
|
def init_kv_buffer(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def transfer(self, indices, flat_data):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_flat_data(self, indices):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_flat_data_by_layer(self, indices, layer_id):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def assign_flat_data(self, indices, flat_data):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@synchronized()
|
@synchronized()
|
||||||
def clear(self):
|
def clear(self):
|
||||||
# Initialize memory states and tracking structures.
|
# Initialize memory states and tracking structures.
|
||||||
@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
@debug_timing
|
@property
|
||||||
def transfer(self, indices, flat_data):
|
def k_buffer(self):
|
||||||
# backup prepared data from device to host
|
return self.kv_buffer[0]
|
||||||
self.kv_buffer[:, :, indices] = flat_data.to(
|
|
||||||
device=self.device, non_blocking=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_flat_data(self, indices):
|
@property
|
||||||
return self.kv_buffer[:, :, indices]
|
def v_buffer(self):
|
||||||
|
return self.kv_buffer[1]
|
||||||
def get_flat_data_by_layer(self, indices, layer_id):
|
|
||||||
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
|
||||||
|
|
||||||
def assign_flat_data(self, indices, flat_data):
|
|
||||||
self.kv_buffer[:, :, indices] = flat_data
|
|
||||||
|
|
||||||
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
|
||||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
|
||||||
for i in range(len(device_indices_cpu)):
|
|
||||||
h_index = host_indices[i * self.page_size]
|
|
||||||
d_index = device_indices_cpu[i]
|
|
||||||
for j in range(self.layer_num):
|
|
||||||
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
|
||||||
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
|
||||||
non_blocking=True,
|
|
||||||
)
|
|
||||||
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
|
||||||
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
|
||||||
non_blocking=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
|
||||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
|
||||||
for i in range(len(device_indices_cpu)):
|
|
||||||
h_index = host_indices[i * self.page_size]
|
|
||||||
d_index = device_indices_cpu[i]
|
|
||||||
device_pool.k_buffer[layer_id - self.start_layer][
|
|
||||||
d_index : d_index + self.page_size
|
|
||||||
].copy_(
|
|
||||||
self.kv_buffer[
|
|
||||||
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
|
||||||
],
|
|
||||||
non_blocking=True,
|
|
||||||
)
|
|
||||||
device_pool.v_buffer[layer_id - self.start_layer][
|
|
||||||
d_index : d_index + self.page_size
|
|
||||||
].copy_(
|
|
||||||
self.kv_buffer[
|
|
||||||
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
|
||||||
],
|
|
||||||
non_blocking=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MLATokenToKVPoolHost(HostKVCache):
|
class MLATokenToKVPoolHost(HostKVCache):
|
||||||
@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
@debug_timing
|
|
||||||
def transfer(self, indices, flat_data):
|
|
||||||
# backup prepared data from device to host
|
|
||||||
self.kv_buffer[:, indices] = flat_data.to(
|
|
||||||
device=self.device, non_blocking=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_flat_data(self, indices):
|
|
||||||
return self.kv_buffer[:, indices]
|
|
||||||
|
|
||||||
def get_flat_data_by_layer(self, indices, layer_id):
|
|
||||||
return self.kv_buffer[layer_id - self.start_layer, indices]
|
|
||||||
|
|
||||||
def assign_flat_data(self, indices, flat_data):
|
|
||||||
self.kv_buffer[:, indices] = flat_data
|
|
||||||
|
|
||||||
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
|
||||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
|
||||||
for i in range(len(device_indices_cpu)):
|
|
||||||
h_index = host_indices[i * self.page_size]
|
|
||||||
d_index = device_indices_cpu[i]
|
|
||||||
for j in range(self.layer_num):
|
|
||||||
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
|
||||||
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
|
||||||
non_blocking=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
|
||||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
|
||||||
for i in range(len(device_indices_cpu)):
|
|
||||||
h_index = host_indices[i * self.page_size]
|
|
||||||
d_index = device_indices_cpu[i]
|
|
||||||
device_pool.kv_buffer[layer_id - self.start_layer][
|
|
||||||
d_index : d_index + self.page_size
|
|
||||||
].copy_(
|
|
||||||
self.kv_buffer[
|
|
||||||
layer_id - self.start_layer, h_index : h_index + self.page_size
|
|
||||||
],
|
|
||||||
non_blocking=True,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
if self.page_size != 1:
|
if self.page_size != 1:
|
||||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||||
|
dtype=torch.int64, copy=True
|
||||||
|
)
|
||||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||||
else:
|
else:
|
||||||
page_aligned_len = len(kv_indices)
|
page_aligned_len = len(kv_indices)
|
||||||
page_aligned_kv_indices = kv_indices.clone()
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||||
|
|
||||||
# Radix Cache takes one ref in memory pool
|
# Radix Cache takes one ref in memory pool
|
||||||
new_prefix_len = self.insert(
|
new_prefix_len = self.insert(
|
||||||
@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
if self.page_size != 1:
|
if self.page_size != 1:
|
||||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||||
|
dtype=torch.int64, copy=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
page_aligned_len = len(kv_indices)
|
page_aligned_len = len(kv_indices)
|
||||||
page_aligned_kv_indices = kv_indices.clone()
|
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||||
page_aligned_token_ids = token_ids[:page_aligned_len]
|
page_aligned_token_ids = token_ids[:page_aligned_len]
|
||||||
|
|
||||||
# Radix Cache takes one ref in memory pool
|
# Radix Cache takes one ref in memory pool
|
||||||
|
|||||||
@@ -217,6 +217,7 @@ class ServerArgs:
|
|||||||
hicache_ratio: float = 2.0
|
hicache_ratio: float = 2.0
|
||||||
hicache_size: int = 0
|
hicache_size: int = 0
|
||||||
hicache_write_policy: str = "write_through_selective"
|
hicache_write_policy: str = "write_through_selective"
|
||||||
|
hicache_io_backend: str = ""
|
||||||
flashinfer_mla_disable_ragged: bool = False
|
flashinfer_mla_disable_ragged: bool = False
|
||||||
disable_shared_experts_fusion: bool = False
|
disable_shared_experts_fusion: bool = False
|
||||||
disable_chunked_prefix_cache: bool = False
|
disable_chunked_prefix_cache: bool = False
|
||||||
@@ -1530,6 +1531,13 @@ class ServerArgs:
|
|||||||
default=ServerArgs.hicache_write_policy,
|
default=ServerArgs.hicache_write_policy,
|
||||||
help="The write policy of hierarchical cache.",
|
help="The write policy of hierarchical cache.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hicache-io-backend",
|
||||||
|
type=str,
|
||||||
|
choices=["direct", "kernel"],
|
||||||
|
default=ServerArgs.hicache_io_backend,
|
||||||
|
help="The IO backend for KV cache transfer between CPU and GPU",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--flashinfer-mla-disable-ragged",
|
"--flashinfer-mla-disable-ragged",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user