diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a327f60dc..cafb69724 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -814,11 +814,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): last_loc: torch.Tensor, backup_state: bool = False, ): - if ( - self.token_to_kv_pool_allocator.available_size() - < len(seq_lens) * self.token_to_kv_pool_allocator.page_size - ): - if self.tree_cache is not None: + if self.tree_cache is not None: + if ( + self.token_to_kv_pool_allocator.available_size() + < len(seq_lens) * self.token_to_kv_pool_allocator.page_size + ): self.tree_cache.evict( len(seq_lens) * self.token_to_kv_pool_allocator.page_size, ) @@ -1116,17 +1116,25 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # TODO (lianmin): Revisit this. It should be seq_len - 1 self.extend_logprob_start_lens.extend([0] * running_bs) + def new_page_count_next_decode(self): + page_size = self.token_to_kv_pool_allocator.page_size + if page_size == 1: + return len(self.reqs) + return sum(1 for req in self.reqs if req.seqlen % page_size == 0) + def check_decode_mem(self, buf_multiplier=1): - bs = len(self.reqs) * buf_multiplier - if self.token_to_kv_pool_allocator.available_size() >= bs: + tokens_required = ( + self.new_page_count_next_decode() + * buf_multiplier + * self.token_to_kv_pool_allocator.page_size + ) + + if self.token_to_kv_pool_allocator.available_size() >= tokens_required: return True - self.tree_cache.evict(bs) + self.tree_cache.evict(tokens_required) - if self.token_to_kv_pool_allocator.available_size() >= bs: - return True - - return False + return self.token_to_kv_pool_allocator.available_size() >= tokens_required def retract_decode(self, server_args: ServerArgs): """Retract the decoding requests when there is not enough memory.""" diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index cc490df25..8bd0b2633 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -144,7 +144,7 @@ class TestEAGLEEngine(CustomTestCase): if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST: self.assertGreater(acc_length, 3.6) else: - self.assertGreater(acc_length, 2.6) + self.assertGreater(acc_length, 2.5) class TestEAGLEEngineTokenMap(TestEAGLEEngine):