[Bugfix] Retract not releasing enough memory when page size > 1 (#9989)
This commit is contained in:
@@ -1371,21 +1371,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
||||||
self.extend_logprob_start_lens.extend([0] * running_bs)
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
||||||
|
|
||||||
def new_page_count_next_decode(self):
|
def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
|
||||||
page_size = self.token_to_kv_pool_allocator.page_size
|
page_size = self.token_to_kv_pool_allocator.page_size
|
||||||
|
requests = (
|
||||||
|
self.reqs
|
||||||
|
if selected_indices is None
|
||||||
|
else [self.reqs[i] for i in selected_indices]
|
||||||
|
)
|
||||||
if page_size == 1:
|
if page_size == 1:
|
||||||
return len(self.reqs)
|
return len(requests)
|
||||||
# In the decoding phase, the length of a request's KV cache should be
|
# In the decoding phase, the length of a request's KV cache should be
|
||||||
# the total length of the request minus 1
|
# the total length of the request minus 1
|
||||||
return (
|
return (
|
||||||
sum(1 for req in self.reqs if req.seqlen % page_size == 0)
|
sum(1 for req in requests if req.seqlen % page_size == 0)
|
||||||
if self.enable_overlap
|
if self.enable_overlap
|
||||||
else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
|
else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_decode_mem(self, buf_multiplier=1):
|
def check_decode_mem(
|
||||||
|
self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
|
||||||
|
):
|
||||||
num_tokens = (
|
num_tokens = (
|
||||||
self.new_page_count_next_decode()
|
self.new_page_count_next_decode(selected_indices)
|
||||||
* buf_multiplier
|
* buf_multiplier
|
||||||
* self.token_to_kv_pool_allocator.page_size
|
* self.token_to_kv_pool_allocator.page_size
|
||||||
)
|
)
|
||||||
@@ -1411,34 +1418,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_required_tokens(num_reqs: int):
|
|
||||||
headroom_for_spec_decode = 0
|
|
||||||
if server_args.speculative_algorithm:
|
|
||||||
headroom_for_spec_decode += (
|
|
||||||
num_reqs
|
|
||||||
* server_args.speculative_eagle_topk
|
|
||||||
* server_args.speculative_num_steps
|
|
||||||
+ num_reqs * server_args.speculative_num_draft_tokens
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_available_size():
|
|
||||||
if self.is_hybrid:
|
|
||||||
return min(
|
|
||||||
self.token_to_kv_pool_allocator.full_available_size(),
|
|
||||||
self.token_to_kv_pool_allocator.swa_available_size(),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self.token_to_kv_pool_allocator.available_size()
|
|
||||||
|
|
||||||
retracted_reqs = []
|
retracted_reqs = []
|
||||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||||
first_iter = True
|
first_iter = True
|
||||||
while (
|
while first_iter or (
|
||||||
_get_available_size() < get_required_tokens(len(sorted_indices))
|
not self.check_decode_mem(selected_indices=sorted_indices)
|
||||||
or first_iter
|
|
||||||
):
|
):
|
||||||
if len(sorted_indices) == 1:
|
if len(sorted_indices) == 1:
|
||||||
# Corner case: only one request left
|
# Corner case: only one request left
|
||||||
@@ -1492,10 +1476,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
else:
|
else:
|
||||||
self.tree_cache.dec_lock_ref(req.last_node)
|
self.tree_cache.dec_lock_ref(req.last_node)
|
||||||
|
|
||||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
|
||||||
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
|
|
||||||
self._evict_tree_cache_if_needed(num_tokens)
|
|
||||||
|
|
||||||
req.reset_for_retract()
|
req.reset_for_retract()
|
||||||
|
|
||||||
if len(retracted_reqs) == 0:
|
if len(retracted_reqs) == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user