Revert "Fix memory leak when doing chunked prefill" (#1797)

This commit is contained in:
Lianmin Zheng
2024-10-25 10:24:44 -07:00
committed by GitHub
parent 40900baea7
commit c555ce2ca2
6 changed files with 69 additions and 183 deletions

View File

@@ -136,7 +136,7 @@ class PrefillAdder:
self.req_states = None
self.can_run_list = []
self.new_chunked_req = None
self.new_inflight_req = None
self.log_hit_tokens = 0
self.log_input_tokens = 0
@@ -176,7 +176,7 @@ class PrefillAdder:
self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len
def add_being_chunked_req(self, req: Req):
def add_inflight_req(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
@@ -192,13 +192,8 @@ class PrefillAdder:
),
)
if truncated:
# Continue to chunk the request
assert req.is_being_chunked
self.new_chunked_req = req
else:
# Release the being chunked status
req.is_being_chunked = False
# Return if chunked prefill not finished
return req if truncated else None
@contextmanager
def _lock_node(self, last_node: TreeNode):
@@ -267,14 +262,11 @@ class PrefillAdder:
)
else:
# Chunked prefill
assert self.new_chunked_req is None
trunc_len = self.rem_chunk_tokens
req.extend_input_len = trunc_len
req.is_being_chunked = True
req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req)
self.new_chunked_req = req
self.new_inflight_req = req
self._prefill_one_req(0, trunc_len, 0)
return self.budget_state()
@@ -313,18 +305,15 @@ class PrefillAdder:
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
)
else:
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER
# Chunked prefill
assert self.new_chunked_req is None
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
req.is_being_chunked = True
self.can_run_list.append(req)
self.new_chunked_req = req
self.new_inflight_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)