diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index bd2158789..53e390b88 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -124,7 +124,7 @@ class Req: # For vision input self.pixel_values = None self.image_size = None - self.image_offset = 0 + self.image_offset = None self.pad_value = None # Prefix info @@ -162,6 +162,13 @@ class Req: def finished(self) -> bool: return self.finished_reason is not None + def adjust_max_prefix_ids(self): + max_prefix_ids = self.input_ids + if self.return_logprob: + max_prefix_ids = self.input_ids[: self.logprob_start_len] + + return max_prefix_ids + # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 def init_incremental_detokenize(self): first_iter = self.surr_offset is None or self.read_offset is None @@ -444,7 +451,8 @@ class ScheduleBatch: self.pixel_values = [r.pixel_values for r in reqs] self.image_sizes = [r.image_size for r in reqs] self.image_offsets = [ - r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens) + (r.image_offset - p_len) if r.image_offset is not None else 0 + for r, p_len in zip(reqs, prefix_lens) ] self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device) self.extend_num_tokens = extend_num_tokens @@ -596,15 +604,7 @@ class ScheduleBatch: req.vid += 1 # insert the old request into tree_cache - self.tree_cache.cache_req( - rid=req.rid, - token_ids=cur_all_ids, - last_uncached_pos=len(req.prefix_indices), - req_pool_idx=req.req_pool_idx, - ) - - # unlock the last node - self.tree_cache.dec_lock_ref(req.last_node) + self.tree_cache.cache_finished_req(req, cur_all_ids) # re-applying image padding if req.pixel_values is not None: @@ -621,8 +621,7 @@ class ScheduleBatch: jump_forward_reqs.append(req) filter_indices.remove(i) - if len(filter_indices) < len(self.reqs): - self.filter_batch(filter_indices) + self.filter_batch(filter_indices) return jump_forward_reqs @@ -644,6 +643,15 @@ class ScheduleBatch: ] = self.out_cache_loc def filter_batch(self, unfinished_indices: List[int]): + if unfinished_indices is None or len(unfinished_indices) == 0: + # Filter out all requests + self.reqs = [] + return + + if len(unfinished_indices) == len(self.reqs): + # No need to filter + return + self.reqs = [self.reqs[i] for i in unfinished_indices] new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda") self.seq_lens = self.seq_lens[new_indices] @@ -711,6 +719,7 @@ class ScheduleBatch: self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) def sample(self, logits: torch.Tensor): + # TODO(lsyin): move this into a part of layer and run with CUDA Graph # Post process logits logits = logits.contiguous() logits.div_(self.temperatures) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 7d9091157..f3ab1d624 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -232,8 +232,6 @@ class ModelTpServer: if new_batch is not None: # Run a new prefill batch self.forward_prefill_batch(new_batch) - self.cache_filled_batch(new_batch) - self.filter_out_inflight(new_batch) if not new_batch.is_empty(): if self.running_batch is None: @@ -353,26 +351,20 @@ class ModelTpServer: self.waiting_queue.append(req) def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: - # TODO(lsyin): organize this function running_bs = ( len(self.running_batch.reqs) if self.running_batch is not None else 0 ) if running_bs >= self.max_running_requests: - return + return None # Compute matched prefix length for req in self.waiting_queue: req.input_ids = req.origin_input_ids + req.output_ids - try_match_ids = req.input_ids - if req.return_logprob: - try_match_ids = req.input_ids[: req.logprob_start_len] # NOTE: the prefix_indices must always be aligned with last_node - prefix_indices, last_node = self.tree_cache.match_prefix( - rid=req.rid, key=try_match_ids + req.prefix_indices, req.last_node = self.tree_cache.match_prefix( + rid=req.rid, key=req.adjust_max_prefix_ids() ) - req.extend_input_len = len(req.input_ids) - len(prefix_indices) - req.prefix_indices = prefix_indices - req.last_node = last_node + req.extend_input_len = len(req.input_ids) - len(req.prefix_indices) # Get priority queue self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue) @@ -394,6 +386,24 @@ class ModelTpServer: ) for req in self.waiting_queue: + + # FIXME: Move this code into adjust_max_prefix_len + if req.return_logprob and req.normalized_prompt_logprob is None: + # Need at least two tokens to compute normalized logprob + if req.extend_input_len < 2: + delta = 2 - req.extend_input_len + req.extend_input_len += delta + req.prefix_indices = req.prefix_indices[:-delta] + if req.image_offset is not None: + req.image_offset += delta + + if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0: + # Need at least one token to compute logits + req.extend_input_len = 1 + req.prefix_indices = req.prefix_indices[:-1] + if req.image_offset is not None: + req.image_offset += 1 + res = adder.add_one_req(req) if ( not res @@ -470,10 +480,20 @@ class ModelTpServer: pt = 0 for i, req in enumerate(batch.reqs): if req is not self.current_inflight_req: + # Inflight reqs' prefill is not finished req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_ids[i]) req.check_finished() + if req.finished(): + self.tree_cache.cache_finished_req(req) + else: + self.tree_cache.cache_unfinished_req(req) + + if req is self.current_inflight_req: + # Inflight request would get a new req idx + self.req_to_token_pool.free(req.req_pool_idx) + if req.return_logprob: self.add_logprob_return_values(i, req, pt, next_token_ids, output) pt += req.extend_input_len @@ -529,22 +549,6 @@ class ModelTpServer: ) req.output_top_logprobs.append(output.output_top_logprobs[i]) - def cache_filled_batch(self, batch: ScheduleBatch): - for i, req in enumerate(batch.reqs): - new_prefix_indices, new_last_node = self.tree_cache.cache_req( - rid=req.rid, - token_ids=tuple(req.input_ids), - last_uncached_pos=len(req.prefix_indices), - req_pool_idx=req.req_pool_idx, - del_in_memory_pool=False, - old_last_node=req.last_node, - ) - req.prefix_indices, req.last_node = new_prefix_indices, new_last_node - - if req is self.current_inflight_req: - # inflight request would get a new req idx - self.req_to_token_pool.free(req.req_pool_idx) - def forward_decode_batch(self, batch: ScheduleBatch): # Check if decode out of memory if not batch.check_decode_mem(): @@ -595,6 +599,9 @@ class ModelTpServer: req.output_ids.append(next_token_id) req.check_finished() + if req.finished(): + self.tree_cache.cache_finished_req(req) + if req.return_logprob: req.output_token_logprobs.append( (next_token_logprobs[i], next_token_id) @@ -614,12 +621,9 @@ class ModelTpServer: output_spaces_between_special_tokens = [] output_meta_info = [] output_finished_reason: List[BaseFinishReason] = [] - finished_indices = [] unfinished_indices = [] for i, req in enumerate(batch.reqs): - if req.finished(): - finished_indices.append(i) - else: + if not req.finished() and req is not self.current_inflight_req: unfinished_indices.append(i) if req.finished() or ( @@ -683,34 +687,7 @@ class ModelTpServer: ) ) - # Remove finished reqs - if finished_indices: - # Update radix cache - for i in finished_indices: - req = batch.reqs[i] - self.tree_cache.cache_req( - rid=req.rid, - token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1], - last_uncached_pos=len(req.prefix_indices), - req_pool_idx=req.req_pool_idx, - ) - - self.tree_cache.dec_lock_ref(req.last_node) - - # Update batch tensors - if unfinished_indices: - batch.filter_batch(unfinished_indices) - else: - batch.reqs = [] - - def filter_out_inflight(self, batch: ScheduleBatch): - # TODO(lsyin): reduce the overhead, make a special version for this - if self.current_inflight_req is None: - return - - to_remove = batch.reqs.index(self.current_inflight_req) - unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove] - + # Remove finished reqs: update batch tensors batch.filter_batch(unfinished_indices) def flush_cache(self): diff --git a/python/sglang/srt/mem_cache/base_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py similarity index 85% rename from python/sglang/srt/mem_cache/base_cache.py rename to python/sglang/srt/mem_cache/base_prefix_cache.py index fe7e0b23a..fb2b7a627 100644 --- a/python/sglang/srt/mem_cache/base_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -17,7 +17,11 @@ class BasePrefixCache(ABC): pass @abstractmethod - def cache_req(self, **kwargs): + def cache_finished_req(self, **kwargs): + pass + + @abstractmethod + def cache_unfinished_req(self, **kwargs): pass @abstractmethod diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 3509bd1cd..7e3b39450 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -1,6 +1,11 @@ """Cache for chunked prefill, used when RadixCache is disabled.""" -from sglang.srt.mem_cache.base_cache import BasePrefixCache +from typing import TYPE_CHECKING + +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req class ChunkCacheEntry: @@ -27,22 +32,31 @@ class ChunkCache(BasePrefixCache): entry = self.entries[rid] return entry.value, entry - def cache_req( - self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs - ): - indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)] - if del_in_memory_pool: - assert rid in self.entries - self.req_to_token_pool.free(req_pool_idx) - self.token_to_kv_pool.free(indices) - return + def cache_finished_req(self, req: "Req", token_ids=None): + if token_ids is None: + token_ids = (req.input_ids + req.output_ids)[:-1] - if rid not in self.entries: - self.entries[rid] = ChunkCacheEntry(rid, indices) + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + assert req.rid in self.entries + self.req_to_token_pool.free(req.req_pool_idx) + self.token_to_kv_pool.free(kv_indices) - entry = self.entries[rid] - entry.value = indices - return indices, entry + def cache_unfinished_req(self, req: "Req", token_ids=None): + if token_ids is None: + token_ids = req.input_ids + + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + if req.rid not in self.entries: + self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices) + + entry = self.entries[req.rid] + entry.value = kv_indices + return kv_indices, entry def insert(self): raise NotImplementedError diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 347ae002e..c23812049 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -20,10 +20,14 @@ The radix tree data structure for managing the KV cache. import heapq import time from collections import defaultdict +from typing import TYPE_CHECKING import torch -from sglang.srt.mem_cache.base_cache import BasePrefixCache +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req class TreeNode: @@ -85,40 +89,54 @@ class RadixCache(BasePrefixCache): value = [x for x in key] return self._insert_helper(self.root_node, key, value) - def cache_req( - self, - token_ids, - last_uncached_pos, - req_pool_idx, - del_in_memory_pool=True, - old_last_node=None, - **kwargs, - ): - # Insert the request into radix cache - indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)] - new_prefix_len = self.insert(token_ids, indices.clone()) + def cache_finished_req(self, req: "Req", token_ids=None): + """Cache request when it finishes.""" + if token_ids is None: + token_ids = (req.input_ids + req.output_ids)[:-1] + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] if self.disable: - if del_in_memory_pool: - self.token_to_kv_pool.free(indices) - else: - return torch.tensor([], dtype=torch.int32), self.root_node + self.token_to_kv_pool.free(kv_indices) + self.req_to_token_pool.free(req.req_pool_idx) + return # Radix Cache takes one ref in memory pool - self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len]) + new_prefix_len = self.insert(token_ids, kv_indices.clone()) + self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) - if del_in_memory_pool: - self.req_to_token_pool.free(req_pool_idx) - else: - cached_indices, new_last_node = self.match_prefix(token_ids) - assert len(cached_indices) == len(token_ids) + # Remove req slot release the cache lock + self.req_to_token_pool.free(req.req_pool_idx) + self.dec_lock_ref(req.last_node) - self.req_to_token_pool.req_to_token[ - req_pool_idx, last_uncached_pos : len(cached_indices) - ] = cached_indices[last_uncached_pos:] - self.dec_lock_ref(old_last_node) - self.inc_lock_ref(new_last_node) - return cached_indices, new_last_node + def cache_unfinished_req(self, req: "Req", token_ids=None): + """Cache request when it is unfinished.""" + if self.disable: + return + + if token_ids is None: + token_ids = req.input_ids + + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + # Radix Cache takes one ref in memory pool + new_prefix_len = self.insert(token_ids, kv_indices.clone()) + self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) + + # The prefix indices could be updated, reuse it + new_indices, new_last_node = self.match_prefix(token_ids) + assert len(new_indices) == len(token_ids) + self.req_to_token_pool.req_to_token[ + req.req_pool_idx, len(req.prefix_indices) : len(new_indices) + ] = new_indices[len(req.prefix_indices) :] + + self.dec_lock_ref(req.last_node) + self.inc_lock_ref(new_last_node) + req.prefix_indices = new_indices + req.last_node = new_last_node def pretty_print(self): self._print_helper(self.root_node, 0)