[Minor] Refactors KV memory pool (#9842)
This commit is contained in:
@@ -130,6 +130,29 @@ class KVCache(abc.ABC):
|
|||||||
# used for chunked cpu-offloading
|
# used for chunked cpu-offloading
|
||||||
self.cpu_offloading_chunk_size = 8192
|
self.cpu_offloading_chunk_size = 8192
|
||||||
|
|
||||||
|
# default state for optional layer-wise transfer control
|
||||||
|
self.layer_transfer_counter = None
|
||||||
|
|
||||||
|
def _finalize_allocation_log(self, num_tokens: int):
|
||||||
|
"""Common logging and mem_usage computation for KV cache allocation.
|
||||||
|
Supports both tuple (K, V) size returns and single KV size returns.
|
||||||
|
"""
|
||||||
|
kv_size_bytes = self.get_kv_size_bytes()
|
||||||
|
if isinstance(kv_size_bytes, tuple):
|
||||||
|
k_size, v_size = kv_size_bytes
|
||||||
|
k_size_GB = k_size / GB
|
||||||
|
v_size_GB = v_size / GB
|
||||||
|
logger.info(
|
||||||
|
f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
|
||||||
|
)
|
||||||
|
self.mem_usage = k_size_GB + v_size_GB
|
||||||
|
else:
|
||||||
|
kv_size_GB = kv_size_bytes / GB
|
||||||
|
logger.info(
|
||||||
|
f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
|
||||||
|
)
|
||||||
|
self.mem_usage = kv_size_GB
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -205,15 +228,9 @@ class MHATokenToKVPool(KVCache):
|
|||||||
|
|
||||||
self._create_buffers()
|
self._create_buffers()
|
||||||
|
|
||||||
self.layer_transfer_counter = None
|
|
||||||
self.device_module = torch.get_device_module(self.device)
|
self.device_module = torch.get_device_module(self.device)
|
||||||
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
||||||
|
self._finalize_allocation_log(size)
|
||||||
k_size, v_size = self.get_kv_size_bytes()
|
|
||||||
logger.info(
|
|
||||||
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
|
||||||
)
|
|
||||||
self.mem_usage = (k_size + v_size) / GB
|
|
||||||
|
|
||||||
def _create_buffers(self):
|
def _create_buffers(self):
|
||||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||||
@@ -427,43 +444,30 @@ class SWAKVPool(KVCache):
|
|||||||
self,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
size_swa: int,
|
size_swa: int,
|
||||||
dtype: torch.dtype,
|
|
||||||
head_num: int,
|
|
||||||
head_dim: int,
|
|
||||||
swa_attention_layer_ids: List[int],
|
swa_attention_layer_ids: List[int],
|
||||||
full_attention_layer_ids: List[int],
|
full_attention_layer_ids: List[int],
|
||||||
enable_kvcache_transpose: bool,
|
enable_kvcache_transpose: bool,
|
||||||
device: str,
|
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.size_swa = size_swa
|
self.size_swa = size_swa
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
self.swa_layer_nums = len(swa_attention_layer_ids)
|
self.swa_layer_nums = len(swa_attention_layer_ids)
|
||||||
self.full_layer_nums = len(full_attention_layer_ids)
|
self.full_layer_nums = len(full_attention_layer_ids)
|
||||||
self.page_size = 1
|
kwargs["page_size"] = 1
|
||||||
|
kwargs["enable_memory_saver"] = False
|
||||||
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
||||||
assert not enable_kvcache_transpose
|
assert not enable_kvcache_transpose
|
||||||
TokenToKVPoolClass = MHATokenToKVPool
|
|
||||||
self.swa_kv_pool = TokenToKVPoolClass(
|
self.swa_kv_pool = token_to_kv_pool_class(
|
||||||
size=size_swa,
|
size=size_swa,
|
||||||
page_size=self.page_size,
|
|
||||||
dtype=dtype,
|
|
||||||
head_num=head_num,
|
|
||||||
head_dim=head_dim,
|
|
||||||
layer_num=self.swa_layer_nums,
|
layer_num=self.swa_layer_nums,
|
||||||
device=device,
|
**kwargs,
|
||||||
enable_memory_saver=False,
|
|
||||||
)
|
)
|
||||||
self.full_kv_pool = TokenToKVPoolClass(
|
self.full_kv_pool = token_to_kv_pool_class(
|
||||||
size=size,
|
size=size,
|
||||||
page_size=self.page_size,
|
|
||||||
dtype=dtype,
|
|
||||||
head_num=head_num,
|
|
||||||
head_dim=head_dim,
|
|
||||||
layer_num=self.full_layer_nums,
|
layer_num=self.full_layer_nums,
|
||||||
device=device,
|
**kwargs,
|
||||||
enable_memory_saver=False,
|
|
||||||
)
|
)
|
||||||
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
|
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
|
||||||
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
|
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
|
||||||
@@ -768,13 +772,7 @@ class MLATokenToKVPool(KVCache):
|
|||||||
dtype=torch.uint64,
|
dtype=torch.uint64,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.layer_transfer_counter = None
|
self._finalize_allocation_log(size)
|
||||||
|
|
||||||
kv_size = self.get_kv_size_bytes()
|
|
||||||
logger.info(
|
|
||||||
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
|
||||||
)
|
|
||||||
self.mem_usage = kv_size / GB
|
|
||||||
|
|
||||||
def get_kv_size_bytes(self):
|
def get_kv_size_bytes(self):
|
||||||
assert hasattr(self, "kv_buffer")
|
assert hasattr(self, "kv_buffer")
|
||||||
@@ -936,13 +934,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layer_transfer_counter = None
|
self._finalize_allocation_log(size)
|
||||||
|
|
||||||
kv_size = self.get_kv_size_bytes()
|
|
||||||
logger.info(
|
|
||||||
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
|
||||||
)
|
|
||||||
self.mem_usage = kv_size / GB
|
|
||||||
|
|
||||||
def get_kv_size_bytes(self):
|
def get_kv_size_bytes(self):
|
||||||
assert hasattr(self, "k_buffer")
|
assert hasattr(self, "k_buffer")
|
||||||
|
|||||||
@@ -31,16 +31,18 @@ class TestSWA(unittest.TestCase):
|
|||||||
i for i in range(num_layers) if i not in full_attention_layer_ids_set
|
i for i in range(num_layers) if i not in full_attention_layer_ids_set
|
||||||
]
|
]
|
||||||
pool = SWAKVPool(
|
pool = SWAKVPool(
|
||||||
size,
|
size=size,
|
||||||
size_swa,
|
size_swa=size_swa,
|
||||||
dtype,
|
dtype=dtype,
|
||||||
num_head,
|
num_head=num_head,
|
||||||
head_dim,
|
head_dim=head_dim,
|
||||||
swa_attention_layer_ids,
|
swa_attention_layer_ids=swa_attention_layer_ids,
|
||||||
full_attention_layer_ids,
|
full_attention_layer_ids=full_attention_layer_ids,
|
||||||
device,
|
device=device,
|
||||||
|
)
|
||||||
|
alloc = SWATokenToKVPoolAllocator(
|
||||||
|
size=size, size_swa=size_swa, dtype=dtype, device=device, kvcache=pool
|
||||||
)
|
)
|
||||||
alloc = SWATokenToKVPoolAllocator(size, size_swa, dtype, device, pool)
|
|
||||||
assert alloc.available_size() == size + size_swa
|
assert alloc.available_size() == size + size_swa
|
||||||
index = alloc.alloc(1)
|
index = alloc.alloc(1)
|
||||||
assert alloc.available_size() == size_swa + size_swa - 2
|
assert alloc.available_size() == size_swa + size_swa - 2
|
||||||
@@ -75,18 +77,22 @@ class TestSWA(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
# setup kv pool
|
# setup kv pool
|
||||||
kv_pool = SWAKVPool(
|
kv_pool = SWAKVPool(
|
||||||
kv_size,
|
size=kv_size,
|
||||||
kv_size_swa,
|
size_swa=kv_size_swa,
|
||||||
dtype,
|
dtype=dtype,
|
||||||
num_head,
|
num_head=num_head,
|
||||||
head_dim,
|
head_dim=head_dim,
|
||||||
swa_attention_layer_ids,
|
swa_attention_layer_ids=swa_attention_layer_ids,
|
||||||
full_attention_layer_ids,
|
full_attention_layer_ids=full_attention_layer_ids,
|
||||||
device,
|
device=device,
|
||||||
)
|
)
|
||||||
# setup token to kv pool allocator
|
# setup token to kv pool allocator
|
||||||
allocator = SWATokenToKVPoolAllocator(
|
allocator = SWATokenToKVPoolAllocator(
|
||||||
kv_size, kv_size_swa, dtype, device, kv_pool
|
size=kv_size,
|
||||||
|
size_swa=kv_size_swa,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
kvcache=kv_pool,
|
||||||
)
|
)
|
||||||
# setup radix cache
|
# setup radix cache
|
||||||
tree = SWARadixCache(
|
tree = SWARadixCache(
|
||||||
|
|||||||
Reference in New Issue
Block a user