upstream hicache fixes (#5570)
This commit is contained in:
@@ -571,6 +571,14 @@ class Req:
|
||||
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
||||
rid=self.rid, key=self.adjust_max_prefix_ids()
|
||||
)
|
||||
elif enable_hierarchical_cache:
|
||||
# in case last_node is evicted during scheduling, we need to update the prefix_indices
|
||||
while self.last_node.evicted:
|
||||
self.prefix_indices = self.prefix_indices[
|
||||
: -len(self.last_node.host_value)
|
||||
]
|
||||
self.last_node = self.last_node.parent
|
||||
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
|
||||
@@ -489,6 +489,8 @@ class Scheduler(
|
||||
tp_cache_group=self.tp_cpu_group,
|
||||
page_size=self.page_size,
|
||||
hicache_ratio=server_args.hicache_ratio,
|
||||
hicache_size=server_args.hicache_size,
|
||||
hicache_write_policy=server_args.hicache_write_policy,
|
||||
)
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
|
||||
@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
|
||||
tp_cache_group: torch.distributed.ProcessGroup,
|
||||
page_size: int,
|
||||
hicache_ratio: float,
|
||||
hicache_size: int,
|
||||
hicache_write_policy: str,
|
||||
):
|
||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
self.kv_cache, hicache_ratio, page_size
|
||||
self.kv_cache, hicache_ratio, hicache_size, page_size
|
||||
)
|
||||
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
||||
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
||||
self.kv_cache, hicache_ratio, page_size
|
||||
self.kv_cache, hicache_ratio, hicache_size, page_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
||||
@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
|
||||
self.token_to_kv_pool_host,
|
||||
page_size,
|
||||
load_cache_event=self.load_cache_event,
|
||||
write_policy=hicache_write_policy,
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
|
||||
# record the node segments with ongoing load back
|
||||
self.ongoing_load_back = {}
|
||||
# todo: dynamically adjust the threshold
|
||||
self.write_through_threshold = 1
|
||||
self.write_through_threshold = (
|
||||
1 if hicache_write_policy == "write_through" else 3
|
||||
)
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(
|
||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||
@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
|
||||
height += 1
|
||||
return height
|
||||
|
||||
def write_backup(self, node: TreeNode):
|
||||
def write_backup(self, node: TreeNode, write_back=False):
|
||||
host_indices = self.cache_controller.write(
|
||||
device_indices=node.value,
|
||||
node_id=node.id,
|
||||
@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
|
||||
if host_indices is not None:
|
||||
node.host_value = host_indices
|
||||
self.ongoing_write_through[node.id] = node
|
||||
self.inc_lock_ref(node)
|
||||
if not write_back:
|
||||
# no need to lock nodes if write back
|
||||
self.inc_lock_ref(node)
|
||||
else:
|
||||
return 0
|
||||
|
||||
return len(host_indices)
|
||||
|
||||
def inc_hit_count(self, node: TreeNode):
|
||||
if self.cache_controller.write_policy != "write_through_selective":
|
||||
if node.backuped or self.cache_controller.write_policy == "write_back":
|
||||
return
|
||||
node.hit_count += 1
|
||||
if node.host_value is None and node.hit_count > self.write_through_threshold:
|
||||
if node.hit_count >= self.write_through_threshold:
|
||||
self.write_backup(node)
|
||||
node.hit_count = 0
|
||||
|
||||
def writing_check(self):
|
||||
def writing_check(self, write_back=False):
|
||||
if write_back:
|
||||
# blocking till all write back complete
|
||||
while len(self.ongoing_write_through) > 0:
|
||||
ack_id = self.cache_controller.ack_write_queue.get()
|
||||
del self.ongoing_write_through[ack_id]
|
||||
return
|
||||
queue_size = torch.tensor(
|
||||
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
||||
)
|
||||
@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache):
|
||||
heapq.heapify(leaves)
|
||||
|
||||
num_evicted = 0
|
||||
pending_nodes = []
|
||||
write_back_nodes = []
|
||||
while num_evicted < num_tokens and len(leaves):
|
||||
x = heapq.heappop(leaves)
|
||||
|
||||
if x.lock_ref > 0:
|
||||
continue
|
||||
|
||||
if x.host_value is None:
|
||||
if not x.backuped:
|
||||
if self.cache_controller.write_policy == "write_back":
|
||||
num_evicted += self.write_backup(x)
|
||||
pending_nodes.append(x)
|
||||
elif self.cache_controller.write_policy == "write_through_selective":
|
||||
num_evicted += self._evict_write_through_selective(x)
|
||||
# write to host if the node is not backuped
|
||||
num_evicted += self.write_backup(x, write_back=True)
|
||||
write_back_nodes.append(x)
|
||||
else:
|
||||
assert (
|
||||
self.cache_controller.write_policy != "write_through"
|
||||
), "write_through should be inclusive"
|
||||
raise NotImplementedError
|
||||
num_evicted += self._evict_regular(x)
|
||||
else:
|
||||
num_evicted += self._evict_write_through(x)
|
||||
num_evicted += self._evict_backuped(x)
|
||||
|
||||
for child in x.parent.children.values():
|
||||
if child in pending_nodes:
|
||||
if child in write_back_nodes:
|
||||
continue
|
||||
if not child.evicted:
|
||||
break
|
||||
@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache):
|
||||
heapq.heappush(leaves, x.parent)
|
||||
|
||||
if self.cache_controller.write_policy == "write_back":
|
||||
# blocking till all write back complete
|
||||
while len(self.ongoing_write_through) > 0:
|
||||
self.writing_check()
|
||||
time.sleep(0.1)
|
||||
for node in pending_nodes:
|
||||
assert node.host_value is not None
|
||||
self._evict_write_through(node)
|
||||
self.writing_check(write_back=True)
|
||||
for node in write_back_nodes:
|
||||
assert node.backuped
|
||||
self._evict_backuped(node)
|
||||
|
||||
def _evict_write_through(self, node: TreeNode):
|
||||
def _evict_backuped(self, node: TreeNode):
|
||||
# evict a node already written to host
|
||||
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
|
||||
assert num_evicted > 0
|
||||
@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache):
|
||||
node.value = None
|
||||
return num_evicted
|
||||
|
||||
def _evict_write_through_selective(self, node: TreeNode):
|
||||
def _evict_regular(self, node: TreeNode):
|
||||
# evict a node not initiated write to host
|
||||
self.cache_controller.mem_pool_device_allocator.free(node.value)
|
||||
num_evicted = len(node.value)
|
||||
@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache):
|
||||
prefix_len = self.key_match_fn(child.key, key)
|
||||
if prefix_len < len(child.key):
|
||||
new_node = self._split_node(child.key, child, prefix_len)
|
||||
self.inc_hit_count(new_node)
|
||||
if not new_node.evicted:
|
||||
value.append(new_node.value)
|
||||
node = new_node
|
||||
break
|
||||
else:
|
||||
self.inc_hit_count(child)
|
||||
if not child.evicted:
|
||||
value.append(child.value)
|
||||
node = child
|
||||
@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache):
|
||||
else:
|
||||
new_node.value = child.value[:split_len]
|
||||
child.value = child.value[split_len:]
|
||||
if child.host_value is not None:
|
||||
if child.backuped:
|
||||
new_node.host_value = child.host_value[:split_len]
|
||||
child.host_value = child.host_value[split_len:]
|
||||
child.parent = new_node
|
||||
@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache):
|
||||
node.children[child_key] = new_node
|
||||
self.evictable_size_ += len(value)
|
||||
|
||||
if self.cache_controller.write_policy == "write_through":
|
||||
self.write_backup(new_node)
|
||||
if self.cache_controller.write_policy != "write_back":
|
||||
self.inc_hit_count(new_node)
|
||||
return total_prefix_length
|
||||
|
||||
def _collect_leaves_device(self):
|
||||
|
||||
@@ -624,26 +624,27 @@ class HostKVCache(abc.ABC):
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
pin_memory: bool,
|
||||
device: str,
|
||||
page_size: int,
|
||||
):
|
||||
assert (
|
||||
host_to_device_ratio >= 1
|
||||
), "The host memory should be larger than the device memory with the current protocol"
|
||||
# todo, other ways of configuring the size
|
||||
|
||||
self.device_pool = device_pool
|
||||
self.host_to_device_ratio = host_to_device_ratio
|
||||
self.dtype = device_pool.store_dtype
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
|
||||
self.size = int(device_pool.size * host_to_device_ratio)
|
||||
self.size_per_token = self.get_size_per_token()
|
||||
if host_size > 0:
|
||||
self.size = int(host_size * 1e9 // self.size_per_token)
|
||||
else:
|
||||
self.size = int(device_pool.size * host_to_device_ratio)
|
||||
# Align the host memory pool size to the page size
|
||||
self.size = self.size - (self.size % self.page_size)
|
||||
self.dtype = device_pool.store_dtype
|
||||
self.size_per_token = self.get_size_per_token()
|
||||
|
||||
assert (
|
||||
self.size > device_pool.size
|
||||
), "The host memory should be larger than the device memory with the current protocol"
|
||||
|
||||
# Verify there is enough available host memory.
|
||||
host_mem = psutil.virtual_memory()
|
||||
@@ -795,12 +796,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
page_size: int,
|
||||
pin_memory: bool = True,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__(
|
||||
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
||||
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
||||
)
|
||||
|
||||
def get_size_per_token(self):
|
||||
@@ -869,12 +871,13 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
self,
|
||||
device_pool: MLATokenToKVPool,
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
page_size: int,
|
||||
pin_memory: bool = True,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__(
|
||||
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
||||
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
||||
)
|
||||
|
||||
def get_size_per_token(self):
|
||||
|
||||
@@ -180,6 +180,8 @@ class ServerArgs:
|
||||
tool_call_parser: Optional[str] = None
|
||||
enable_hierarchical_cache: bool = False
|
||||
hicache_ratio: float = 2.0
|
||||
hicache_size: int = 0
|
||||
hicache_write_policy: str = "write_through_selective"
|
||||
flashinfer_mla_disable_ragged: bool = False
|
||||
warmups: Optional[str] = None
|
||||
moe_dense_tp_size: Optional[int] = None
|
||||
@@ -1116,10 +1118,22 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--hicache-ratio",
|
||||
type=float,
|
||||
required=False,
|
||||
default=ServerArgs.hicache_ratio,
|
||||
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hicache-size",
|
||||
type=int,
|
||||
default=ServerArgs.hicache_size,
|
||||
help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hicache-write-policy",
|
||||
type=str,
|
||||
choices=["write_back", "write_through", "write_through_selective"],
|
||||
default=ServerArgs.hicache_write_policy,
|
||||
help="The write policy of hierarchical cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-deepep-moe",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user