Host memory pool for hierarchical caching (#2771)
This commit is contained in:
@@ -22,12 +22,16 @@ BaseTokenToKVPool maps a token location to its KV cache data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from enum import IntEnum
|
||||
from functools import wraps
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
from sglang.srt.utils import debug_timing, get_compiler_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -213,6 +217,26 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
del self.k_buffer
|
||||
del self.v_buffer
|
||||
|
||||
# Todo: different memory layout
|
||||
def get_flat_data(self, indices):
|
||||
# prepare a large chunk of contiguous data for efficient transfer
|
||||
flatten = torch.stack(
|
||||
[
|
||||
torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
|
||||
torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
|
||||
]
|
||||
)
|
||||
return flatten
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
k_data, v_data = flat_data[0], flat_data[1]
|
||||
for i in range(self.layer_num):
|
||||
self.k_buffer[i][indices] = k_data[i]
|
||||
self.v_buffer[i][indices] = v_data[i]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.k_buffer[layer_id].view(self.dtype)
|
||||
@@ -361,3 +385,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
||||
self.k_buffer[layer_id][loc] = cache_k
|
||||
self.v_buffer[layer_id][loc] = cache_v
|
||||
self.label_buffer[layer_id][loc] = cache_label
|
||||
|
||||
|
||||
class MemoryStateInt(IntEnum):
|
||||
IDLE = 0
|
||||
RESERVED = 1
|
||||
PROTECTED = 2
|
||||
SYNCED = 3
|
||||
BACKUP = 4
|
||||
|
||||
|
||||
def synchronized(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self.lock:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MLATokenToKVPoolHost:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
host_to_device_ratio: float = 2.0,
|
||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||
device: str = "cpu",
|
||||
):
|
||||
assert (
|
||||
host_to_device_ratio >= 1
|
||||
), "The host memory should be larger than the device memory with the current protocol"
|
||||
# todo, other ways of configuring the size
|
||||
|
||||
self.device_pool = device_pool
|
||||
self.host_to_device_ratio = host_to_device_ratio
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
self.size = int(device_pool.size * host_to_device_ratio)
|
||||
self.dtype = device_pool.store_dtype
|
||||
self.head_num = device_pool.head_num
|
||||
self.head_dim = device_pool.head_dim
|
||||
self.layer_num = device_pool.layer_num
|
||||
self.size_per_token = (
|
||||
self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
||||
)
|
||||
|
||||
# Verify there is enough available host memory.
|
||||
host_mem = psutil.virtual_memory()
|
||||
requested_bytes = self.size * self.size_per_token
|
||||
# preserve at least 10GB for other usage
|
||||
ten_gb = 10 * (1024**3)
|
||||
if requested_bytes > host_mem.available - ten_gb:
|
||||
raise ValueError(
|
||||
f"Not enough host memory available. Requesting "
|
||||
f"{requested_bytes / 1e9:.2f} GB but only have "
|
||||
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
|
||||
f"size of the hierarchical cache."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
||||
)
|
||||
|
||||
self.kv_buffer = torch.empty(
|
||||
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
# Initialize memory states and tracking structures.
|
||||
self.mem_state = torch.zeros(
|
||||
(self.size,), dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.free_slots = torch.arange(self.size, dtype=torch.int32)
|
||||
self.can_use_mem_size = self.size
|
||||
|
||||
# A lock for synchronized operations on memory allocation and state transitions.
|
||||
self.lock = threading.RLock()
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, :, indices]
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# backup prepared data from device to host
|
||||
self.kv_buffer[:, :, indices] = flat_data.to(
|
||||
device=self.device, non_blocking=False
|
||||
)
|
||||
|
||||
@synchronized
|
||||
def clear(self):
|
||||
self.mem_state.fill_(0)
|
||||
self.can_use_mem_size = self.size
|
||||
self.free_slots = torch.arange(self.size, dtype=torch.int32)
|
||||
|
||||
@synchronized
|
||||
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
||||
assert len(indices) > 0, "The indices should not be empty"
|
||||
states = self.mem_state[indices]
|
||||
assert (
|
||||
states == states[0]
|
||||
).all(), "The memory slots should have the same state {}".format(states)
|
||||
return MemoryStateInt(states[0].item())
|
||||
|
||||
@synchronized
|
||||
def alloc(self, need_size: int) -> torch.Tensor:
|
||||
if need_size > self.can_use_mem_size:
|
||||
return None
|
||||
|
||||
# todo: de-fragementation
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
||||
self.can_use_mem_size -= need_size
|
||||
|
||||
return select_index
|
||||
|
||||
@synchronized
|
||||
def is_reserved(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.RESERVED
|
||||
|
||||
@synchronized
|
||||
def is_protected(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized
|
||||
def is_synced(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.SYNCED
|
||||
|
||||
@synchronized
|
||||
def is_backup(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized
|
||||
def update_backup(self, indices: torch.Tensor):
|
||||
assert self.is_synced(indices), (
|
||||
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized
|
||||
def update_synced(self, indices: torch.Tensor):
|
||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||
|
||||
@synchronized
|
||||
def protect_write(self, indices: torch.Tensor):
|
||||
assert self.is_reserved(indices), (
|
||||
f"The host memory slots should be RESERVED before write operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized
|
||||
def protect_load(self, indices: torch.Tensor):
|
||||
assert self.is_backup(indices), (
|
||||
f"The host memory slots should be in BACKUP state before load operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized
|
||||
def complete_io(self, indices: torch.Tensor):
|
||||
assert self.is_protected(indices), (
|
||||
f"The host memory slots should be PROTECTED during I/O operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
@synchronized
|
||||
def free(self, indices: torch.Tensor) -> int:
|
||||
self.mem_state[indices] = MemoryStateInt.IDLE
|
||||
self.free_slots = torch.concat([self.free_slots, indices])
|
||||
self.can_use_mem_size += len(indices)
|
||||
return len(indices)
|
||||
|
||||
@@ -1349,3 +1349,27 @@ class MultiprocessingSerializer:
|
||||
@staticmethod
|
||||
def deserialize(data):
|
||||
return ForkingPickler.loads(data)
|
||||
|
||||
|
||||
def debug_timing(func):
|
||||
# todo: replace with a more organized instrumentation
|
||||
def wrapper(*args, **kwargs):
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
tic = torch.cuda.Event(enable_timing=True)
|
||||
toc = torch.cuda.Event(enable_timing=True)
|
||||
tic.record()
|
||||
result = func(*args, **kwargs)
|
||||
toc.record()
|
||||
torch.cuda.synchronize() # Ensure all CUDA operations are complete
|
||||
elapsed = tic.elapsed_time(toc)
|
||||
indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
|
||||
num_tokens = len(indices) if indices is not None else 0
|
||||
throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
Reference in New Issue
Block a user