RadixCache method adjust (#977)
This commit is contained in:
@@ -124,7 +124,7 @@ class Req:
|
||||
# For vision input
|
||||
self.pixel_values = None
|
||||
self.image_size = None
|
||||
self.image_offset = 0
|
||||
self.image_offset = None
|
||||
self.pad_value = None
|
||||
|
||||
# Prefix info
|
||||
@@ -162,6 +162,13 @@ class Req:
|
||||
def finished(self) -> bool:
|
||||
return self.finished_reason is not None
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
max_prefix_ids = self.input_ids
|
||||
if self.return_logprob:
|
||||
max_prefix_ids = self.input_ids[: self.logprob_start_len]
|
||||
|
||||
return max_prefix_ids
|
||||
|
||||
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
||||
def init_incremental_detokenize(self):
|
||||
first_iter = self.surr_offset is None or self.read_offset is None
|
||||
@@ -444,7 +451,8 @@ class ScheduleBatch:
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_size for r in reqs]
|
||||
self.image_offsets = [
|
||||
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
||||
(r.image_offset - p_len) if r.image_offset is not None else 0
|
||||
for r, p_len in zip(reqs, prefix_lens)
|
||||
]
|
||||
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
@@ -596,15 +604,7 @@ class ScheduleBatch:
|
||||
req.vid += 1
|
||||
|
||||
# insert the old request into tree_cache
|
||||
self.tree_cache.cache_req(
|
||||
rid=req.rid,
|
||||
token_ids=cur_all_ids,
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req.req_pool_idx,
|
||||
)
|
||||
|
||||
# unlock the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
||||
|
||||
# re-applying image padding
|
||||
if req.pixel_values is not None:
|
||||
@@ -621,8 +621,7 @@ class ScheduleBatch:
|
||||
jump_forward_reqs.append(req)
|
||||
filter_indices.remove(i)
|
||||
|
||||
if len(filter_indices) < len(self.reqs):
|
||||
self.filter_batch(filter_indices)
|
||||
self.filter_batch(filter_indices)
|
||||
|
||||
return jump_forward_reqs
|
||||
|
||||
@@ -644,6 +643,15 @@ class ScheduleBatch:
|
||||
] = self.out_cache_loc
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int]):
|
||||
if unfinished_indices is None or len(unfinished_indices) == 0:
|
||||
# Filter out all requests
|
||||
self.reqs = []
|
||||
return
|
||||
|
||||
if len(unfinished_indices) == len(self.reqs):
|
||||
# No need to filter
|
||||
return
|
||||
|
||||
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
||||
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
@@ -711,6 +719,7 @@ class ScheduleBatch:
|
||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||
|
||||
def sample(self, logits: torch.Tensor):
|
||||
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
logits.div_(self.temperatures)
|
||||
|
||||
@@ -232,8 +232,6 @@ class ModelTpServer:
|
||||
if new_batch is not None:
|
||||
# Run a new prefill batch
|
||||
self.forward_prefill_batch(new_batch)
|
||||
self.cache_filled_batch(new_batch)
|
||||
self.filter_out_inflight(new_batch)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
@@ -353,26 +351,20 @@ class ModelTpServer:
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
||||
# TODO(lsyin): organize this function
|
||||
running_bs = (
|
||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
)
|
||||
if running_bs >= self.max_running_requests:
|
||||
return
|
||||
return None
|
||||
|
||||
# Compute matched prefix length
|
||||
for req in self.waiting_queue:
|
||||
req.input_ids = req.origin_input_ids + req.output_ids
|
||||
try_match_ids = req.input_ids
|
||||
if req.return_logprob:
|
||||
try_match_ids = req.input_ids[: req.logprob_start_len]
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
prefix_indices, last_node = self.tree_cache.match_prefix(
|
||||
rid=req.rid, key=try_match_ids
|
||||
req.prefix_indices, req.last_node = self.tree_cache.match_prefix(
|
||||
rid=req.rid, key=req.adjust_max_prefix_ids()
|
||||
)
|
||||
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
||||
req.prefix_indices = prefix_indices
|
||||
req.last_node = last_node
|
||||
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
|
||||
|
||||
# Get priority queue
|
||||
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
|
||||
@@ -394,6 +386,24 @@ class ModelTpServer:
|
||||
)
|
||||
|
||||
for req in self.waiting_queue:
|
||||
|
||||
# FIXME: Move this code into adjust_max_prefix_len
|
||||
if req.return_logprob and req.normalized_prompt_logprob is None:
|
||||
# Need at least two tokens to compute normalized logprob
|
||||
if req.extend_input_len < 2:
|
||||
delta = 2 - req.extend_input_len
|
||||
req.extend_input_len += delta
|
||||
req.prefix_indices = req.prefix_indices[:-delta]
|
||||
if req.image_offset is not None:
|
||||
req.image_offset += delta
|
||||
|
||||
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
|
||||
# Need at least one token to compute logits
|
||||
req.extend_input_len = 1
|
||||
req.prefix_indices = req.prefix_indices[:-1]
|
||||
if req.image_offset is not None:
|
||||
req.image_offset += 1
|
||||
|
||||
res = adder.add_one_req(req)
|
||||
if (
|
||||
not res
|
||||
@@ -470,10 +480,20 @@ class ModelTpServer:
|
||||
pt = 0
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req is not self.current_inflight_req:
|
||||
# Inflight reqs' prefill is not finished
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_ids[i])
|
||||
req.check_finished()
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
else:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
if req is self.current_inflight_req:
|
||||
# Inflight request would get a new req idx
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
if req.return_logprob:
|
||||
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
||||
pt += req.extend_input_len
|
||||
@@ -529,22 +549,6 @@ class ModelTpServer:
|
||||
)
|
||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||
|
||||
def cache_filled_batch(self, batch: ScheduleBatch):
|
||||
for i, req in enumerate(batch.reqs):
|
||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||
rid=req.rid,
|
||||
token_ids=tuple(req.input_ids),
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req.req_pool_idx,
|
||||
del_in_memory_pool=False,
|
||||
old_last_node=req.last_node,
|
||||
)
|
||||
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
||||
|
||||
if req is self.current_inflight_req:
|
||||
# inflight request would get a new req idx
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
def forward_decode_batch(self, batch: ScheduleBatch):
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
@@ -595,6 +599,9 @@ class ModelTpServer:
|
||||
req.output_ids.append(next_token_id)
|
||||
req.check_finished()
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
|
||||
if req.return_logprob:
|
||||
req.output_token_logprobs.append(
|
||||
(next_token_logprobs[i], next_token_id)
|
||||
@@ -614,12 +621,9 @@ class ModelTpServer:
|
||||
output_spaces_between_special_tokens = []
|
||||
output_meta_info = []
|
||||
output_finished_reason: List[BaseFinishReason] = []
|
||||
finished_indices = []
|
||||
unfinished_indices = []
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req.finished():
|
||||
finished_indices.append(i)
|
||||
else:
|
||||
if not req.finished() and req is not self.current_inflight_req:
|
||||
unfinished_indices.append(i)
|
||||
|
||||
if req.finished() or (
|
||||
@@ -683,34 +687,7 @@ class ModelTpServer:
|
||||
)
|
||||
)
|
||||
|
||||
# Remove finished reqs
|
||||
if finished_indices:
|
||||
# Update radix cache
|
||||
for i in finished_indices:
|
||||
req = batch.reqs[i]
|
||||
self.tree_cache.cache_req(
|
||||
rid=req.rid,
|
||||
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req.req_pool_idx,
|
||||
)
|
||||
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
# Update batch tensors
|
||||
if unfinished_indices:
|
||||
batch.filter_batch(unfinished_indices)
|
||||
else:
|
||||
batch.reqs = []
|
||||
|
||||
def filter_out_inflight(self, batch: ScheduleBatch):
|
||||
# TODO(lsyin): reduce the overhead, make a special version for this
|
||||
if self.current_inflight_req is None:
|
||||
return
|
||||
|
||||
to_remove = batch.reqs.index(self.current_inflight_req)
|
||||
unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
|
||||
|
||||
# Remove finished reqs: update batch tensors
|
||||
batch.filter_batch(unfinished_indices)
|
||||
|
||||
def flush_cache(self):
|
||||
|
||||
@@ -17,7 +17,11 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_req(self, **kwargs):
|
||||
def cache_finished_req(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_unfinished_req(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -1,6 +1,11 @@
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
|
||||
class ChunkCacheEntry:
|
||||
@@ -27,22 +32,31 @@ class ChunkCache(BasePrefixCache):
|
||||
entry = self.entries[rid]
|
||||
return entry.value, entry
|
||||
|
||||
def cache_req(
|
||||
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
|
||||
):
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
||||
if del_in_memory_pool:
|
||||
assert rid in self.entries
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
self.token_to_kv_pool.free(indices)
|
||||
return
|
||||
def cache_finished_req(self, req: "Req", token_ids=None):
|
||||
if token_ids is None:
|
||||
token_ids = (req.input_ids + req.output_ids)[:-1]
|
||||
|
||||
if rid not in self.entries:
|
||||
self.entries[rid] = ChunkCacheEntry(rid, indices)
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
assert req.rid in self.entries
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.token_to_kv_pool.free(kv_indices)
|
||||
|
||||
entry = self.entries[rid]
|
||||
entry.value = indices
|
||||
return indices, entry
|
||||
def cache_unfinished_req(self, req: "Req", token_ids=None):
|
||||
if token_ids is None:
|
||||
token_ids = req.input_ids
|
||||
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
if req.rid not in self.entries:
|
||||
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
|
||||
|
||||
entry = self.entries[req.rid]
|
||||
entry.value = kv_indices
|
||||
return kv_indices, entry
|
||||
|
||||
def insert(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache.
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
|
||||
class TreeNode:
|
||||
@@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache):
|
||||
value = [x for x in key]
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_req(
|
||||
self,
|
||||
token_ids,
|
||||
last_uncached_pos,
|
||||
req_pool_idx,
|
||||
del_in_memory_pool=True,
|
||||
old_last_node=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Insert the request into radix cache
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
||||
new_prefix_len = self.insert(token_ids, indices.clone())
|
||||
def cache_finished_req(self, req: "Req", token_ids=None):
|
||||
"""Cache request when it finishes."""
|
||||
if token_ids is None:
|
||||
token_ids = (req.input_ids + req.output_ids)[:-1]
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
if self.disable:
|
||||
if del_in_memory_pool:
|
||||
self.token_to_kv_pool.free(indices)
|
||||
else:
|
||||
return torch.tensor([], dtype=torch.int32), self.root_node
|
||||
self.token_to_kv_pool.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
||||
|
||||
if del_in_memory_pool:
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
else:
|
||||
cached_indices, new_last_node = self.match_prefix(token_ids)
|
||||
assert len(cached_indices) == len(token_ids)
|
||||
# Remove req slot release the cache lock
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.dec_lock_ref(req.last_node)
|
||||
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req_pool_idx, last_uncached_pos : len(cached_indices)
|
||||
] = cached_indices[last_uncached_pos:]
|
||||
self.dec_lock_ref(old_last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
return cached_indices, new_last_node
|
||||
def cache_unfinished_req(self, req: "Req", token_ids=None):
|
||||
"""Cache request when it is unfinished."""
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
if token_ids is None:
|
||||
token_ids = req.input_ids
|
||||
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
]
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node = self.match_prefix(token_ids)
|
||||
assert len(new_indices) == len(token_ids)
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
|
||||
] = new_indices[len(req.prefix_indices) :]
|
||||
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
req.prefix_indices = new_indices
|
||||
req.last_node = new_last_node
|
||||
|
||||
def pretty_print(self):
|
||||
self._print_helper(self.root_node, 0)
|
||||
|
||||
Reference in New Issue
Block a user