diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 12d924998..9f7f48cda 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -22,7 +22,8 @@ from typing import List, Optional import torch -from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator +from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator +from sglang.srt.mem_cache.memory_pool_host import HostKVCache logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 4bec901aa..03a4417cf 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -9,12 +9,14 @@ import torch from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, - MHATokenToKVPoolHost, MLATokenToKVPool, - MLATokenToKVPoolHost, ReqToTokenPool, TokenToKVPoolAllocator, ) +from sglang.srt.mem_cache.memory_pool_host import ( + MHATokenToKVPoolHost, + MLATokenToKVPoolHost, +) from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 8ae7ba6b1..7e08007ed 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -26,24 +26,15 @@ KVCache actually holds the physical kv cache. import abc import logging -import threading -from enum import IntEnum -from functools import wraps from typing import List, Optional, Tuple, Union import numpy as np -import psutil import torch import triton import triton.language as tl from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import ( - debug_timing, - get_compiler_backend, - is_cuda, - next_power_of_2, -) +from sglang.srt.utils import debug_timing, get_compiler_backend, is_cuda logger = logging.getLogger(__name__) @@ -772,370 +763,3 @@ class DoubleSparseTokenToKVPool(KVCache): def transfer_per_layer(self, indices, flat_data, layer_id): pass - - -class MemoryStateInt(IntEnum): - IDLE = 0 - RESERVED = 1 - PROTECTED = 2 - SYNCED = 3 - BACKUP = 4 - - -def synchronized(debug_only=False): - def _decorator(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if (not debug_only) or self.debug: - return func(self, *args, **kwargs) - with self.lock: - return func(self, *args, **kwargs) - else: - return True - - return wrapper - - return _decorator - - -class HostKVCache(abc.ABC): - - def __init__( - self, - device_pool: KVCache, - host_to_device_ratio: float, - host_size: int, - pin_memory: bool, - device: str, - page_size: int, - ): - self.device_pool = device_pool - self.dtype = device_pool.store_dtype - self.pin_memory = pin_memory - self.device = device - self.page_size = page_size - self.size_per_token = self.get_size_per_token() - if host_size > 0: - self.size = int(host_size * 1e9 // self.size_per_token) - else: - self.size = int(device_pool.size * host_to_device_ratio) - # Align the host memory pool size to the page size - self.size = self.size - (self.size % self.page_size) - self.start_layer = device_pool.start_layer - self.end_layer = device_pool.end_layer - - assert ( - self.size > device_pool.size - ), "The host memory should be larger than the device memory with the current protocol" - - # Verify there is enough available host memory. - host_mem = psutil.virtual_memory() - requested_bytes = self.size * self.size_per_token - # preserve at least 10GB for other usage - ten_gb = 10 * (1024**3) - if requested_bytes > host_mem.available - ten_gb: - raise ValueError( - f"Not enough host memory available. Requesting " - f"{requested_bytes / 1e9:.2f} GB but only have " - f"{host_mem.available / 1e9:.2f} GB free. Please reduce the " - f"size of the hierarchical cache." - ) - else: - logger.info( - f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." - ) - - self.kv_buffer = self.init_kv_buffer() - - # A lock for synchronized operations on memory allocation and state transitions. - self.lock = threading.RLock() - self.debug = logger.isEnabledFor(logging.DEBUG) - self.clear() - - @abc.abstractmethod - def get_size_per_token(self): - raise NotImplementedError() - - @abc.abstractmethod - def init_kv_buffer(self): - raise NotImplementedError() - - @abc.abstractmethod - def transfer(self, indices, flat_data): - raise NotImplementedError() - - @abc.abstractmethod - def get_flat_data(self, indices): - raise NotImplementedError() - - @abc.abstractmethod - def get_flat_data_by_layer(self, indices, layer_id): - raise NotImplementedError() - - @abc.abstractmethod - def assign_flat_data(self, indices, flat_data): - raise NotImplementedError() - - @synchronized() - def clear(self): - # Initialize memory states and tracking structures. - self.mem_state = torch.zeros( - (self.size,), dtype=torch.uint8, device=self.device - ) - self.free_slots = torch.arange(self.size, dtype=torch.int64) - - def available_size(self): - return len(self.free_slots) - - @synchronized() - def alloc(self, need_size: int) -> torch.Tensor: - if need_size > self.available_size(): - return None - - select_index = self.free_slots[:need_size] - self.free_slots = self.free_slots[need_size:] - - if self.debug: - self.mem_state[select_index] = MemoryStateInt.RESERVED - - return select_index - - @synchronized() - def free(self, indices: torch.Tensor) -> int: - self.free_slots = torch.cat([self.free_slots, indices]) - if self.debug: - self.mem_state[indices] = MemoryStateInt.IDLE - return len(indices) - - @synchronized(debug_only=True) - def get_state(self, indices: torch.Tensor) -> MemoryStateInt: - assert len(indices) > 0, "The indices should not be empty" - states = self.mem_state[indices] - assert ( - states == states[0] - ).all(), "The memory slots should have the same state {}".format(states) - return MemoryStateInt(states[0].item()) - - @synchronized(debug_only=True) - def is_reserved(self, indices: torch.Tensor) -> bool: - return self.get_state(indices) == MemoryStateInt.RESERVED - - @synchronized(debug_only=True) - def is_protected(self, indices: torch.Tensor) -> bool: - return self.get_state(indices) == MemoryStateInt.PROTECTED - - @synchronized(debug_only=True) - def is_synced(self, indices: torch.Tensor) -> bool: - return self.get_state(indices) == MemoryStateInt.SYNCED - - @synchronized(debug_only=True) - def is_backup(self, indices: torch.Tensor) -> bool: - return self.get_state(indices) == MemoryStateInt.BACKUP - - @synchronized(debug_only=True) - def update_backup(self, indices: torch.Tensor): - if not self.is_synced(indices): - raise ValueError( - f"The host memory slots should be in SYNCED state before turning into BACKUP. " - f"Current state: {self.get_state(indices)}" - ) - self.mem_state[indices] = MemoryStateInt.BACKUP - - @synchronized(debug_only=True) - def update_synced(self, indices: torch.Tensor): - self.mem_state[indices] = MemoryStateInt.SYNCED - - @synchronized(debug_only=True) - def protect_write(self, indices: torch.Tensor): - if not self.is_reserved(indices): - raise ValueError( - f"The host memory slots should be RESERVED before write operations. " - f"Current state: {self.get_state(indices)}" - ) - self.mem_state[indices] = MemoryStateInt.PROTECTED - - @synchronized(debug_only=True) - def protect_load(self, indices: torch.Tensor): - if not self.is_backup(indices): - raise ValueError( - f"The host memory slots should be in BACKUP state before load operations. " - f"Current state: {self.get_state(indices)}" - ) - self.mem_state[indices] = MemoryStateInt.PROTECTED - - @synchronized(debug_only=True) - def complete_io(self, indices: torch.Tensor): - if not self.is_protected(indices): - raise ValueError( - f"The host memory slots should be PROTECTED during I/O operations. " - f"Current state: {self.get_state(indices)}" - ) - self.mem_state[indices] = MemoryStateInt.SYNCED - - -class MHATokenToKVPoolHost(HostKVCache): - device_pool: MHATokenToKVPool - - def __init__( - self, - device_pool: MHATokenToKVPool, - host_to_device_ratio: float, - host_size: int, - page_size: int, - pin_memory: bool = True, - device: str = "cpu", - ): - super().__init__( - device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size - ) - - def get_size_per_token(self): - self.head_num = self.device_pool.head_num - self.head_dim = self.device_pool.head_dim - self.layer_num = self.device_pool.layer_num - - return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 - - def init_kv_buffer(self): - return torch.empty( - (2, self.layer_num, self.size, self.head_num, self.head_dim), - dtype=self.dtype, - device=self.device, - pin_memory=self.pin_memory, - ) - - @debug_timing - def transfer(self, indices, flat_data): - # backup prepared data from device to host - self.kv_buffer[:, :, indices] = flat_data.to( - device=self.device, non_blocking=False - ) - - def get_flat_data(self, indices): - return self.kv_buffer[:, :, indices] - - def get_flat_data_by_layer(self, indices, layer_id): - return self.kv_buffer[:, layer_id - self.start_layer, indices] - - def assign_flat_data(self, indices, flat_data): - self.kv_buffer[:, :, indices] = flat_data - - def write_page_all_layers(self, host_indices, device_indices, device_pool): - device_indices_cpu = device_indices[:: self.page_size].cpu() - for i in range(len(device_indices_cpu)): - h_index = host_indices[i * self.page_size] - d_index = device_indices_cpu[i] - for j in range(self.layer_num): - self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_( - device_pool.k_buffer[j][d_index : d_index + self.page_size], - non_blocking=True, - ) - self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_( - device_pool.v_buffer[j][d_index : d_index + self.page_size], - non_blocking=True, - ) - - def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id): - device_indices_cpu = device_indices[:: self.page_size].cpu() - for i in range(len(device_indices_cpu)): - h_index = host_indices[i * self.page_size] - d_index = device_indices_cpu[i] - device_pool.k_buffer[layer_id - self.start_layer][ - d_index : d_index + self.page_size - ].copy_( - self.kv_buffer[ - 0, layer_id - self.start_layer, h_index : h_index + self.page_size - ], - non_blocking=True, - ) - device_pool.v_buffer[layer_id - self.start_layer][ - d_index : d_index + self.page_size - ].copy_( - self.kv_buffer[ - 1, layer_id - self.start_layer, h_index : h_index + self.page_size - ], - non_blocking=True, - ) - - -class MLATokenToKVPoolHost(HostKVCache): - device_pool: MLATokenToKVPool - - def __init__( - self, - device_pool: MLATokenToKVPool, - host_to_device_ratio: float, - host_size: int, - page_size: int, - pin_memory: bool = True, - device: str = "cpu", - ): - super().__init__( - device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size - ) - - def get_size_per_token(self): - self.kv_lora_rank = self.device_pool.kv_lora_rank - self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim - self.layer_num = self.device_pool.layer_num - - return ( - (self.kv_lora_rank + self.qk_rope_head_dim) - * 1 - * self.dtype.itemsize - * self.layer_num - ) - - def init_kv_buffer(self): - return torch.empty( - ( - self.layer_num, - self.size, - 1, - self.kv_lora_rank + self.qk_rope_head_dim, - ), - dtype=self.dtype, - device=self.device, - pin_memory=self.pin_memory, - ) - - @debug_timing - def transfer(self, indices, flat_data): - # backup prepared data from device to host - self.kv_buffer[:, indices] = flat_data.to( - device=self.device, non_blocking=False - ) - - def get_flat_data(self, indices): - return self.kv_buffer[:, indices] - - def get_flat_data_by_layer(self, indices, layer_id): - return self.kv_buffer[layer_id - self.start_layer, indices] - - def assign_flat_data(self, indices, flat_data): - self.kv_buffer[:, indices] = flat_data - - def write_page_all_layers(self, host_indices, device_indices, device_pool): - device_indices_cpu = device_indices[:: self.page_size].cpu() - for i in range(len(device_indices_cpu)): - h_index = host_indices[i * self.page_size] - d_index = device_indices_cpu[i] - for j in range(self.layer_num): - self.kv_buffer[j, h_index : h_index + self.page_size].copy_( - device_pool.kv_buffer[j][d_index : d_index + self.page_size], - non_blocking=True, - ) - - def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id): - device_indices_cpu = device_indices[:: self.page_size].cpu() - for i in range(len(device_indices_cpu)): - h_index = host_indices[i * self.page_size] - d_index = device_indices_cpu[i] - device_pool.kv_buffer[layer_id - self.start_layer][ - d_index : d_index + self.page_size - ].copy_( - self.kv_buffer[ - layer_id - self.start_layer, h_index : h_index + self.page_size - ], - non_blocking=True, - ) diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py new file mode 100644 index 000000000..2a9d4c6e5 --- /dev/null +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -0,0 +1,380 @@ +import abc +import logging +import threading +from enum import IntEnum +from functools import wraps + +import psutil +import torch + +from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool +from sglang.srt.utils import debug_timing + +logger = logging.getLogger(__name__) + + +class MemoryStateInt(IntEnum): + IDLE = 0 + RESERVED = 1 + PROTECTED = 2 + SYNCED = 3 + BACKUP = 4 + + +def synchronized(debug_only=False): + def _decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if (not debug_only) or self.debug: + return func(self, *args, **kwargs) + with self.lock: + return func(self, *args, **kwargs) + else: + return True + + return wrapper + + return _decorator + + +class HostKVCache(abc.ABC): + + def __init__( + self, + device_pool: KVCache, + host_to_device_ratio: float, + host_size: int, + pin_memory: bool, + device: str, + page_size: int, + ): + self.device_pool = device_pool + self.dtype = device_pool.store_dtype + self.pin_memory = pin_memory + self.device = device + self.page_size = page_size + self.size_per_token = self.get_size_per_token() + if host_size > 0: + self.size = int(host_size * 1e9 // self.size_per_token) + else: + self.size = int(device_pool.size * host_to_device_ratio) + # Align the host memory pool size to the page size + self.size = self.size - (self.size % self.page_size) + self.start_layer = device_pool.start_layer + self.end_layer = device_pool.end_layer + + assert ( + self.size > device_pool.size + ), "The host memory should be larger than the device memory with the current protocol" + + # Verify there is enough available host memory. + host_mem = psutil.virtual_memory() + requested_bytes = self.size * self.size_per_token + # preserve at least 10GB for other usage + ten_gb = 10 * (1024**3) + if requested_bytes > host_mem.available - ten_gb: + raise ValueError( + f"Not enough host memory available. Requesting " + f"{requested_bytes / 1e9:.2f} GB but only have " + f"{host_mem.available / 1e9:.2f} GB free. Please reduce the " + f"size of the hierarchical cache." + ) + else: + logger.info( + f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." + ) + + self.kv_buffer = self.init_kv_buffer() + + # A lock for synchronized operations on memory allocation and state transitions. + self.lock = threading.RLock() + self.debug = logger.isEnabledFor(logging.DEBUG) + self.clear() + + @abc.abstractmethod + def get_size_per_token(self): + raise NotImplementedError() + + @abc.abstractmethod + def init_kv_buffer(self): + raise NotImplementedError() + + @abc.abstractmethod + def transfer(self, indices, flat_data): + raise NotImplementedError() + + @abc.abstractmethod + def get_flat_data(self, indices): + raise NotImplementedError() + + @abc.abstractmethod + def get_flat_data_by_layer(self, indices, layer_id): + raise NotImplementedError() + + @abc.abstractmethod + def assign_flat_data(self, indices, flat_data): + raise NotImplementedError() + + @synchronized() + def clear(self): + # Initialize memory states and tracking structures. + self.mem_state = torch.zeros( + (self.size,), dtype=torch.uint8, device=self.device + ) + self.free_slots = torch.arange(self.size, dtype=torch.int64) + + def available_size(self): + return len(self.free_slots) + + @synchronized() + def alloc(self, need_size: int) -> torch.Tensor: + if need_size > self.available_size(): + return None + + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + + if self.debug: + self.mem_state[select_index] = MemoryStateInt.RESERVED + + return select_index + + @synchronized() + def free(self, indices: torch.Tensor) -> int: + self.free_slots = torch.cat([self.free_slots, indices]) + if self.debug: + self.mem_state[indices] = MemoryStateInt.IDLE + return len(indices) + + @synchronized(debug_only=True) + def get_state(self, indices: torch.Tensor) -> MemoryStateInt: + assert len(indices) > 0, "The indices should not be empty" + states = self.mem_state[indices] + assert ( + states == states[0] + ).all(), "The memory slots should have the same state {}".format(states) + return MemoryStateInt(states[0].item()) + + @synchronized(debug_only=True) + def is_reserved(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.RESERVED + + @synchronized(debug_only=True) + def is_protected(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.PROTECTED + + @synchronized(debug_only=True) + def is_synced(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.SYNCED + + @synchronized(debug_only=True) + def is_backup(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.BACKUP + + @synchronized(debug_only=True) + def update_backup(self, indices: torch.Tensor): + if not self.is_synced(indices): + raise ValueError( + f"The host memory slots should be in SYNCED state before turning into BACKUP. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.BACKUP + + @synchronized(debug_only=True) + def update_synced(self, indices: torch.Tensor): + self.mem_state[indices] = MemoryStateInt.SYNCED + + @synchronized(debug_only=True) + def protect_write(self, indices: torch.Tensor): + if not self.is_reserved(indices): + raise ValueError( + f"The host memory slots should be RESERVED before write operations. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.PROTECTED + + @synchronized(debug_only=True) + def protect_load(self, indices: torch.Tensor): + if not self.is_backup(indices): + raise ValueError( + f"The host memory slots should be in BACKUP state before load operations. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.PROTECTED + + @synchronized(debug_only=True) + def complete_io(self, indices: torch.Tensor): + if not self.is_protected(indices): + raise ValueError( + f"The host memory slots should be PROTECTED during I/O operations. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.SYNCED + + +class MHATokenToKVPoolHost(HostKVCache): + device_pool: MHATokenToKVPool + + def __init__( + self, + device_pool: MHATokenToKVPool, + host_to_device_ratio: float, + host_size: int, + page_size: int, + pin_memory: bool = True, + device: str = "cpu", + ): + super().__init__( + device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size + ) + + def get_size_per_token(self): + self.head_num = self.device_pool.head_num + self.head_dim = self.device_pool.head_dim + self.layer_num = self.device_pool.layer_num + + return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 + + def init_kv_buffer(self): + return torch.empty( + (2, self.layer_num, self.size, self.head_num, self.head_dim), + dtype=self.dtype, + device=self.device, + pin_memory=self.pin_memory, + ) + + @debug_timing + def transfer(self, indices, flat_data): + # backup prepared data from device to host + self.kv_buffer[:, :, indices] = flat_data.to( + device=self.device, non_blocking=False + ) + + def get_flat_data(self, indices): + return self.kv_buffer[:, :, indices] + + def get_flat_data_by_layer(self, indices, layer_id): + return self.kv_buffer[:, layer_id - self.start_layer, indices] + + def assign_flat_data(self, indices, flat_data): + self.kv_buffer[:, :, indices] = flat_data + + def write_page_all_layers(self, host_indices, device_indices, device_pool): + device_indices_cpu = device_indices[:: self.page_size].cpu() + for i in range(len(device_indices_cpu)): + h_index = host_indices[i * self.page_size] + d_index = device_indices_cpu[i] + for j in range(self.layer_num): + self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_( + device_pool.k_buffer[j][d_index : d_index + self.page_size], + non_blocking=True, + ) + self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_( + device_pool.v_buffer[j][d_index : d_index + self.page_size], + non_blocking=True, + ) + + def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id): + device_indices_cpu = device_indices[:: self.page_size].cpu() + for i in range(len(device_indices_cpu)): + h_index = host_indices[i * self.page_size] + d_index = device_indices_cpu[i] + device_pool.k_buffer[layer_id - self.start_layer][ + d_index : d_index + self.page_size + ].copy_( + self.kv_buffer[ + 0, layer_id - self.start_layer, h_index : h_index + self.page_size + ], + non_blocking=True, + ) + device_pool.v_buffer[layer_id - self.start_layer][ + d_index : d_index + self.page_size + ].copy_( + self.kv_buffer[ + 1, layer_id - self.start_layer, h_index : h_index + self.page_size + ], + non_blocking=True, + ) + + +class MLATokenToKVPoolHost(HostKVCache): + device_pool: MLATokenToKVPool + + def __init__( + self, + device_pool: MLATokenToKVPool, + host_to_device_ratio: float, + host_size: int, + page_size: int, + pin_memory: bool = True, + device: str = "cpu", + ): + super().__init__( + device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size + ) + + def get_size_per_token(self): + self.kv_lora_rank = self.device_pool.kv_lora_rank + self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim + self.layer_num = self.device_pool.layer_num + + return ( + (self.kv_lora_rank + self.qk_rope_head_dim) + * 1 + * self.dtype.itemsize + * self.layer_num + ) + + def init_kv_buffer(self): + return torch.empty( + ( + self.layer_num, + self.size, + 1, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + dtype=self.dtype, + device=self.device, + pin_memory=self.pin_memory, + ) + + @debug_timing + def transfer(self, indices, flat_data): + # backup prepared data from device to host + self.kv_buffer[:, indices] = flat_data.to( + device=self.device, non_blocking=False + ) + + def get_flat_data(self, indices): + return self.kv_buffer[:, indices] + + def get_flat_data_by_layer(self, indices, layer_id): + return self.kv_buffer[layer_id - self.start_layer, indices] + + def assign_flat_data(self, indices, flat_data): + self.kv_buffer[:, indices] = flat_data + + def write_page_all_layers(self, host_indices, device_indices, device_pool): + device_indices_cpu = device_indices[:: self.page_size].cpu() + for i in range(len(device_indices_cpu)): + h_index = host_indices[i * self.page_size] + d_index = device_indices_cpu[i] + for j in range(self.layer_num): + self.kv_buffer[j, h_index : h_index + self.page_size].copy_( + device_pool.kv_buffer[j][d_index : d_index + self.page_size], + non_blocking=True, + ) + + def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id): + device_indices_cpu = device_indices[:: self.page_size].cpu() + for i in range(len(device_indices_cpu)): + h_index = host_indices[i * self.page_size] + d_index = device_indices_cpu[i] + device_pool.kv_buffer[layer_id - self.start_layer][ + d_index : d_index + self.page_size + ].copy_( + self.kv_buffer[ + layer_id - self.start_layer, h_index : h_index + self.page_size + ], + non_blocking=True, + ) diff --git a/test/srt/test_hicache_page.py b/test/srt/test_hicache_page.py index c110d054e..b1e1459c2 100644 --- a/test/srt/test_hicache_page.py +++ b/test/srt/test_hicache_page.py @@ -26,7 +26,7 @@ class TestHiCachePage(CustomTestCase): "--page-size", 32, "--hicache-write-policy", - "write-back", + "write_back", ], )