SWA Prefix Cache (#7367)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
Hanming Lu
2025-07-13 12:31:07 -07:00
committed by GitHub
parent 0c55cbcfc5
commit 9379da77de
16 changed files with 1742 additions and 158 deletions

View File

@@ -57,11 +57,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
def debug_print(self) -> str:
return ""
def log_usage(self, evictable_size: int = 0):
num_used = self.size - (self.available_size() + evictable_size)
msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
return msg, num_used
def available_size(self):
return len(self.free_pages) * self.page_size
@@ -190,7 +185,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
def available_size(self):
return min(self.full_available_size(), self.swa_available_size())
raise NotImplementedError()
def full_available_size(self):
return self.full_attn_allocator.available_size()
@@ -214,16 +209,6 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
)
return msg
def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
used_full = self.size_full - (self.full_available_size() + full_evictable_size)
used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
msg = (
f"#token: full={used_full}, swa={used_swa}, "
f"token usage: full={used_full / self.size_full:.2f}, "
f"swa={used_swa / self.size_swa:.2f}, "
)
return msg, used_full
def get_kvcache(self):
return self._kvcache

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple
import torch
@@ -56,15 +56,27 @@ class BasePrefixCache(ABC):
pass
@abstractmethod
def dec_lock_ref(self, node: Any):
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
pass
def evictable_size(self):
return 0
def full_evictable_size(self):
return 0
def swa_evictable_size(self):
return 0
def protected_size(self):
return 0
def full_protected_size(self):
return 0
def swa_protected_size(self):
return 0
def total_size(self):
raise NotImplementedError()

View File

@@ -61,7 +61,7 @@ class ChunkCache(BasePrefixCache):
def inc_lock_ref(self, node: Any):
return 0
def dec_lock_ref(self, node: Any):
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
return 0
def pretty_print(self):
@@ -80,7 +80,7 @@ class SWAChunkCache(ChunkCache):
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
def evict(
def evict_swa(
self,
req: Req,
prelen: int,
@@ -95,3 +95,6 @@ class SWAChunkCache(ChunkCache):
]
self.token_to_kv_pool_allocator.free_swa(free_slots)
req.evicted_seqlen_local = new_evicted_seqlen_local
def evict(self, num_tokens: int):
pass

File diff suppressed because it is too large Load Diff