From 5ea5d221703d663c6dbd8a57a313c4872a59b647 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sun, 22 Jun 2025 02:39:05 +0800 Subject: [PATCH] Fix CPU offloading for MLA memory pool (#7409) --- python/sglang/srt/mem_cache/memory_pool.py | 52 ++++++++++++++++++---- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b5be2bb1b..5306ce175 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -123,6 +123,9 @@ class KVCache(abc.ABC): enable=enable_memory_saver ) + # used for chunked cpu-offloading + self.cpu_offloading_chunk_size = 8192 + @abc.abstractmethod def get_key_buffer(self, layer_id: int) -> torch.Tensor: raise NotImplementedError() @@ -157,6 +160,12 @@ class KVCache(abc.ABC): def register_layer_transfer_counter(self, layer_transfer_counter): self.layer_transfer_counter = layer_transfer_counter + def get_cpu_copy(self, indices): + raise NotImplementedError() + + def load_cpu_copy(self, kv_cache_cpu, indices): + raise NotImplementedError() + class TokenToKVPoolAllocator: """An allocator managing the indices to kv cache data.""" @@ -280,8 +289,6 @@ class MHATokenToKVPool(KVCache): self._create_buffers() - # used for chunked cpu-offloading - self.chunk_size = 8192 self.layer_transfer_counter = None self.device_module = torch.get_device_module(self.device) self.alt_stream = self.device_module.Stream() if _is_cuda else None @@ -378,10 +385,11 @@ class MHATokenToKVPool(KVCache): def get_cpu_copy(self, indices): torch.cuda.synchronize() kv_cache_cpu = [] + chunk_size = self.cpu_offloading_chunk_size for layer_id in range(self.layer_num): kv_cache_cpu.append([]) - for i in range(0, len(indices), self.chunk_size): - chunk_indices = indices[i : i + self.chunk_size] + for i in range(0, len(indices), chunk_size): + chunk_indices = indices[i : i + chunk_size] k_cpu = self.k_buffer[layer_id][chunk_indices].to( "cpu", non_blocking=True ) @@ -394,12 +402,13 @@ class MHATokenToKVPool(KVCache): def load_cpu_copy(self, kv_cache_cpu, indices): torch.cuda.synchronize() + chunk_size = self.cpu_offloading_chunk_size for layer_id in range(self.layer_num): - for i in range(0, len(indices), self.chunk_size): - chunk_indices = indices[i : i + self.chunk_size] + for i in range(0, len(indices), chunk_size): + chunk_indices = indices[i : i + chunk_size] k_cpu, v_cpu = ( - kv_cache_cpu[layer_id][i // self.chunk_size][0], - kv_cache_cpu[layer_id][i // self.chunk_size][1], + kv_cache_cpu[layer_id][i // chunk_size][0], + kv_cache_cpu[layer_id][i // chunk_size][1], ) assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices) k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True) @@ -724,6 +733,33 @@ class MLATokenToKVPool(KVCache): flat_data = flat_data.to(device=self.device, non_blocking=False) self.kv_buffer[layer_id - self.start_layer][indices] = flat_data + def get_cpu_copy(self, indices): + torch.cuda.synchronize() + kv_cache_cpu = [] + chunk_size = self.cpu_offloading_chunk_size + for layer_id in range(self.layer_num): + kv_cache_cpu.append([]) + for i in range(0, len(indices), chunk_size): + chunk_indices = indices[i : i + chunk_size] + kv_cpu = self.kv_buffer[layer_id][chunk_indices].to( + "cpu", non_blocking=True + ) + kv_cache_cpu[-1].append(kv_cpu) + torch.cuda.synchronize() + return kv_cache_cpu + + def load_cpu_copy(self, kv_cache_cpu, indices): + torch.cuda.synchronize() + chunk_size = self.cpu_offloading_chunk_size + for layer_id in range(self.layer_num): + for i in range(0, len(indices), chunk_size): + chunk_indices = indices[i : i + chunk_size] + kv_cpu = kv_cache_cpu[layer_id][i // chunk_size] + assert kv_cpu.shape[0] == len(chunk_indices) + kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True) + self.kv_buffer[layer_id][chunk_indices] = kv_chunk + torch.cuda.synchronize() + class DoubleSparseTokenToKVPool(KVCache): def __init__(