Fix CPU offloading for MLA memory pool (#7409)
This commit is contained in:
@@ -123,6 +123,9 @@ class KVCache(abc.ABC):
|
|||||||
enable=enable_memory_saver
|
enable=enable_memory_saver
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# used for chunked cpu-offloading
|
||||||
|
self.cpu_offloading_chunk_size = 8192
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -157,6 +160,12 @@ class KVCache(abc.ABC):
|
|||||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||||
self.layer_transfer_counter = layer_transfer_counter
|
self.layer_transfer_counter = layer_transfer_counter
|
||||||
|
|
||||||
|
def get_cpu_copy(self, indices):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class TokenToKVPoolAllocator:
|
class TokenToKVPoolAllocator:
|
||||||
"""An allocator managing the indices to kv cache data."""
|
"""An allocator managing the indices to kv cache data."""
|
||||||
@@ -280,8 +289,6 @@ class MHATokenToKVPool(KVCache):
|
|||||||
|
|
||||||
self._create_buffers()
|
self._create_buffers()
|
||||||
|
|
||||||
# used for chunked cpu-offloading
|
|
||||||
self.chunk_size = 8192
|
|
||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
self.device_module = torch.get_device_module(self.device)
|
self.device_module = torch.get_device_module(self.device)
|
||||||
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
||||||
@@ -378,10 +385,11 @@ class MHATokenToKVPool(KVCache):
|
|||||||
def get_cpu_copy(self, indices):
|
def get_cpu_copy(self, indices):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
kv_cache_cpu = []
|
kv_cache_cpu = []
|
||||||
|
chunk_size = self.cpu_offloading_chunk_size
|
||||||
for layer_id in range(self.layer_num):
|
for layer_id in range(self.layer_num):
|
||||||
kv_cache_cpu.append([])
|
kv_cache_cpu.append([])
|
||||||
for i in range(0, len(indices), self.chunk_size):
|
for i in range(0, len(indices), chunk_size):
|
||||||
chunk_indices = indices[i : i + self.chunk_size]
|
chunk_indices = indices[i : i + chunk_size]
|
||||||
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
||||||
"cpu", non_blocking=True
|
"cpu", non_blocking=True
|
||||||
)
|
)
|
||||||
@@ -394,12 +402,13 @@ class MHATokenToKVPool(KVCache):
|
|||||||
|
|
||||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
chunk_size = self.cpu_offloading_chunk_size
|
||||||
for layer_id in range(self.layer_num):
|
for layer_id in range(self.layer_num):
|
||||||
for i in range(0, len(indices), self.chunk_size):
|
for i in range(0, len(indices), chunk_size):
|
||||||
chunk_indices = indices[i : i + self.chunk_size]
|
chunk_indices = indices[i : i + chunk_size]
|
||||||
k_cpu, v_cpu = (
|
k_cpu, v_cpu = (
|
||||||
kv_cache_cpu[layer_id][i // self.chunk_size][0],
|
kv_cache_cpu[layer_id][i // chunk_size][0],
|
||||||
kv_cache_cpu[layer_id][i // self.chunk_size][1],
|
kv_cache_cpu[layer_id][i // chunk_size][1],
|
||||||
)
|
)
|
||||||
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
||||||
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
||||||
@@ -724,6 +733,33 @@ class MLATokenToKVPool(KVCache):
|
|||||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||||
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
||||||
|
|
||||||
|
def get_cpu_copy(self, indices):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
kv_cache_cpu = []
|
||||||
|
chunk_size = self.cpu_offloading_chunk_size
|
||||||
|
for layer_id in range(self.layer_num):
|
||||||
|
kv_cache_cpu.append([])
|
||||||
|
for i in range(0, len(indices), chunk_size):
|
||||||
|
chunk_indices = indices[i : i + chunk_size]
|
||||||
|
kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
|
||||||
|
"cpu", non_blocking=True
|
||||||
|
)
|
||||||
|
kv_cache_cpu[-1].append(kv_cpu)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return kv_cache_cpu
|
||||||
|
|
||||||
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
chunk_size = self.cpu_offloading_chunk_size
|
||||||
|
for layer_id in range(self.layer_num):
|
||||||
|
for i in range(0, len(indices), chunk_size):
|
||||||
|
chunk_indices = indices[i : i + chunk_size]
|
||||||
|
kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
|
||||||
|
assert kv_cpu.shape[0] == len(chunk_indices)
|
||||||
|
kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
|
||||||
|
self.kv_buffer[layer_id][chunk_indices] = kv_chunk
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
class DoubleSparseTokenToKVPool(KVCache):
|
class DoubleSparseTokenToKVPool(KVCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user