[Bugfix] Retract not releasing enough memory when page size > 1 (#9989)
This commit is contained in:
@@ -1371,21 +1371,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
||||
self.extend_logprob_start_lens.extend([0] * running_bs)
|
||||
|
||||
def new_page_count_next_decode(self):
|
||||
def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
|
||||
page_size = self.token_to_kv_pool_allocator.page_size
|
||||
requests = (
|
||||
self.reqs
|
||||
if selected_indices is None
|
||||
else [self.reqs[i] for i in selected_indices]
|
||||
)
|
||||
if page_size == 1:
|
||||
return len(self.reqs)
|
||||
return len(requests)
|
||||
# In the decoding phase, the length of a request's KV cache should be
|
||||
# the total length of the request minus 1
|
||||
return (
|
||||
sum(1 for req in self.reqs if req.seqlen % page_size == 0)
|
||||
sum(1 for req in requests if req.seqlen % page_size == 0)
|
||||
if self.enable_overlap
|
||||
else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
|
||||
else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
|
||||
)
|
||||
|
||||
def check_decode_mem(self, buf_multiplier=1):
|
||||
def check_decode_mem(
|
||||
self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
|
||||
):
|
||||
num_tokens = (
|
||||
self.new_page_count_next_decode()
|
||||
self.new_page_count_next_decode(selected_indices)
|
||||
* buf_multiplier
|
||||
* self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
@@ -1411,34 +1418,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
def get_required_tokens(num_reqs: int):
|
||||
headroom_for_spec_decode = 0
|
||||
if server_args.speculative_algorithm:
|
||||
headroom_for_spec_decode += (
|
||||
num_reqs
|
||||
* server_args.speculative_eagle_topk
|
||||
* server_args.speculative_num_steps
|
||||
+ num_reqs * server_args.speculative_num_draft_tokens
|
||||
)
|
||||
return (
|
||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||
)
|
||||
|
||||
def _get_available_size():
|
||||
if self.is_hybrid:
|
||||
return min(
|
||||
self.token_to_kv_pool_allocator.full_available_size(),
|
||||
self.token_to_kv_pool_allocator.swa_available_size(),
|
||||
)
|
||||
else:
|
||||
return self.token_to_kv_pool_allocator.available_size()
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while (
|
||||
_get_available_size() < get_required_tokens(len(sorted_indices))
|
||||
or first_iter
|
||||
while first_iter or (
|
||||
not self.check_decode_mem(selected_indices=sorted_indices)
|
||||
):
|
||||
if len(sorted_indices) == 1:
|
||||
# Corner case: only one request left
|
||||
@@ -1492,10 +1476,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
else:
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
if len(retracted_reqs) == 0:
|
||||
|
||||
Reference in New Issue
Block a user