Support chunked prefill when radix cache is disabled (#811)

This commit is contained in:
Liangsheng Yin
2024-08-01 00:29:01 -07:00
committed by GitHub
parent ca600e8cd6
commit c020f9ceda
9 changed files with 163 additions and 26 deletions

View File

@@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.radix_cache import RadixCache
@@ -486,15 +487,33 @@ class Batch:
req = self.reqs[idx]
retracted_reqs.append(req)
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][: seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
del self.tree_cache.entries[req.rid]
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = (
len(sorted_indices) * global_config.retract_decode_steps
- self.token_to_kv_pool.available_size()
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
req.prefix_indices = None
req.last_node = None
@@ -575,6 +594,7 @@ class Batch:
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.tolist()
self.tree_cache.cache_req(
rid=req.rid,
token_ids=cur_all_ids,
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],