[PD] Support decode retract and update decode.py (#7196)

This commit is contained in:
Byron Hsu
2025-06-14 19:48:05 -07:00
committed by GitHub
parent 349bb2c92a
commit db0cc57e75
6 changed files with 378 additions and 43 deletions

View File

@@ -234,6 +234,12 @@ class TokenToKVPoolAllocator:
self.is_not_in_free_group = True
self.free_group = []
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
class MHATokenToKVPool(KVCache):
@@ -265,6 +271,8 @@ class MHATokenToKVPool(KVCache):
self.head_dim = head_dim
self._create_buffers()
# used for chunked cpu-offloading
self.chunk_size = 8192
self.layer_transfer_counter = None
self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if _is_cuda else None
@@ -329,6 +337,39 @@ class MHATokenToKVPool(KVCache):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
for layer_id in range(self.layer_num):
kv_cache_cpu.append([])
for i in range(0, len(indices), self.chunk_size):
chunk_indices = indices[i : i + self.chunk_size]
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
v_cpu = self.v_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
kv_cache_cpu[-1].append([k_cpu, v_cpu])
torch.cuda.synchronize()
return kv_cache_cpu
def load_cpu_copy(self, kv_cache_cpu, indices):
torch.cuda.synchronize()
for layer_id in range(self.layer_num):
for i in range(0, len(indices), self.chunk_size):
chunk_indices = indices[i : i + self.chunk_size]
k_cpu, v_cpu = (
kv_cache_cpu[layer_id][i // self.chunk_size][0],
kv_cache_cpu[layer_id][i // self.chunk_size][1],
)
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
self.k_buffer[layer_id][chunk_indices] = k_chunk
self.v_buffer[layer_id][chunk_indices] = v_chunk
torch.cuda.synchronize()
# Todo: different memory layout
def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer