Support chunked prefill when radix cache is disabled (#811)
This commit is contained in:
@@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
|
||||
@@ -486,15 +487,33 @@ class Batch:
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = len(req.prefix_indices)
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[idx]
|
||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
if isinstance(self.tree_cache, ChunkCache):
|
||||
# ChunkCache does not have eviction
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[idx]
|
||||
][: seq_lens_cpu[idx]]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
||||
del self.tree_cache.entries[req.rid]
|
||||
else:
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = len(req.prefix_indices)
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[idx]
|
||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
||||
|
||||
# release the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
# release the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||
residual_size = (
|
||||
len(sorted_indices) * global_config.retract_decode_steps
|
||||
- self.token_to_kv_pool.available_size()
|
||||
)
|
||||
residual_size = max(0, residual_size)
|
||||
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
||||
|
||||
req.prefix_indices = None
|
||||
req.last_node = None
|
||||
@@ -575,6 +594,7 @@ class Batch:
|
||||
if req_pool_indices_cpu is None:
|
||||
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
||||
self.tree_cache.cache_req(
|
||||
rid=req.rid,
|
||||
token_ids=cur_all_ids,
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
|
||||
Reference in New Issue
Block a user