diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 5463d3d77..2d29126b3 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -92,7 +92,7 @@ class HiRadixCache(RadixCache): self.ongoing_write_through[node.id] = node self.inc_lock_ref(node) else: - return None + return 0 return len(host_indices) @@ -153,6 +153,7 @@ class HiRadixCache(RadixCache): if x.host_value is None: if self.cache_controller.write_policy == "write_back": num_evicted += self.write_backup(x) + pending_nodes.append(x) elif self.cache_controller.write_policy == "write_through_selective": num_evicted += self._evict_write_through_selective(x) else: @@ -177,6 +178,9 @@ class HiRadixCache(RadixCache): while len(self.ongoing_write_through) > 0: self.writing_check() time.sleep(0.1) + for node in pending_nodes: + assert node.host_value is not None + self._evict_write_through(node) def _evict_write_through(self, node: TreeNode): # evict a node already written to host