Fix memory leak for chunked prefill 2 (#1858)
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
@@ -221,7 +221,7 @@ class Req:
|
||||
self.prefix_indices = []
|
||||
self.extend_input_len = 0
|
||||
self.last_node = None
|
||||
self.is_inflight_req = 0
|
||||
self.is_being_chunked = 0
|
||||
|
||||
# Logprobs (arguments)
|
||||
self.return_logprob = False
|
||||
@@ -888,7 +888,7 @@ class ScheduleBatch:
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
current_inflight_req: Optional[Req] = None,
|
||||
being_chunked_req: Optional[Req] = None,
|
||||
keep_indices: Optional[List[int]] = None,
|
||||
):
|
||||
if keep_indices is None:
|
||||
@@ -896,7 +896,7 @@ class ScheduleBatch:
|
||||
i
|
||||
for i in range(len(self.reqs))
|
||||
if not self.reqs[i].finished()
|
||||
and self.reqs[i] is not current_inflight_req
|
||||
and self.reqs[i] is not being_chunked_req
|
||||
]
|
||||
|
||||
if keep_indices is None or len(keep_indices) == 0:
|
||||
|
||||
@@ -231,7 +231,7 @@ class Scheduler:
|
||||
|
||||
# Init chunked prefill
|
||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||
self.current_inflight_req = None
|
||||
self.being_chunked_req = None
|
||||
self.is_mixed_chunk = (
|
||||
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
||||
)
|
||||
@@ -551,13 +551,13 @@ class Scheduler:
|
||||
and not self.last_batch.forward_mode.is_decode()
|
||||
and not self.last_batch.is_empty()
|
||||
):
|
||||
if self.current_inflight_req:
|
||||
if self.being_chunked_req:
|
||||
self.last_batch.filter_batch(
|
||||
current_inflight_req=self.current_inflight_req
|
||||
being_chunked_req=self.being_chunked_req
|
||||
)
|
||||
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
|
||||
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
||||
# Inflight request keeps its rid but will get a new req_pool_idx.
|
||||
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
|
||||
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
||||
self.batch_is_full = False
|
||||
if not self.last_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
@@ -588,7 +588,7 @@ class Scheduler:
|
||||
# Handle the cases where prefill is not allowed
|
||||
if (
|
||||
self.batch_is_full or len(self.waiting_queue) == 0
|
||||
) and self.current_inflight_req is None:
|
||||
) and self.being_chunked_req is None:
|
||||
return None
|
||||
|
||||
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
@@ -611,13 +611,11 @@ class Scheduler:
|
||||
num_mixed_running,
|
||||
)
|
||||
|
||||
has_inflight = self.current_inflight_req is not None
|
||||
has_inflight = self.being_chunked_req is not None
|
||||
if has_inflight:
|
||||
self.current_inflight_req.init_next_round_input(
|
||||
None if prefix_computed else self.tree_cache
|
||||
)
|
||||
self.current_inflight_req = adder.add_inflight_req(
|
||||
self.current_inflight_req
|
||||
self.being_chunked_req.init_next_round_input()
|
||||
self.being_chunked_req = adder.add_inflight_req(
|
||||
self.being_chunked_req
|
||||
)
|
||||
|
||||
if self.lora_paths:
|
||||
@@ -661,11 +659,11 @@ class Scheduler:
|
||||
]
|
||||
|
||||
if adder.new_inflight_req is not None:
|
||||
assert self.current_inflight_req is None
|
||||
self.current_inflight_req = adder.new_inflight_req
|
||||
assert self.being_chunked_req is None
|
||||
self.being_chunked_req = adder.new_inflight_req
|
||||
|
||||
if self.current_inflight_req:
|
||||
self.current_inflight_req.is_inflight_req += 1
|
||||
if self.being_chunked_req:
|
||||
self.being_chunked_req.is_being_chunked += 1
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0:
|
||||
@@ -833,8 +831,8 @@ class Scheduler:
|
||||
# Check finish conditions
|
||||
logprob_pt = 0
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req.is_inflight_req > 0:
|
||||
req.is_inflight_req -= 1
|
||||
if req.is_being_chunked > 0:
|
||||
req.is_being_chunked -= 1
|
||||
else:
|
||||
# Inflight reqs' prefill is not finished
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
@@ -860,8 +858,8 @@ class Scheduler:
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
req.embedding = embeddings[i]
|
||||
if req.is_inflight_req > 0:
|
||||
req.is_inflight_req -= 1
|
||||
if req.is_being_chunked > 0:
|
||||
req.is_being_chunked -= 1
|
||||
else:
|
||||
# Inflight reqs' prefill is not finished
|
||||
# dummy output token for embedding models
|
||||
|
||||
Reference in New Issue
Block a user