From 0e0ec702007b79e7f45cd8efc48292b087369d3f Mon Sep 17 00:00:00 2001 From: Lu Changqi <58518876+zeroorhero@users.noreply.github.com> Date: Fri, 14 Mar 2025 11:42:14 +0800 Subject: [PATCH] Hierarchical Caching supports MLA (#4009) Signed-off-by: Changqi Lu Co-authored-by: Zhiqiang Xie --- .../sglang/srt/managers/cache_controller.py | 7 +- python/sglang/srt/mem_cache/hiradix_cache.py | 16 +- python/sglang/srt/mem_cache/memory_pool.py | 190 +++++++++++++++--- test/srt/test_hierarchical_mla.py | 56 ++++++ 4 files changed, 231 insertions(+), 38 deletions(-) create mode 100644 test/srt/test_hierarchical_mla.py diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index f8857fab4..92d509732 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -22,10 +22,7 @@ from typing import List, Optional import torch -from sglang.srt.mem_cache.memory_pool import ( - MHATokenToKVPoolHost, - TokenToKVPoolAllocator, -) +from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator logger = logging.getLogger(__name__) @@ -151,7 +148,7 @@ class HiCacheController: def __init__( self, token_to_kv_pool_allocator: TokenToKVPoolAllocator, - mem_pool_host: MHATokenToKVPoolHost, + mem_pool_host: HostKVCache, load_cache_event: threading.Event = None, write_policy: str = "write_through_selective", ): diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 754a88c07..e16e8e3ad 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -8,7 +8,10 @@ import torch from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.mem_cache.memory_pool import ( + MHATokenToKVPool, MHATokenToKVPoolHost, + MLATokenToKVPool, + MLATokenToKVPoolHost, ReqToTokenPool, TokenToKVPoolAllocator, ) @@ -31,9 +34,14 @@ class HiRadixCache(RadixCache): raise ValueError( "Page size larger than 1 is not yet supported in HiRadixCache." ) - self.token_to_kv_pool_host = MHATokenToKVPoolHost( - token_to_kv_pool_allocator.get_kvcache() - ) + self.kv_cache = token_to_kv_pool_allocator.get_kvcache() + if isinstance(self.kv_cache, MHATokenToKVPool): + self.token_to_kv_pool_host = MHATokenToKVPoolHost(self.kv_cache) + elif isinstance(self.kv_cache, MLATokenToKVPool): + self.token_to_kv_pool_host = MLATokenToKVPoolHost(self.kv_cache) + else: + raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.") + self.tp_group = tp_cache_group self.page_size = page_size @@ -317,13 +325,11 @@ class HiRadixCache(RadixCache): prefix_len = _key_match(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) - self.inc_hit_count(new_node) if not new_node.evicted: value.append(new_node.value) node = new_node break else: - self.inc_hit_count(child) if not child.evicted: value.append(child.value) node = child diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 4f48c27b4..16bb8eb60 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -115,6 +115,21 @@ class KVCache(abc.ABC): ) -> None: raise NotImplementedError() + @abc.abstractmethod + def get_flat_data(self, indices): + raise NotImplementedError() + + @abc.abstractmethod + def transfer(self, indices, flat_data): + raise NotImplementedError() + + @abc.abstractmethod + def transfer_per_layer(self, indices, flat_data, layer_id): + raise NotImplementedError() + + def register_layer_transfer_counter(self, layer_transfer_counter): + self.layer_transfer_counter = layer_transfer_counter + class TokenToKVPoolAllocator: """An allocator managing the indices to kv cache data.""" @@ -275,9 +290,6 @@ class MHATokenToKVPool(KVCache): self.k_buffer[i][indices] = k_data[i] self.v_buffer[i][indices] = v_data[i] - def register_layer_transfer_counter(self, layer_transfer_counter): - self.layer_transfer_counter = layer_transfer_counter - def transfer_per_layer(self, indices, flat_data, layer_id): # transfer prepared data from host to device flat_data = flat_data.to(device=self.device, non_blocking=False) @@ -388,6 +400,8 @@ class MLATokenToKVPool(KVCache): else: self.store_dtype = dtype self.kv_lora_rank = kv_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.layer_num = layer_num memory_saver_adapter = TorchMemorySaverAdapter.create( enable=enable_memory_saver @@ -404,12 +418,20 @@ class MLATokenToKVPool(KVCache): for _ in range(layer_num) ] + self.layer_transfer_counter = None + def get_key_buffer(self, layer_id: int): + if self.layer_transfer_counter is not None: + self.layer_transfer_counter.wait_until(layer_id) + if self.store_dtype != self.dtype: return self.kv_buffer[layer_id].view(self.dtype) return self.kv_buffer[layer_id] def get_value_buffer(self, layer_id: int): + if self.layer_transfer_counter is not None: + self.layer_transfer_counter.wait_until(layer_id) + if self.store_dtype != self.dtype: return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype) return self.kv_buffer[layer_id][..., : self.kv_lora_rank] @@ -432,6 +454,22 @@ class MLATokenToKVPool(KVCache): else: self.kv_buffer[layer_id][loc] = cache_k + def get_flat_data(self, indices): + # prepare a large chunk of contiguous data for efficient transfer + return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)]) + + @debug_timing + def transfer(self, indices, flat_data): + # transfer prepared data from host to device + flat_data = flat_data.to(device=self.device, non_blocking=False) + for i in range(self.layer_num): + self.kv_buffer[i][indices] = flat_data[i] + + def transfer_per_layer(self, indices, flat_data, layer_id): + # transfer prepared data from host to device + flat_data = flat_data.to(device=self.device, non_blocking=False) + self.kv_buffer[layer_id][indices] = flat_data + class DoubleSparseTokenToKVPool(KVCache): def __init__( @@ -508,6 +546,15 @@ class DoubleSparseTokenToKVPool(KVCache): self.v_buffer[layer_id][loc] = cache_v self.label_buffer[layer_id][loc] = cache_label + def get_flat_data(self, indices): + pass + + def transfer(self, indices, flat_data): + pass + + def transfer_per_layer(self, indices, flat_data, layer_id): + pass + class MemoryStateInt(IntEnum): IDLE = 0 @@ -526,7 +573,7 @@ def synchronized(func): return wrapper -class MHATokenToKVPoolHost: +class HostKVCache(abc.ABC): def __init__( self, @@ -547,12 +594,7 @@ class MHATokenToKVPoolHost: self.size = int(device_pool.size * host_to_device_ratio) self.dtype = device_pool.store_dtype - self.head_num = device_pool.head_num - self.head_dim = device_pool.head_dim - self.layer_num = device_pool.layer_num - self.size_per_token = ( - self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 - ) + self.size_per_token = self.get_size_per_token() # Verify there is enough available host memory. host_mem = psutil.virtual_memory() @@ -571,12 +613,7 @@ class MHATokenToKVPoolHost: f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." ) - self.kv_buffer = torch.zeros( - (2, self.layer_num, self.size, self.head_num, self.head_dim), - dtype=self.dtype, - device=self.device, - pin_memory=self.pin_memory, - ) + self.kv_buffer = self.init_kv_buffer() # Initialize memory states and tracking structures. self.mem_state = torch.zeros( @@ -588,21 +625,29 @@ class MHATokenToKVPoolHost: # A lock for synchronized operations on memory allocation and state transitions. self.lock = threading.RLock() - def get_flat_data(self, indices): - return self.kv_buffer[:, :, indices] + @abc.abstractmethod + def get_size_per_token(self): + raise NotImplementedError() - def get_flat_data_by_layer(self, indices, layer_id): - return self.kv_buffer[:, layer_id, indices] + @abc.abstractmethod + def init_kv_buffer(self): + raise NotImplementedError() - def assign_flat_data(self, indices, flat_data): - self.kv_buffer[:, :, indices] = flat_data - - @debug_timing + @abc.abstractmethod def transfer(self, indices, flat_data): - # backup prepared data from device to host - self.kv_buffer[:, :, indices] = flat_data.to( - device=self.device, non_blocking=False - ) + raise NotImplementedError() + + @abc.abstractmethod + def get_flat_data(self, indices): + raise NotImplementedError() + + @abc.abstractmethod + def get_flat_data_by_layer(self, indices, layer_id): + raise NotImplementedError() + + @abc.abstractmethod + def assign_flat_data(self, indices, flat_data): + raise NotImplementedError() @synchronized def clear(self): @@ -694,3 +739,92 @@ class MHATokenToKVPoolHost: self.free_slots = torch.concat([self.free_slots, indices]) self.can_use_mem_size += len(indices) return len(indices) + + +class MHATokenToKVPoolHost(HostKVCache): + def __init__( + self, + device_pool: MHATokenToKVPool, + host_to_device_ratio: float = 3.0, + pin_memory: bool = False, # no need to use pin memory with the double buffering + device: str = "cpu", + ): + super().__init__(device_pool, host_to_device_ratio, pin_memory, device) + + def get_size_per_token(self): + self.head_num = self.device_pool.head_num + self.head_dim = self.device_pool.head_dim + self.layer_num = self.device_pool.layer_num + + return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 + + def init_kv_buffer(self): + return torch.empty( + (2, self.layer_num, self.size, self.head_num, self.head_dim), + dtype=self.dtype, + device=self.device, + pin_memory=self.pin_memory, + ) + + @debug_timing + def transfer(self, indices, flat_data): + # backup prepared data from device to host + self.kv_buffer[:, :, indices] = flat_data.to( + device=self.device, non_blocking=False + ) + + def get_flat_data(self, indices): + return self.kv_buffer[:, :, indices] + + def get_flat_data_by_layer(self, indices, layer_id): + return self.kv_buffer[:, layer_id, indices] + + def assign_flat_data(self, indices, flat_data): + self.kv_buffer[:, :, indices] = flat_data + + +class MLATokenToKVPoolHost(HostKVCache): + def __init__( + self, + device_pool: MLATokenToKVPool, + host_to_device_ratio: float = 4.0, + pin_memory: bool = False, # no need to use pin memory with the double buffering + device: str = "cpu", + ): + super().__init__(device_pool, host_to_device_ratio, pin_memory, device) + + def get_size_per_token(self): + self.kv_lora_rank = self.device_pool.kv_lora_rank + self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim + self.layer_num = self.device_pool.layer_num + + return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize + + def init_kv_buffer(self): + return torch.empty( + ( + self.layer_num, + self.size, + 1, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + dtype=self.dtype, + device=self.device, + pin_memory=self.pin_memory, + ) + + @debug_timing + def transfer(self, indices, flat_data): + # backup prepared data from device to host + self.kv_buffer[:, indices] = flat_data.to( + device=self.device, non_blocking=False + ) + + def get_flat_data(self, indices): + return self.kv_buffer[:, indices] + + def get_flat_data_by_layer(self, indices, layer_id): + return self.kv_buffer[layer_id, indices] + + def assign_flat_data(self, indices, flat_data): + self.kv_buffer[:, indices] = flat_data diff --git a/test/srt/test_hierarchical_mla.py b/test/srt/test_hierarchical_mla.py new file mode 100644 index 000000000..8250395a0 --- /dev/null +++ b/test/srt/test_hierarchical_mla.py @@ -0,0 +1,56 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestHierarchicalMLA(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--trust-remote-code", "--enable-hierarchical-cache"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.5) + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.8) + + +if __name__ == "__main__": + unittest.main()