diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index f2271fd7a..cd7d653fc 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -94,6 +94,33 @@ class ReqToTokenPool: class KVCache(abc.ABC): + @abc.abstractmethod + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + ): + self.size = size + self.page_size = page_size + self.dtype = dtype + self.device = device + if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): + # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 + self.store_dtype = torch.uint8 + else: + self.store_dtype = dtype + self.layer_num = layer_num + self.start_layer = start_layer or 0 + self.end_layer = end_layer or layer_num - 1 + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) @abc.abstractmethod def get_key_buffer(self, layer_id: int) -> torch.Tensor: @@ -217,25 +244,20 @@ class MHATokenToKVPool(KVCache): start_layer: Optional[int] = None, end_layer: Optional[int] = None, ): - self.size = size - self.page_size = page_size - self.dtype = dtype - self.device = device - if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): - # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 - self.store_dtype = torch.uint8 - else: - self.store_dtype = dtype - self.memory_saver_adapter = TorchMemorySaverAdapter.create( - enable=enable_memory_saver + super().__init__( + size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, ) self.head_num = head_num self.head_dim = head_dim - self.layer_num = layer_num self._create_buffers() - self.start_layer = start_layer or 0 - self.end_layer = end_layer or layer_num - 1 self.layer_transfer_counter = None self.capture_mode = False @@ -493,26 +515,21 @@ class MLATokenToKVPool(KVCache): start_layer: Optional[int] = None, end_layer: Optional[int] = None, ): - self.size = size - self.page_size = page_size - self.dtype = dtype - self.device = device - if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): - # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 - self.store_dtype = torch.uint8 - else: - self.store_dtype = dtype - self.kv_lora_rank = kv_lora_rank - self.qk_rope_head_dim = qk_rope_head_dim - self.layer_num = layer_num - self.start_layer = start_layer or 0 - self.end_layer = end_layer or layer_num - 1 - - memory_saver_adapter = TorchMemorySaverAdapter.create( - enable=enable_memory_saver + super().__init__( + size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, ) - with memory_saver_adapter.region(): + self.kv_lora_rank = kv_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + + with self.memory_saver_adapter.region(): # The padded slot 0 is used for writing dummy outputs from padded tokens. self.kv_buffer = [ torch.zeros( @@ -636,20 +653,18 @@ class DoubleSparseTokenToKVPool(KVCache): start_layer: Optional[int] = None, end_layer: Optional[int] = None, ): - self.size = size - self.page_size = page_size - self.dtype = dtype - self.device = device - if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): - # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 - self.store_dtype = torch.uint8 - else: - self.store_dtype = dtype - memory_saver_adapter = TorchMemorySaverAdapter.create( - enable=enable_memory_saver + super().__init__( + size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, ) - with memory_saver_adapter.region(): + with self.memory_saver_adapter.region(): # [size, head_num, head_dim] for each layer self.k_buffer = [ torch.zeros( @@ -672,9 +687,6 @@ class DoubleSparseTokenToKVPool(KVCache): for _ in range(layer_num) ] - self.start_layer = start_layer or 0 - self.end_layer = end_layer or layer_num - 1 - def get_key_buffer(self, layer_id: int): return self.k_buffer[layer_id - self.start_layer] @@ -742,7 +754,7 @@ class HostKVCache(abc.ABC): def __init__( self, - device_pool: MHATokenToKVPool, + device_pool: KVCache, host_to_device_ratio: float, host_size: int, pin_memory: bool, @@ -914,6 +926,8 @@ class HostKVCache(abc.ABC): class MHATokenToKVPoolHost(HostKVCache): + device_pool: MHATokenToKVPool + def __init__( self, device_pool: MHATokenToKVPool, @@ -997,6 +1011,8 @@ class MHATokenToKVPoolHost(HostKVCache): class MLATokenToKVPoolHost(HostKVCache): + device_pool: MLATokenToKVPool + def __init__( self, device_pool: MLATokenToKVPool,