Files
sglang/python/sglang/srt/mem_cache/memory_pool_host.py
2025-06-14 21:31:42 -07:00

381 lines
13 KiB
Python

import abc
import logging
import threading
from enum import IntEnum
from functools import wraps
import psutil
import torch
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.utils import debug_timing
logger = logging.getLogger(__name__)
class MemoryStateInt(IntEnum):
IDLE = 0
RESERVED = 1
PROTECTED = 2
SYNCED = 3
BACKUP = 4
def synchronized(debug_only=False):
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if (not debug_only) or self.debug:
return func(self, *args, **kwargs)
with self.lock:
return func(self, *args, **kwargs)
else:
return True
return wrapper
return _decorator
class HostKVCache(abc.ABC):
def __init__(
self,
device_pool: KVCache,
host_to_device_ratio: float,
host_size: int,
pin_memory: bool,
device: str,
page_size: int,
):
self.device_pool = device_pool
self.dtype = device_pool.store_dtype
self.pin_memory = pin_memory
self.device = device
self.page_size = page_size
self.size_per_token = self.get_size_per_token()
if host_size > 0:
self.size = int(host_size * 1e9 // self.size_per_token)
else:
self.size = int(device_pool.size * host_to_device_ratio)
# Align the host memory pool size to the page size
self.size = self.size - (self.size % self.page_size)
self.start_layer = device_pool.start_layer
self.end_layer = device_pool.end_layer
assert (
self.size > device_pool.size
), "The host memory should be larger than the device memory with the current protocol"
# Verify there is enough available host memory.
host_mem = psutil.virtual_memory()
requested_bytes = self.size * self.size_per_token
# preserve at least 10GB for other usage
ten_gb = 10 * (1024**3)
if requested_bytes > host_mem.available - ten_gb:
raise ValueError(
f"Not enough host memory available. Requesting "
f"{requested_bytes / 1e9:.2f} GB but only have "
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
f"size of the hierarchical cache."
)
else:
logger.info(
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
)
self.kv_buffer = self.init_kv_buffer()
# A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock()
self.debug = logger.isEnabledFor(logging.DEBUG)
self.clear()
@abc.abstractmethod
def get_size_per_token(self):
raise NotImplementedError()
@abc.abstractmethod
def init_kv_buffer(self):
raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data(self, indices):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data_by_layer(self, indices, layer_id):
raise NotImplementedError()
@abc.abstractmethod
def assign_flat_data(self, indices, flat_data):
raise NotImplementedError()
@synchronized()
def clear(self):
# Initialize memory states and tracking structures.
self.mem_state = torch.zeros(
(self.size,), dtype=torch.uint8, device=self.device
)
self.free_slots = torch.arange(self.size, dtype=torch.int64)
def available_size(self):
return len(self.free_slots)
@synchronized()
def alloc(self, need_size: int) -> torch.Tensor:
if need_size > self.available_size():
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
if self.debug:
self.mem_state[select_index] = MemoryStateInt.RESERVED
return select_index
@synchronized()
def free(self, indices: torch.Tensor) -> int:
self.free_slots = torch.cat([self.free_slots, indices])
if self.debug:
self.mem_state[indices] = MemoryStateInt.IDLE
return len(indices)
@synchronized(debug_only=True)
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
assert len(indices) > 0, "The indices should not be empty"
states = self.mem_state[indices]
assert (
states == states[0]
).all(), "The memory slots should have the same state {}".format(states)
return MemoryStateInt(states[0].item())
@synchronized(debug_only=True)
def is_reserved(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.RESERVED
@synchronized(debug_only=True)
def is_protected(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def is_synced(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def is_backup(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_backup(self, indices: torch.Tensor):
if not self.is_synced(indices):
raise ValueError(
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_synced(self, indices: torch.Tensor):
self.mem_state[indices] = MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def protect_write(self, indices: torch.Tensor):
if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be RESERVED before write operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def protect_load(self, indices: torch.Tensor):
if not self.is_backup(indices):
raise ValueError(
f"The host memory slots should be in BACKUP state before load operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def complete_io(self, indices: torch.Tensor):
if not self.is_protected(indices):
raise ValueError(
f"The host memory slots should be PROTECTED during I/O operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.SYNCED
class MHATokenToKVPoolHost(HostKVCache):
device_pool: MHATokenToKVPool
def __init__(
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float,
host_size: int,
page_size: int,
pin_memory: bool = True,
device: str = "cpu",
):
super().__init__(
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
)
def get_size_per_token(self):
self.head_num = self.device_pool.head_num
self.head_dim = self.device_pool.head_dim
self.layer_num = self.device_pool.layer_num
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
def init_kv_buffer(self):
return torch.empty(
(2, self.layer_num, self.size, self.head_num, self.head_dim),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, :, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
device_pool.k_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
device_pool.v_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.k_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
0, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
device_pool.v_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
1, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
class MLATokenToKVPoolHost(HostKVCache):
device_pool: MLATokenToKVPool
def __init__(
self,
device_pool: MLATokenToKVPool,
host_to_device_ratio: float,
host_size: int,
page_size: int,
pin_memory: bool = True,
device: str = "cpu",
):
super().__init__(
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
)
def get_size_per_token(self):
self.kv_lora_rank = self.device_pool.kv_lora_rank
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
self.layer_num = self.device_pool.layer_num
return (
(self.kv_lora_rank + self.qk_rope_head_dim)
* 1
* self.dtype.itemsize
* self.layer_num
)
def init_kv_buffer(self):
return torch.empty(
(
self.layer_num,
self.size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.kv_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)