diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 06b9f6268..b67f085b2 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -22,12 +22,16 @@ BaseTokenToKVPool maps a token location to its KV cache data. """ import logging +import threading +from enum import IntEnum +from functools import wraps from typing import List, Tuple, Union +import psutil import torch from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import get_compiler_backend +from sglang.srt.utils import debug_timing, get_compiler_backend logger = logging.getLogger(__name__) @@ -213,6 +217,26 @@ class MHATokenToKVPool(BaseTokenToKVPool): del self.k_buffer del self.v_buffer + # Todo: different memory layout + def get_flat_data(self, indices): + # prepare a large chunk of contiguous data for efficient transfer + flatten = torch.stack( + [ + torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]), + torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]), + ] + ) + return flatten + + @debug_timing + def transfer(self, indices, flat_data): + # transfer prepared data from host to device + flat_data = flat_data.to(device=self.device, non_blocking=False) + k_data, v_data = flat_data[0], flat_data[1] + for i in range(self.layer_num): + self.k_buffer[i][indices] = k_data[i] + self.v_buffer[i][indices] = v_data[i] + def get_key_buffer(self, layer_id: int): if self.store_dtype != self.dtype: return self.k_buffer[layer_id].view(self.dtype) @@ -361,3 +385,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool): self.k_buffer[layer_id][loc] = cache_k self.v_buffer[layer_id][loc] = cache_v self.label_buffer[layer_id][loc] = cache_label + + +class MemoryStateInt(IntEnum): + IDLE = 0 + RESERVED = 1 + PROTECTED = 2 + SYNCED = 3 + BACKUP = 4 + + +def synchronized(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) + + return wrapper + + +class MLATokenToKVPoolHost: + + def __init__( + self, + device_pool: MHATokenToKVPool, + host_to_device_ratio: float = 2.0, + pin_memory: bool = False, # no need to use pin memory with the double buffering + device: str = "cpu", + ): + assert ( + host_to_device_ratio >= 1 + ), "The host memory should be larger than the device memory with the current protocol" + # todo, other ways of configuring the size + + self.device_pool = device_pool + self.host_to_device_ratio = host_to_device_ratio + self.pin_memory = pin_memory + self.device = device + + self.size = int(device_pool.size * host_to_device_ratio) + self.dtype = device_pool.store_dtype + self.head_num = device_pool.head_num + self.head_dim = device_pool.head_dim + self.layer_num = device_pool.layer_num + self.size_per_token = ( + self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 + ) + + # Verify there is enough available host memory. + host_mem = psutil.virtual_memory() + requested_bytes = self.size * self.size_per_token + # preserve at least 10GB for other usage + ten_gb = 10 * (1024**3) + if requested_bytes > host_mem.available - ten_gb: + raise ValueError( + f"Not enough host memory available. Requesting " + f"{requested_bytes / 1e9:.2f} GB but only have " + f"{host_mem.available / 1e9:.2f} GB free. Please reduce the " + f"size of the hierarchical cache." + ) + else: + logger.info( + f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." + ) + + self.kv_buffer = torch.empty( + (2, self.layer_num, self.size, self.head_num, self.head_dim), + dtype=self.dtype, + device=self.device, + pin_memory=self.pin_memory, + ) + + # Initialize memory states and tracking structures. + self.mem_state = torch.zeros( + (self.size,), dtype=torch.uint8, device=self.device + ) + self.free_slots = torch.arange(self.size, dtype=torch.int32) + self.can_use_mem_size = self.size + + # A lock for synchronized operations on memory allocation and state transitions. + self.lock = threading.RLock() + + def get_flat_data(self, indices): + return self.kv_buffer[:, :, indices] + + @debug_timing + def transfer(self, indices, flat_data): + # backup prepared data from device to host + self.kv_buffer[:, :, indices] = flat_data.to( + device=self.device, non_blocking=False + ) + + @synchronized + def clear(self): + self.mem_state.fill_(0) + self.can_use_mem_size = self.size + self.free_slots = torch.arange(self.size, dtype=torch.int32) + + @synchronized + def get_state(self, indices: torch.Tensor) -> MemoryStateInt: + assert len(indices) > 0, "The indices should not be empty" + states = self.mem_state[indices] + assert ( + states == states[0] + ).all(), "The memory slots should have the same state {}".format(states) + return MemoryStateInt(states[0].item()) + + @synchronized + def alloc(self, need_size: int) -> torch.Tensor: + if need_size > self.can_use_mem_size: + return None + + # todo: de-fragementation + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + + self.mem_state[select_index] = MemoryStateInt.RESERVED + self.can_use_mem_size -= need_size + + return select_index + + @synchronized + def is_reserved(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.RESERVED + + @synchronized + def is_protected(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.PROTECTED + + @synchronized + def is_synced(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.SYNCED + + @synchronized + def is_backup(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.BACKUP + + @synchronized + def update_backup(self, indices: torch.Tensor): + assert self.is_synced(indices), ( + f"The host memory slots should be in SYNCED state before turning into BACKUP. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.BACKUP + + @synchronized + def update_synced(self, indices: torch.Tensor): + self.mem_state[indices] = MemoryStateInt.SYNCED + + @synchronized + def protect_write(self, indices: torch.Tensor): + assert self.is_reserved(indices), ( + f"The host memory slots should be RESERVED before write operations. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.PROTECTED + + @synchronized + def protect_load(self, indices: torch.Tensor): + assert self.is_backup(indices), ( + f"The host memory slots should be in BACKUP state before load operations. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.PROTECTED + + @synchronized + def complete_io(self, indices: torch.Tensor): + assert self.is_protected(indices), ( + f"The host memory slots should be PROTECTED during I/O operations. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.SYNCED + + def available_size(self): + return len(self.free_slots) + + @synchronized + def free(self, indices: torch.Tensor) -> int: + self.mem_state[indices] = MemoryStateInt.IDLE + self.free_slots = torch.concat([self.free_slots, indices]) + self.can_use_mem_size += len(indices) + return len(indices) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 6f3144ca6..44a5e41a4 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1349,3 +1349,27 @@ class MultiprocessingSerializer: @staticmethod def deserialize(data): return ForkingPickler.loads(data) + + +def debug_timing(func): + # todo: replace with a more organized instrumentation + def wrapper(*args, **kwargs): + if logger.isEnabledFor(logging.DEBUG): + tic = torch.cuda.Event(enable_timing=True) + toc = torch.cuda.Event(enable_timing=True) + tic.record() + result = func(*args, **kwargs) + toc.record() + torch.cuda.synchronize() # Ensure all CUDA operations are complete + elapsed = tic.elapsed_time(toc) + indices = kwargs.get("indices", args[1] if len(args) > 1 else None) + num_tokens = len(indices) if indices is not None else 0 + throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0 + logger.debug( + f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s" + ) + return result + else: + return func(*args, **kwargs) + + return wrapper