diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index aff5eacc1..5dc5dce39 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1371,21 +1371,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # TODO (lianmin): Revisit this. It should be seq_len - 1 self.extend_logprob_start_lens.extend([0] * running_bs) - def new_page_count_next_decode(self): + def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None): page_size = self.token_to_kv_pool_allocator.page_size + requests = ( + self.reqs + if selected_indices is None + else [self.reqs[i] for i in selected_indices] + ) if page_size == 1: - return len(self.reqs) + return len(requests) # In the decoding phase, the length of a request's KV cache should be # the total length of the request minus 1 return ( - sum(1 for req in self.reqs if req.seqlen % page_size == 0) + sum(1 for req in requests if req.seqlen % page_size == 0) if self.enable_overlap - else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0) + else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0) ) - def check_decode_mem(self, buf_multiplier=1): + def check_decode_mem( + self, buf_multiplier=1, selected_indices: Optional[List[int]] = None + ): num_tokens = ( - self.new_page_count_next_decode() + self.new_page_count_next_decode(selected_indices) * buf_multiplier * self.token_to_kv_pool_allocator.page_size ) @@ -1411,34 +1418,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): reverse=True, ) - def get_required_tokens(num_reqs: int): - headroom_for_spec_decode = 0 - if server_args.speculative_algorithm: - headroom_for_spec_decode += ( - num_reqs - * server_args.speculative_eagle_topk - * server_args.speculative_num_steps - + num_reqs * server_args.speculative_num_draft_tokens - ) - return ( - num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode - ) - - def _get_available_size(): - if self.is_hybrid: - return min( - self.token_to_kv_pool_allocator.full_available_size(), - self.token_to_kv_pool_allocator.swa_available_size(), - ) - else: - return self.token_to_kv_pool_allocator.available_size() - retracted_reqs = [] seq_lens_cpu = self.seq_lens.cpu().numpy() first_iter = True - while ( - _get_available_size() < get_required_tokens(len(sorted_indices)) - or first_iter + while first_iter or ( + not self.check_decode_mem(selected_indices=sorted_indices) ): if len(sorted_indices) == 1: # Corner case: only one request left @@ -1492,10 +1476,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): else: self.tree_cache.dec_lock_ref(req.last_node) - # NOTE(lsyin): we should use the newly evictable memory instantly. - num_tokens = len(sorted_indices) * global_config.retract_decode_steps - self._evict_tree_cache_if_needed(num_tokens) - req.reset_for_retract() if len(retracted_reqs) == 0: