Fix CPU offloading for MLA memory pool (#7409)
This commit is contained in:
@@ -123,6 +123,9 @@ class KVCache(abc.ABC):
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
# used for chunked cpu-offloading
|
||||
self.cpu_offloading_chunk_size = 8192
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
@@ -157,6 +160,12 @@ class KVCache(abc.ABC):
|
||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||
self.layer_transfer_counter = layer_transfer_counter
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TokenToKVPoolAllocator:
|
||||
"""An allocator managing the indices to kv cache data."""
|
||||
@@ -280,8 +289,6 @@ class MHATokenToKVPool(KVCache):
|
||||
|
||||
self._create_buffers()
|
||||
|
||||
# used for chunked cpu-offloading
|
||||
self.chunk_size = 8192
|
||||
self.layer_transfer_counter = None
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
||||
@@ -378,10 +385,11 @@ class MHATokenToKVPool(KVCache):
|
||||
def get_cpu_copy(self, indices):
|
||||
torch.cuda.synchronize()
|
||||
kv_cache_cpu = []
|
||||
chunk_size = self.cpu_offloading_chunk_size
|
||||
for layer_id in range(self.layer_num):
|
||||
kv_cache_cpu.append([])
|
||||
for i in range(0, len(indices), self.chunk_size):
|
||||
chunk_indices = indices[i : i + self.chunk_size]
|
||||
for i in range(0, len(indices), chunk_size):
|
||||
chunk_indices = indices[i : i + chunk_size]
|
||||
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
@@ -394,12 +402,13 @@ class MHATokenToKVPool(KVCache):
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
torch.cuda.synchronize()
|
||||
chunk_size = self.cpu_offloading_chunk_size
|
||||
for layer_id in range(self.layer_num):
|
||||
for i in range(0, len(indices), self.chunk_size):
|
||||
chunk_indices = indices[i : i + self.chunk_size]
|
||||
for i in range(0, len(indices), chunk_size):
|
||||
chunk_indices = indices[i : i + chunk_size]
|
||||
k_cpu, v_cpu = (
|
||||
kv_cache_cpu[layer_id][i // self.chunk_size][0],
|
||||
kv_cache_cpu[layer_id][i // self.chunk_size][1],
|
||||
kv_cache_cpu[layer_id][i // chunk_size][0],
|
||||
kv_cache_cpu[layer_id][i // chunk_size][1],
|
||||
)
|
||||
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
||||
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
||||
@@ -724,6 +733,33 @@ class MLATokenToKVPool(KVCache):
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
torch.cuda.synchronize()
|
||||
kv_cache_cpu = []
|
||||
chunk_size = self.cpu_offloading_chunk_size
|
||||
for layer_id in range(self.layer_num):
|
||||
kv_cache_cpu.append([])
|
||||
for i in range(0, len(indices), chunk_size):
|
||||
chunk_indices = indices[i : i + chunk_size]
|
||||
kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
kv_cache_cpu[-1].append(kv_cpu)
|
||||
torch.cuda.synchronize()
|
||||
return kv_cache_cpu
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
torch.cuda.synchronize()
|
||||
chunk_size = self.cpu_offloading_chunk_size
|
||||
for layer_id in range(self.layer_num):
|
||||
for i in range(0, len(indices), chunk_size):
|
||||
chunk_indices = indices[i : i + chunk_size]
|
||||
kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
|
||||
assert kv_cpu.shape[0] == len(chunk_indices)
|
||||
kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
|
||||
self.kv_buffer[layer_id][chunk_indices] = kv_chunk
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
class DoubleSparseTokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user