[PD] Support decode retract and update decode.py (#7196)
This commit is contained in:
@@ -234,6 +234,12 @@ class TokenToKVPoolAllocator:
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
return self._kvcache.get_cpu_copy(indices)
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||
|
||||
|
||||
class MHATokenToKVPool(KVCache):
|
||||
|
||||
@@ -265,6 +271,8 @@ class MHATokenToKVPool(KVCache):
|
||||
self.head_dim = head_dim
|
||||
self._create_buffers()
|
||||
|
||||
# used for chunked cpu-offloading
|
||||
self.chunk_size = 8192
|
||||
self.layer_transfer_counter = None
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
||||
@@ -329,6 +337,39 @@ class MHATokenToKVPool(KVCache):
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
torch.cuda.synchronize()
|
||||
kv_cache_cpu = []
|
||||
for layer_id in range(self.layer_num):
|
||||
kv_cache_cpu.append([])
|
||||
for i in range(0, len(indices), self.chunk_size):
|
||||
chunk_indices = indices[i : i + self.chunk_size]
|
||||
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
v_cpu = self.v_buffer[layer_id][chunk_indices].to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
kv_cache_cpu[-1].append([k_cpu, v_cpu])
|
||||
torch.cuda.synchronize()
|
||||
return kv_cache_cpu
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
torch.cuda.synchronize()
|
||||
for layer_id in range(self.layer_num):
|
||||
for i in range(0, len(indices), self.chunk_size):
|
||||
chunk_indices = indices[i : i + self.chunk_size]
|
||||
k_cpu, v_cpu = (
|
||||
kv_cache_cpu[layer_id][i // self.chunk_size][0],
|
||||
kv_cache_cpu[layer_id][i // self.chunk_size][1],
|
||||
)
|
||||
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
||||
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
||||
v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
|
||||
self.k_buffer[layer_id][chunk_indices] = k_chunk
|
||||
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Todo: different memory layout
|
||||
def get_flat_data(self, indices):
|
||||
# prepare a large chunk of contiguous data for efficient transfer
|
||||
|
||||
Reference in New Issue
Block a user