""" Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """ Memory pool. SGLang has two levels of memory pool. ReqToTokenPool maps a a request to its token locations. BaseTokenToKVPool maps a token location to its KV cache data. """ import logging import threading from enum import IntEnum from functools import wraps from typing import List, Tuple, Union import numpy as np import psutil import torch from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.utils import debug_timing, get_compiler_backend logger = logging.getLogger(__name__) GB = 1024 * 1024 * 1024 class ReqToTokenPool: """A memory pool that maps a request to its token locations.""" def __init__(self, size: int, max_context_len: int, device: str, use_records: bool): self.size = size self.max_context_len = max_context_len self.device = device self.req_to_token = torch.zeros( (size, max_context_len), dtype=torch.int32, device=device ) self.free_slots = list(range(size)) self.write_records = [] self.use_records = use_records if self.use_records: self.write = self.write_with_records else: self.write = self.write_without_records def write(self, indices, values): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() def available_size(self): return len(self.free_slots) def alloc(self, need_size: int) -> List[int]: if need_size > len(self.free_slots): return None select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] return select_index def free(self, free_index: Union[int, List[int]]): if isinstance(free_index, (int,)): self.free_slots.append(free_index) else: self.free_slots.extend(free_index) def clear(self): self.free_slots = list(range(self.size)) self.write_records = [] def write_without_records(self, indices, values): self.req_to_token[indices] = values def write_with_records(self, indices, values): self.req_to_token[indices] = values self.write_records.append((indices, values)) def get_write_records(self): ret = self.write_records self.write_records = [] return ret def apply_write_records(self, write_records: List[Tuple]): for indices, values in write_records: self.req_to_token[indices] = values class BaseTokenToKVPool: """A memory pool that maps a token location to its kv cache data.""" def __init__( self, size: int, dtype: torch.dtype, device: str, ): self.size = size self.dtype = dtype if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 self.store_dtype = torch.uint8 else: self.store_dtype = dtype self.device = device self.free_slots = None self.is_not_in_free_group = True self.free_group = [] self.clear() def available_size(self): return len(self.free_slots) def alloc(self, need_size: int): if need_size > len(self.free_slots): return None select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] return select_index.to(self.device, non_blocking=True) def free(self, free_index: torch.Tensor): if free_index.numel() == 0: return if self.is_not_in_free_group: self.free_slots = torch.concat((self.free_slots, free_index.cpu())) else: self.free_group.append(free_index) def free_group_begin(self): self.is_not_in_free_group = False self.free_group = [] def free_group_end(self): self.is_not_in_free_group = True if self.free_group: self.free(torch.concat(self.free_group)) def clear(self): # The padded slot 0 is used for writing dummy outputs from padded tokens. self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32) self.is_in_free_group = False self.free_group = [] def get_key_buffer(self, layer_id: int) -> torch.Tensor: raise NotImplementedError() def get_value_buffer(self, layer_id: int) -> torch.Tensor: raise NotImplementedError() def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError() def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, ) -> None: raise NotImplementedError() class MHATokenToKVPool(BaseTokenToKVPool): def __init__( self, size: int, dtype: torch.dtype, head_num: int, head_dim: int, layer_num: int, device: str, ): super().__init__(size, dtype, device) self.head_num = head_num self.head_dim = head_dim self.layer_num = layer_num self._create_buffers() k_size, v_size = self.get_kv_size_bytes() logger.info( f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB." ) def _create_buffers(self): # [size, head_num, head_dim] for each layer # The padded slot 0 is used for writing dummy outputs from padded tokens. self.k_buffer = [ torch.empty( (self.size + 1, self.head_num, self.head_dim), dtype=self.store_dtype, device=self.device, ) for _ in range(self.layer_num) ] self.v_buffer = [ torch.empty( (self.size + 1, self.head_num, self.head_dim), dtype=self.store_dtype, device=self.device, ) for _ in range(self.layer_num) ] def _clear_buffers(self): del self.k_buffer del self.v_buffer def get_kv_size_bytes(self): assert hasattr(self, "k_buffer") assert hasattr(self, "v_buffer") k_size_bytes = 0 for k_cache in self.k_buffer: k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize v_size_bytes = 0 for v_cache in self.v_buffer: v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize return k_size_bytes, v_size_bytes # Todo: different memory layout def get_flat_data(self, indices): # prepare a large chunk of contiguous data for efficient transfer flatten = torch.stack( [ torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]), torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]), ] ) return flatten @debug_timing def transfer(self, indices, flat_data): # transfer prepared data from host to device flat_data = flat_data.to(device=self.device, non_blocking=False) k_data, v_data = flat_data[0], flat_data[1] for i in range(self.layer_num): self.k_buffer[i][indices] = k_data[i] self.v_buffer[i][indices] = v_data[i] def get_key_buffer(self, layer_id: int): if self.store_dtype != self.dtype: return self.k_buffer[layer_id].view(self.dtype) return self.k_buffer[layer_id] def get_value_buffer(self, layer_id: int): if self.store_dtype != self.dtype: return self.v_buffer[layer_id].view(self.dtype) return self.v_buffer[layer_id] def get_kv_buffer(self, layer_id: int): return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, k_scale: float = 1.0, v_scale: float = 1.0, ): layer_id = layer.layer_id if cache_k.dtype != self.dtype: cache_k = (cache_k / k_scale).to(self.dtype) cache_v = (cache_v / v_scale).to(self.dtype) if self.store_dtype != self.dtype: self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) else: self.k_buffer[layer_id][loc] = cache_k self.v_buffer[layer_id][loc] = cache_v # This compiled version is slower in the unit test # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size @torch.compile(dynamic=True, backend=get_compiler_backend()) def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): dst_1[loc] = src_1.to(dtype).view(store_dtype) dst_2[loc] = src_2.to(dtype).view(store_dtype) class MLATokenToKVPool(BaseTokenToKVPool): def __init__( self, size: int, dtype: torch.dtype, kv_lora_rank: int, qk_rope_head_dim: int, layer_num: int, device: str, ): super().__init__(size, dtype, device) self.kv_lora_rank = kv_lora_rank # The padded slot 0 is used for writing dummy outputs from padded tokens. self.kv_buffer = [ torch.empty( (size + 1, 1, kv_lora_rank + qk_rope_head_dim), dtype=self.store_dtype, device=device, ) for _ in range(layer_num) ] def get_key_buffer(self, layer_id: int): if self.store_dtype != self.dtype: return self.kv_buffer[layer_id].view(self.dtype) return self.kv_buffer[layer_id] def get_value_buffer(self, layer_id: int): if self.store_dtype != self.dtype: return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype) return self.kv_buffer[layer_id][..., : self.kv_lora_rank] def get_kv_buffer(self, layer_id: int): return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, ): layer_id = layer.layer_id if cache_k.dtype != self.dtype: cache_k = cache_k.to(self.dtype) if self.store_dtype != self.dtype: self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype) else: self.kv_buffer[layer_id][loc] = cache_k class DoubleSparseTokenToKVPool(BaseTokenToKVPool): def __init__( self, size: int, dtype: torch.dtype, head_num: int, head_dim: int, layer_num: int, device: str, heavy_channel_num: int, ): super().__init__(size, dtype, device) # [size, head_num, head_dim] for each layer self.k_buffer = [ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) ] self.v_buffer = [ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) ] # [size, head_num, heavy_channel_num] for each layer self.label_buffer = [ torch.empty( (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device ) for _ in range(layer_num) ] def get_key_buffer(self, layer_id: int): return self.k_buffer[layer_id] def get_value_buffer(self, layer_id: int): return self.v_buffer[layer_id] def get_label_buffer(self, layer_id: int): return self.label_buffer[layer_id] def get_kv_buffer(self, layer_id: int): return self.k_buffer[layer_id], self.v_buffer[layer_id] def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, cache_label: torch.Tensor, ): # NOTE(Andy): ignore the dtype check layer_id = layer.layer_id self.k_buffer[layer_id][loc] = cache_k self.v_buffer[layer_id][loc] = cache_v self.label_buffer[layer_id][loc] = cache_label class MemoryStateInt(IntEnum): IDLE = 0 RESERVED = 1 PROTECTED = 2 SYNCED = 3 BACKUP = 4 def synchronized(func): @wraps(func) def wrapper(self, *args, **kwargs): with self.lock: return func(self, *args, **kwargs) return wrapper class MLATokenToKVPoolHost: def __init__( self, device_pool: MHATokenToKVPool, host_to_device_ratio: float = 2.0, pin_memory: bool = False, # no need to use pin memory with the double buffering device: str = "cpu", ): assert ( host_to_device_ratio >= 1 ), "The host memory should be larger than the device memory with the current protocol" # todo, other ways of configuring the size self.device_pool = device_pool self.host_to_device_ratio = host_to_device_ratio self.pin_memory = pin_memory self.device = device self.size = int(device_pool.size * host_to_device_ratio) self.dtype = device_pool.store_dtype self.head_num = device_pool.head_num self.head_dim = device_pool.head_dim self.layer_num = device_pool.layer_num self.size_per_token = ( self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 ) # Verify there is enough available host memory. host_mem = psutil.virtual_memory() requested_bytes = self.size * self.size_per_token # preserve at least 10GB for other usage ten_gb = 10 * (1024**3) if requested_bytes > host_mem.available - ten_gb: raise ValueError( f"Not enough host memory available. Requesting " f"{requested_bytes / 1e9:.2f} GB but only have " f"{host_mem.available / 1e9:.2f} GB free. Please reduce the " f"size of the hierarchical cache." ) else: logger.info( f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." ) self.kv_buffer = torch.empty( (2, self.layer_num, self.size, self.head_num, self.head_dim), dtype=self.dtype, device=self.device, pin_memory=self.pin_memory, ) # Initialize memory states and tracking structures. self.mem_state = torch.zeros( (self.size,), dtype=torch.uint8, device=self.device ) self.free_slots = torch.arange(self.size, dtype=torch.int32) self.can_use_mem_size = self.size # A lock for synchronized operations on memory allocation and state transitions. self.lock = threading.RLock() def get_flat_data(self, indices): return self.kv_buffer[:, :, indices] @debug_timing def transfer(self, indices, flat_data): # backup prepared data from device to host self.kv_buffer[:, :, indices] = flat_data.to( device=self.device, non_blocking=False ) @synchronized def clear(self): self.mem_state.fill_(0) self.can_use_mem_size = self.size self.free_slots = torch.arange(self.size, dtype=torch.int32) @synchronized def get_state(self, indices: torch.Tensor) -> MemoryStateInt: assert len(indices) > 0, "The indices should not be empty" states = self.mem_state[indices] assert ( states == states[0] ).all(), "The memory slots should have the same state {}".format(states) return MemoryStateInt(states[0].item()) @synchronized def alloc(self, need_size: int) -> torch.Tensor: if need_size > self.can_use_mem_size: return None # todo: de-fragementation select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] self.mem_state[select_index] = MemoryStateInt.RESERVED self.can_use_mem_size -= need_size return select_index @synchronized def is_reserved(self, indices: torch.Tensor) -> bool: return self.get_state(indices) == MemoryStateInt.RESERVED @synchronized def is_protected(self, indices: torch.Tensor) -> bool: return self.get_state(indices) == MemoryStateInt.PROTECTED @synchronized def is_synced(self, indices: torch.Tensor) -> bool: return self.get_state(indices) == MemoryStateInt.SYNCED @synchronized def is_backup(self, indices: torch.Tensor) -> bool: return self.get_state(indices) == MemoryStateInt.BACKUP @synchronized def update_backup(self, indices: torch.Tensor): assert self.is_synced(indices), ( f"The host memory slots should be in SYNCED state before turning into BACKUP. " f"Current state: {self.get_state(indices)}" ) self.mem_state[indices] = MemoryStateInt.BACKUP @synchronized def update_synced(self, indices: torch.Tensor): self.mem_state[indices] = MemoryStateInt.SYNCED @synchronized def protect_write(self, indices: torch.Tensor): assert self.is_reserved(indices), ( f"The host memory slots should be RESERVED before write operations. " f"Current state: {self.get_state(indices)}" ) self.mem_state[indices] = MemoryStateInt.PROTECTED @synchronized def protect_load(self, indices: torch.Tensor): assert self.is_backup(indices), ( f"The host memory slots should be in BACKUP state before load operations. " f"Current state: {self.get_state(indices)}" ) self.mem_state[indices] = MemoryStateInt.PROTECTED @synchronized def complete_io(self, indices: torch.Tensor): assert self.is_protected(indices), ( f"The host memory slots should be PROTECTED during I/O operations. " f"Current state: {self.get_state(indices)}" ) self.mem_state[indices] = MemoryStateInt.SYNCED def available_size(self): return len(self.free_slots) @synchronized def free(self, indices: torch.Tensor) -> int: self.mem_state[indices] = MemoryStateInt.IDLE self.free_slots = torch.concat([self.free_slots, indices]) self.can_use_mem_size += len(indices) return len(indices)