Fix oom error for large page size (#4913)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
Zhiqiang Xie
2025-03-30 21:34:21 -07:00
committed by GitHub
parent 4a63bc32b7
commit a169b9f813
2 changed files with 21 additions and 13 deletions

View File

@@ -814,11 +814,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
last_loc: torch.Tensor,
backup_state: bool = False,
):
if (
self.token_to_kv_pool_allocator.available_size()
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
):
if self.tree_cache is not None:
if self.tree_cache is not None:
if (
self.token_to_kv_pool_allocator.available_size()
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
):
self.tree_cache.evict(
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
)
@@ -1116,17 +1116,25 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs)
def new_page_count_next_decode(self):
page_size = self.token_to_kv_pool_allocator.page_size
if page_size == 1:
return len(self.reqs)
return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
def check_decode_mem(self, buf_multiplier=1):
bs = len(self.reqs) * buf_multiplier
if self.token_to_kv_pool_allocator.available_size() >= bs:
tokens_required = (
self.new_page_count_next_decode()
* buf_multiplier
* self.token_to_kv_pool_allocator.page_size
)
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
return True
self.tree_cache.evict(bs)
self.tree_cache.evict(tokens_required)
if self.token_to_kv_pool_allocator.available_size() >= bs:
return True
return False
return self.token_to_kv_pool_allocator.available_size() >= tokens_required
def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory."""