Fix memory leak during abort (#2238)
This commit is contained in:
@@ -231,6 +231,7 @@ class Req:
|
||||
self.tokenizer = None
|
||||
self.finished_reason = None
|
||||
self.stream = False
|
||||
self.to_abort = False
|
||||
|
||||
# For incremental decoding
|
||||
# ----- | --------- read_ids -------|
|
||||
@@ -368,6 +369,10 @@ class Req:
|
||||
if self.finished():
|
||||
return
|
||||
|
||||
if self.to_abort:
|
||||
self.finished_reason = FINISH_ABORT()
|
||||
return
|
||||
|
||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||
self.finished_reason = FINISH_LENGTH(
|
||||
length=self.sampling_params.max_new_tokens
|
||||
|
||||
@@ -579,6 +579,8 @@ class Scheduler:
|
||||
"Image request length is longer than the KV cache pool size or "
|
||||
"the max context length aborting because you cannot truncate the image embeds"
|
||||
)
|
||||
req.image_inputs = None
|
||||
req.origin_input_ids = [0]
|
||||
req.sampling_params.max_new_tokens = 0
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
@@ -1350,13 +1352,15 @@ class Scheduler:
|
||||
|
||||
if to_del is not None:
|
||||
del self.waiting_queue[to_del]
|
||||
logger.debug(f"Abort queued request. {req.rid=}")
|
||||
return
|
||||
|
||||
# Delete requests in the running batch
|
||||
if self.running_batch:
|
||||
for req in self.running_batch.reqs:
|
||||
if req.rid == recv_req.rid and not req.finished():
|
||||
req.finished_reason = FINISH_ABORT()
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
logger.debug(f"Abort running request. {req.rid=}")
|
||||
req.to_abort = True
|
||||
break
|
||||
|
||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||
|
||||
@@ -677,8 +677,14 @@ def run_and_check_memory_leak(
|
||||
enable_mixed_chunk,
|
||||
disable_overlap,
|
||||
chunked_prefill_size,
|
||||
assert_has_abort,
|
||||
):
|
||||
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
|
||||
other_args = [
|
||||
"--chunked-prefill-size",
|
||||
str(chunked_prefill_size),
|
||||
"--log-level",
|
||||
"debug",
|
||||
]
|
||||
if disable_radix_cache:
|
||||
other_args += ["--disable-radix-cache"]
|
||||
if enable_mixed_chunk:
|
||||
@@ -723,14 +729,19 @@ def run_and_check_memory_leak(
|
||||
# Assert success
|
||||
has_new_server = False
|
||||
has_leak = False
|
||||
has_abort = False
|
||||
for line in output_lines:
|
||||
if "The server is fired" in line:
|
||||
has_new_server = True
|
||||
if "leak" in line:
|
||||
has_leak = True
|
||||
if "Abort" in line:
|
||||
has_abort = True
|
||||
|
||||
assert has_new_server
|
||||
assert not has_leak
|
||||
if assert_has_abort:
|
||||
assert has_abort
|
||||
|
||||
|
||||
def run_mmlu_test(
|
||||
@@ -761,6 +772,7 @@ def run_mmlu_test(
|
||||
enable_mixed_chunk,
|
||||
disable_overlap,
|
||||
chunked_prefill_size,
|
||||
assert_has_abort=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -800,4 +812,5 @@ def run_mulit_request_test(
|
||||
enable_mixed_chunk,
|
||||
enable_overlap,
|
||||
chunked_prefill_size,
|
||||
assert_has_abort=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user