Move host memory pools into a separate file (#7200)
This commit is contained in:
@@ -22,7 +22,8 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -9,12 +9,14 @@ import torch
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MHATokenToKVPoolHost,
|
||||
MLATokenToKVPool,
|
||||
MLATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool_host import (
|
||||
MHATokenToKVPoolHost,
|
||||
MLATokenToKVPoolHost,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,24 +26,15 @@ KVCache actually holds the physical kv cache.
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import threading
|
||||
from enum import IntEnum
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import (
|
||||
debug_timing,
|
||||
get_compiler_backend,
|
||||
is_cuda,
|
||||
next_power_of_2,
|
||||
)
|
||||
from sglang.srt.utils import debug_timing, get_compiler_backend, is_cuda
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -772,370 +763,3 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
pass
|
||||
|
||||
|
||||
class MemoryStateInt(IntEnum):
|
||||
IDLE = 0
|
||||
RESERVED = 1
|
||||
PROTECTED = 2
|
||||
SYNCED = 3
|
||||
BACKUP = 4
|
||||
|
||||
|
||||
def synchronized(debug_only=False):
|
||||
def _decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if (not debug_only) or self.debug:
|
||||
return func(self, *args, **kwargs)
|
||||
with self.lock:
|
||||
return func(self, *args, **kwargs)
|
||||
else:
|
||||
return True
|
||||
|
||||
return wrapper
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class HostKVCache(abc.ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: KVCache,
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
pin_memory: bool,
|
||||
device: str,
|
||||
page_size: int,
|
||||
):
|
||||
self.device_pool = device_pool
|
||||
self.dtype = device_pool.store_dtype
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
self.size_per_token = self.get_size_per_token()
|
||||
if host_size > 0:
|
||||
self.size = int(host_size * 1e9 // self.size_per_token)
|
||||
else:
|
||||
self.size = int(device_pool.size * host_to_device_ratio)
|
||||
# Align the host memory pool size to the page size
|
||||
self.size = self.size - (self.size % self.page_size)
|
||||
self.start_layer = device_pool.start_layer
|
||||
self.end_layer = device_pool.end_layer
|
||||
|
||||
assert (
|
||||
self.size > device_pool.size
|
||||
), "The host memory should be larger than the device memory with the current protocol"
|
||||
|
||||
# Verify there is enough available host memory.
|
||||
host_mem = psutil.virtual_memory()
|
||||
requested_bytes = self.size * self.size_per_token
|
||||
# preserve at least 10GB for other usage
|
||||
ten_gb = 10 * (1024**3)
|
||||
if requested_bytes > host_mem.available - ten_gb:
|
||||
raise ValueError(
|
||||
f"Not enough host memory available. Requesting "
|
||||
f"{requested_bytes / 1e9:.2f} GB but only have "
|
||||
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
|
||||
f"size of the hierarchical cache."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
||||
)
|
||||
|
||||
self.kv_buffer = self.init_kv_buffer()
|
||||
|
||||
# A lock for synchronized operations on memory allocation and state transitions.
|
||||
self.lock = threading.RLock()
|
||||
self.debug = logger.isEnabledFor(logging.DEBUG)
|
||||
self.clear()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_size_per_token(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_kv_buffer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer(self, indices, flat_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data(self, indices):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
@synchronized()
|
||||
def clear(self):
|
||||
# Initialize memory states and tracking structures.
|
||||
self.mem_state = torch.zeros(
|
||||
(self.size,), dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.free_slots = torch.arange(self.size, dtype=torch.int64)
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
@synchronized()
|
||||
def alloc(self, need_size: int) -> torch.Tensor:
|
||||
if need_size > self.available_size():
|
||||
return None
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
if self.debug:
|
||||
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
||||
|
||||
return select_index
|
||||
|
||||
@synchronized()
|
||||
def free(self, indices: torch.Tensor) -> int:
|
||||
self.free_slots = torch.cat([self.free_slots, indices])
|
||||
if self.debug:
|
||||
self.mem_state[indices] = MemoryStateInt.IDLE
|
||||
return len(indices)
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
||||
assert len(indices) > 0, "The indices should not be empty"
|
||||
states = self.mem_state[indices]
|
||||
assert (
|
||||
states == states[0]
|
||||
).all(), "The memory slots should have the same state {}".format(states)
|
||||
return MemoryStateInt(states[0].item())
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_reserved(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.RESERVED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_protected(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_synced(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.SYNCED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_backup(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def update_backup(self, indices: torch.Tensor):
|
||||
if not self.is_synced(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def update_synced(self, indices: torch.Tensor):
|
||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def protect_write(self, indices: torch.Tensor):
|
||||
if not self.is_reserved(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be RESERVED before write operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def protect_load(self, indices: torch.Tensor):
|
||||
if not self.is_backup(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be in BACKUP state before load operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def complete_io(self, indices: torch.Tensor):
|
||||
if not self.is_protected(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be PROTECTED during I/O operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||
|
||||
|
||||
class MHATokenToKVPoolHost(HostKVCache):
|
||||
device_pool: MHATokenToKVPool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
page_size: int,
|
||||
pin_memory: bool = True,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__(
|
||||
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
||||
)
|
||||
|
||||
def get_size_per_token(self):
|
||||
self.head_num = self.device_pool.head_num
|
||||
self.head_dim = self.device_pool.head_dim
|
||||
self.layer_num = self.device_pool.layer_num
|
||||
|
||||
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
||||
|
||||
def init_kv_buffer(self):
|
||||
return torch.empty(
|
||||
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# backup prepared data from device to host
|
||||
self.kv_buffer[:, :, indices] = flat_data.to(
|
||||
device=self.device, non_blocking=False
|
||||
)
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, :, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, :, indices] = flat_data
|
||||
|
||||
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
for j in range(self.layer_num):
|
||||
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
||||
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
||||
non_blocking=True,
|
||||
)
|
||||
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
||||
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
device_pool.k_buffer[layer_id - self.start_layer][
|
||||
d_index : d_index + self.page_size
|
||||
].copy_(
|
||||
self.kv_buffer[
|
||||
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||
],
|
||||
non_blocking=True,
|
||||
)
|
||||
device_pool.v_buffer[layer_id - self.start_layer][
|
||||
d_index : d_index + self.page_size
|
||||
].copy_(
|
||||
self.kv_buffer[
|
||||
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||
],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
|
||||
class MLATokenToKVPoolHost(HostKVCache):
|
||||
device_pool: MLATokenToKVPool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MLATokenToKVPool,
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
page_size: int,
|
||||
pin_memory: bool = True,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__(
|
||||
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
||||
)
|
||||
|
||||
def get_size_per_token(self):
|
||||
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
||||
self.layer_num = self.device_pool.layer_num
|
||||
|
||||
return (
|
||||
(self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
* 1
|
||||
* self.dtype.itemsize
|
||||
* self.layer_num
|
||||
)
|
||||
|
||||
def init_kv_buffer(self):
|
||||
return torch.empty(
|
||||
(
|
||||
self.layer_num,
|
||||
self.size,
|
||||
1,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# backup prepared data from device to host
|
||||
self.kv_buffer[:, indices] = flat_data.to(
|
||||
device=self.device, non_blocking=False
|
||||
)
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[layer_id - self.start_layer, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, indices] = flat_data
|
||||
|
||||
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
for j in range(self.layer_num):
|
||||
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
||||
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
device_pool.kv_buffer[layer_id - self.start_layer][
|
||||
d_index : d_index + self.page_size
|
||||
].copy_(
|
||||
self.kv_buffer[
|
||||
layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||
],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
380
python/sglang/srt/mem_cache/memory_pool_host.py
Normal file
380
python/sglang/srt/mem_cache/memory_pool_host.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import abc
|
||||
import logging
|
||||
import threading
|
||||
from enum import IntEnum
|
||||
from functools import wraps
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
||||
from sglang.srt.utils import debug_timing
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryStateInt(IntEnum):
|
||||
IDLE = 0
|
||||
RESERVED = 1
|
||||
PROTECTED = 2
|
||||
SYNCED = 3
|
||||
BACKUP = 4
|
||||
|
||||
|
||||
def synchronized(debug_only=False):
|
||||
def _decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if (not debug_only) or self.debug:
|
||||
return func(self, *args, **kwargs)
|
||||
with self.lock:
|
||||
return func(self, *args, **kwargs)
|
||||
else:
|
||||
return True
|
||||
|
||||
return wrapper
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class HostKVCache(abc.ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: KVCache,
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
pin_memory: bool,
|
||||
device: str,
|
||||
page_size: int,
|
||||
):
|
||||
self.device_pool = device_pool
|
||||
self.dtype = device_pool.store_dtype
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
self.size_per_token = self.get_size_per_token()
|
||||
if host_size > 0:
|
||||
self.size = int(host_size * 1e9 // self.size_per_token)
|
||||
else:
|
||||
self.size = int(device_pool.size * host_to_device_ratio)
|
||||
# Align the host memory pool size to the page size
|
||||
self.size = self.size - (self.size % self.page_size)
|
||||
self.start_layer = device_pool.start_layer
|
||||
self.end_layer = device_pool.end_layer
|
||||
|
||||
assert (
|
||||
self.size > device_pool.size
|
||||
), "The host memory should be larger than the device memory with the current protocol"
|
||||
|
||||
# Verify there is enough available host memory.
|
||||
host_mem = psutil.virtual_memory()
|
||||
requested_bytes = self.size * self.size_per_token
|
||||
# preserve at least 10GB for other usage
|
||||
ten_gb = 10 * (1024**3)
|
||||
if requested_bytes > host_mem.available - ten_gb:
|
||||
raise ValueError(
|
||||
f"Not enough host memory available. Requesting "
|
||||
f"{requested_bytes / 1e9:.2f} GB but only have "
|
||||
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
|
||||
f"size of the hierarchical cache."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
||||
)
|
||||
|
||||
self.kv_buffer = self.init_kv_buffer()
|
||||
|
||||
# A lock for synchronized operations on memory allocation and state transitions.
|
||||
self.lock = threading.RLock()
|
||||
self.debug = logger.isEnabledFor(logging.DEBUG)
|
||||
self.clear()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_size_per_token(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_kv_buffer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer(self, indices, flat_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data(self, indices):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
@synchronized()
|
||||
def clear(self):
|
||||
# Initialize memory states and tracking structures.
|
||||
self.mem_state = torch.zeros(
|
||||
(self.size,), dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.free_slots = torch.arange(self.size, dtype=torch.int64)
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
@synchronized()
|
||||
def alloc(self, need_size: int) -> torch.Tensor:
|
||||
if need_size > self.available_size():
|
||||
return None
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
if self.debug:
|
||||
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
||||
|
||||
return select_index
|
||||
|
||||
@synchronized()
|
||||
def free(self, indices: torch.Tensor) -> int:
|
||||
self.free_slots = torch.cat([self.free_slots, indices])
|
||||
if self.debug:
|
||||
self.mem_state[indices] = MemoryStateInt.IDLE
|
||||
return len(indices)
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
||||
assert len(indices) > 0, "The indices should not be empty"
|
||||
states = self.mem_state[indices]
|
||||
assert (
|
||||
states == states[0]
|
||||
).all(), "The memory slots should have the same state {}".format(states)
|
||||
return MemoryStateInt(states[0].item())
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_reserved(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.RESERVED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_protected(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_synced(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.SYNCED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def is_backup(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def update_backup(self, indices: torch.Tensor):
|
||||
if not self.is_synced(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def update_synced(self, indices: torch.Tensor):
|
||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def protect_write(self, indices: torch.Tensor):
|
||||
if not self.is_reserved(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be RESERVED before write operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def protect_load(self, indices: torch.Tensor):
|
||||
if not self.is_backup(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be in BACKUP state before load operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def complete_io(self, indices: torch.Tensor):
|
||||
if not self.is_protected(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be PROTECTED during I/O operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||
|
||||
|
||||
class MHATokenToKVPoolHost(HostKVCache):
|
||||
device_pool: MHATokenToKVPool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
page_size: int,
|
||||
pin_memory: bool = True,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__(
|
||||
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
||||
)
|
||||
|
||||
def get_size_per_token(self):
|
||||
self.head_num = self.device_pool.head_num
|
||||
self.head_dim = self.device_pool.head_dim
|
||||
self.layer_num = self.device_pool.layer_num
|
||||
|
||||
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
||||
|
||||
def init_kv_buffer(self):
|
||||
return torch.empty(
|
||||
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# backup prepared data from device to host
|
||||
self.kv_buffer[:, :, indices] = flat_data.to(
|
||||
device=self.device, non_blocking=False
|
||||
)
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, :, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, :, indices] = flat_data
|
||||
|
||||
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
for j in range(self.layer_num):
|
||||
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
||||
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
||||
non_blocking=True,
|
||||
)
|
||||
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
||||
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
device_pool.k_buffer[layer_id - self.start_layer][
|
||||
d_index : d_index + self.page_size
|
||||
].copy_(
|
||||
self.kv_buffer[
|
||||
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||
],
|
||||
non_blocking=True,
|
||||
)
|
||||
device_pool.v_buffer[layer_id - self.start_layer][
|
||||
d_index : d_index + self.page_size
|
||||
].copy_(
|
||||
self.kv_buffer[
|
||||
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||
],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
|
||||
class MLATokenToKVPoolHost(HostKVCache):
|
||||
device_pool: MLATokenToKVPool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MLATokenToKVPool,
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
page_size: int,
|
||||
pin_memory: bool = True,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__(
|
||||
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
||||
)
|
||||
|
||||
def get_size_per_token(self):
|
||||
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
||||
self.layer_num = self.device_pool.layer_num
|
||||
|
||||
return (
|
||||
(self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
* 1
|
||||
* self.dtype.itemsize
|
||||
* self.layer_num
|
||||
)
|
||||
|
||||
def init_kv_buffer(self):
|
||||
return torch.empty(
|
||||
(
|
||||
self.layer_num,
|
||||
self.size,
|
||||
1,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# backup prepared data from device to host
|
||||
self.kv_buffer[:, indices] = flat_data.to(
|
||||
device=self.device, non_blocking=False
|
||||
)
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[layer_id - self.start_layer, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, indices] = flat_data
|
||||
|
||||
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
for j in range(self.layer_num):
|
||||
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
||||
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
device_pool.kv_buffer[layer_id - self.start_layer][
|
||||
d_index : d_index + self.page_size
|
||||
].copy_(
|
||||
self.kv_buffer[
|
||||
layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||
],
|
||||
non_blocking=True,
|
||||
)
|
||||
Reference in New Issue
Block a user