Fix oom error for large page size (#4913)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -814,11 +814,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
last_loc: torch.Tensor,
|
last_loc: torch.Tensor,
|
||||||
backup_state: bool = False,
|
backup_state: bool = False,
|
||||||
):
|
):
|
||||||
if (
|
if self.tree_cache is not None:
|
||||||
self.token_to_kv_pool_allocator.available_size()
|
if (
|
||||||
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
):
|
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||||
if self.tree_cache is not None:
|
):
|
||||||
self.tree_cache.evict(
|
self.tree_cache.evict(
|
||||||
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||||
)
|
)
|
||||||
@@ -1116,17 +1116,25 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
||||||
self.extend_logprob_start_lens.extend([0] * running_bs)
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
||||||
|
|
||||||
|
def new_page_count_next_decode(self):
|
||||||
|
page_size = self.token_to_kv_pool_allocator.page_size
|
||||||
|
if page_size == 1:
|
||||||
|
return len(self.reqs)
|
||||||
|
return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
|
||||||
|
|
||||||
def check_decode_mem(self, buf_multiplier=1):
|
def check_decode_mem(self, buf_multiplier=1):
|
||||||
bs = len(self.reqs) * buf_multiplier
|
tokens_required = (
|
||||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
self.new_page_count_next_decode()
|
||||||
|
* buf_multiplier
|
||||||
|
* self.token_to_kv_pool_allocator.page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
self.tree_cache.evict(bs)
|
self.tree_cache.evict(tokens_required)
|
||||||
|
|
||||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
return self.token_to_kv_pool_allocator.available_size() >= tokens_required
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def retract_decode(self, server_args: ServerArgs):
|
def retract_decode(self, server_args: ServerArgs):
|
||||||
"""Retract the decoding requests when there is not enough memory."""
|
"""Retract the decoding requests when there is not enough memory."""
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class TestEAGLEEngine(CustomTestCase):
|
|||||||
if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
|
if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
|
||||||
self.assertGreater(acc_length, 3.6)
|
self.assertGreater(acc_length, 3.6)
|
||||||
else:
|
else:
|
||||||
self.assertGreater(acc_length, 2.6)
|
self.assertGreater(acc_length, 2.5)
|
||||||
|
|
||||||
|
|
||||||
class TestEAGLEEngineTokenMap(TestEAGLEEngine):
|
class TestEAGLEEngineTokenMap(TestEAGLEEngine):
|
||||||
|
|||||||
Reference in New Issue
Block a user