Fix memory leak for chunked prefill 2 (#1858)

Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
Lianmin Zheng
2024-10-31 14:51:51 -07:00
committed by GitHub
parent 8ce202a493
commit a2e0424abf
7 changed files with 138 additions and 30 deletions

View File

@@ -221,7 +221,7 @@ class Req:
self.prefix_indices = []
self.extend_input_len = 0
self.last_node = None
self.is_inflight_req = 0
self.is_being_chunked = 0
# Logprobs (arguments)
self.return_logprob = False
@@ -888,7 +888,7 @@ class ScheduleBatch:
def filter_batch(
self,
current_inflight_req: Optional[Req] = None,
being_chunked_req: Optional[Req] = None,
keep_indices: Optional[List[int]] = None,
):
if keep_indices is None:
@@ -896,7 +896,7 @@ class ScheduleBatch:
i
for i in range(len(self.reqs))
if not self.reqs[i].finished()
and self.reqs[i] is not current_inflight_req
and self.reqs[i] is not being_chunked_req
]
if keep_indices is None or len(keep_indices) == 0:

View File

@@ -231,7 +231,7 @@ class Scheduler:
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
self.being_chunked_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
@@ -551,13 +551,13 @@ class Scheduler:
and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty()
):
if self.current_inflight_req:
if self.being_chunked_req:
self.last_batch.filter_batch(
current_inflight_req=self.current_inflight_req
being_chunked_req=self.being_chunked_req
)
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Inflight request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
self.batch_is_full = False
if not self.last_batch.is_empty():
if self.running_batch is None:
@@ -588,7 +588,7 @@ class Scheduler:
# Handle the cases where prefill is not allowed
if (
self.batch_is_full or len(self.waiting_queue) == 0
) and self.current_inflight_req is None:
) and self.being_chunked_req is None:
return None
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
@@ -611,13 +611,11 @@ class Scheduler:
num_mixed_running,
)
has_inflight = self.current_inflight_req is not None
has_inflight = self.being_chunked_req is not None
if has_inflight:
self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache
)
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
self.being_chunked_req.init_next_round_input()
self.being_chunked_req = adder.add_inflight_req(
self.being_chunked_req
)
if self.lora_paths:
@@ -661,11 +659,11 @@ class Scheduler:
]
if adder.new_inflight_req is not None:
assert self.current_inflight_req is None
self.current_inflight_req = adder.new_inflight_req
assert self.being_chunked_req is None
self.being_chunked_req = adder.new_inflight_req
if self.current_inflight_req:
self.current_inflight_req.is_inflight_req += 1
if self.being_chunked_req:
self.being_chunked_req.is_being_chunked += 1
# Print stats
if self.tp_rank == 0:
@@ -833,8 +831,8 @@ class Scheduler:
# Check finish conditions
logprob_pt = 0
for i, req in enumerate(batch.reqs):
if req.is_inflight_req > 0:
req.is_inflight_req -= 1
if req.is_being_chunked > 0:
req.is_being_chunked -= 1
else:
# Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1
@@ -860,8 +858,8 @@ class Scheduler:
# Check finish conditions
for i, req in enumerate(batch.reqs):
req.embedding = embeddings[i]
if req.is_inflight_req > 0:
req.is_inflight_req -= 1
if req.is_being_chunked > 0:
req.is_being_chunked -= 1
else:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models