Fix memory leak during abort (#2238)
This commit is contained in:
8
.github/workflows/pr-test.yml
vendored
8
.github/workflows/pr-test.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
|||||||
timeout-minutes: 25
|
timeout-minutes: 25
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 0 --range-end 5
|
python3 run_suite.py --suite minimal --range-begin 0 --range-end 6
|
||||||
|
|
||||||
unit-test-backend-part-2:
|
unit-test-backend-part-2:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -67,7 +67,7 @@ jobs:
|
|||||||
timeout-minutes: 25
|
timeout-minutes: 25
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 5 --range-end 14
|
python3 run_suite.py --suite minimal --range-begin 6 --range-end 15
|
||||||
|
|
||||||
unit-test-backend-part-3:
|
unit-test-backend-part-3:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -84,7 +84,7 @@ jobs:
|
|||||||
timeout-minutes: 25
|
timeout-minutes: 25
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 14 --range-end 23
|
python3 run_suite.py --suite minimal --range-begin 15 --range-end 24
|
||||||
|
|
||||||
unit-test-backend-part-4:
|
unit-test-backend-part-4:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -101,7 +101,7 @@ jobs:
|
|||||||
timeout-minutes: 25
|
timeout-minutes: 25
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite minimal --range-begin 23
|
python3 run_suite.py --suite minimal --range-begin 24
|
||||||
|
|
||||||
unit-test-backend-2-gpu-part-1:
|
unit-test-backend-2-gpu-part-1:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
|
|||||||
@@ -231,6 +231,7 @@ class Req:
|
|||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.finished_reason = None
|
self.finished_reason = None
|
||||||
self.stream = False
|
self.stream = False
|
||||||
|
self.to_abort = False
|
||||||
|
|
||||||
# For incremental decoding
|
# For incremental decoding
|
||||||
# ----- | --------- read_ids -------|
|
# ----- | --------- read_ids -------|
|
||||||
@@ -368,6 +369,10 @@ class Req:
|
|||||||
if self.finished():
|
if self.finished():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self.to_abort:
|
||||||
|
self.finished_reason = FINISH_ABORT()
|
||||||
|
return
|
||||||
|
|
||||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||||
self.finished_reason = FINISH_LENGTH(
|
self.finished_reason = FINISH_LENGTH(
|
||||||
length=self.sampling_params.max_new_tokens
|
length=self.sampling_params.max_new_tokens
|
||||||
|
|||||||
@@ -579,6 +579,8 @@ class Scheduler:
|
|||||||
"Image request length is longer than the KV cache pool size or "
|
"Image request length is longer than the KV cache pool size or "
|
||||||
"the max context length aborting because you cannot truncate the image embeds"
|
"the max context length aborting because you cannot truncate the image embeds"
|
||||||
)
|
)
|
||||||
|
req.image_inputs = None
|
||||||
|
req.origin_input_ids = [0]
|
||||||
req.sampling_params.max_new_tokens = 0
|
req.sampling_params.max_new_tokens = 0
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
return
|
return
|
||||||
@@ -1350,13 +1352,15 @@ class Scheduler:
|
|||||||
|
|
||||||
if to_del is not None:
|
if to_del is not None:
|
||||||
del self.waiting_queue[to_del]
|
del self.waiting_queue[to_del]
|
||||||
|
logger.debug(f"Abort queued request. {req.rid=}")
|
||||||
|
return
|
||||||
|
|
||||||
# Delete requests in the running batch
|
# Delete requests in the running batch
|
||||||
if self.running_batch:
|
if self.running_batch:
|
||||||
for req in self.running_batch.reqs:
|
for req in self.running_batch.reqs:
|
||||||
if req.rid == recv_req.rid and not req.finished():
|
if req.rid == recv_req.rid and not req.finished():
|
||||||
req.finished_reason = FINISH_ABORT()
|
logger.debug(f"Abort running request. {req.rid=}")
|
||||||
self.tree_cache.cache_finished_req(req)
|
req.to_abort = True
|
||||||
break
|
break
|
||||||
|
|
||||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||||
|
|||||||
@@ -677,8 +677,14 @@ def run_and_check_memory_leak(
|
|||||||
enable_mixed_chunk,
|
enable_mixed_chunk,
|
||||||
disable_overlap,
|
disable_overlap,
|
||||||
chunked_prefill_size,
|
chunked_prefill_size,
|
||||||
|
assert_has_abort,
|
||||||
):
|
):
|
||||||
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
|
other_args = [
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
str(chunked_prefill_size),
|
||||||
|
"--log-level",
|
||||||
|
"debug",
|
||||||
|
]
|
||||||
if disable_radix_cache:
|
if disable_radix_cache:
|
||||||
other_args += ["--disable-radix-cache"]
|
other_args += ["--disable-radix-cache"]
|
||||||
if enable_mixed_chunk:
|
if enable_mixed_chunk:
|
||||||
@@ -723,14 +729,19 @@ def run_and_check_memory_leak(
|
|||||||
# Assert success
|
# Assert success
|
||||||
has_new_server = False
|
has_new_server = False
|
||||||
has_leak = False
|
has_leak = False
|
||||||
|
has_abort = False
|
||||||
for line in output_lines:
|
for line in output_lines:
|
||||||
if "The server is fired" in line:
|
if "The server is fired" in line:
|
||||||
has_new_server = True
|
has_new_server = True
|
||||||
if "leak" in line:
|
if "leak" in line:
|
||||||
has_leak = True
|
has_leak = True
|
||||||
|
if "Abort" in line:
|
||||||
|
has_abort = True
|
||||||
|
|
||||||
assert has_new_server
|
assert has_new_server
|
||||||
assert not has_leak
|
assert not has_leak
|
||||||
|
if assert_has_abort:
|
||||||
|
assert has_abort
|
||||||
|
|
||||||
|
|
||||||
def run_mmlu_test(
|
def run_mmlu_test(
|
||||||
@@ -761,6 +772,7 @@ def run_mmlu_test(
|
|||||||
enable_mixed_chunk,
|
enable_mixed_chunk,
|
||||||
disable_overlap,
|
disable_overlap,
|
||||||
chunked_prefill_size,
|
chunked_prefill_size,
|
||||||
|
assert_has_abort=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -800,4 +812,5 @@ def run_mulit_request_test(
|
|||||||
enable_mixed_chunk,
|
enable_mixed_chunk,
|
||||||
enable_overlap,
|
enable_overlap,
|
||||||
chunked_prefill_size,
|
chunked_prefill_size,
|
||||||
|
assert_has_abort=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ suites = {
|
|||||||
"models/test_lora.py",
|
"models/test_lora.py",
|
||||||
"models/test_reward_models.py",
|
"models/test_reward_models.py",
|
||||||
"sampling/penaltylib",
|
"sampling/penaltylib",
|
||||||
|
"test_abort.py",
|
||||||
"test_chunked_prefill.py",
|
"test_chunked_prefill.py",
|
||||||
"test_double_sparsity.py",
|
"test_double_sparsity.py",
|
||||||
"test_embedding_openai_server.py",
|
"test_embedding_openai_server.py",
|
||||||
|
|||||||
54
test/srt/test_abort.py
Normal file
54
test/srt/test_abort.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import multiprocessing
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.test.test_utils import run_and_check_memory_leak
|
||||||
|
|
||||||
|
|
||||||
|
class TestAbort(unittest.TestCase):
|
||||||
|
def workload_func(self, base_url, model):
|
||||||
|
def process_func():
|
||||||
|
def run_one(_):
|
||||||
|
prompt = """
|
||||||
|
System: You are a helpful assistant.
|
||||||
|
User: What is the capital of France?
|
||||||
|
Assistant: The capital of France is
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/generate",
|
||||||
|
json={
|
||||||
|
"text": prompt,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 2048,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ret = response.json()
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(16) as executor:
|
||||||
|
list(executor.map(run_one, list(range(16))))
|
||||||
|
|
||||||
|
p = multiprocessing.Process(target=process_func)
|
||||||
|
p.start()
|
||||||
|
time.sleep(0.5)
|
||||||
|
p.terminate()
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
def test_memory_leak(self):
|
||||||
|
run_and_check_memory_leak(
|
||||||
|
self.workload_func,
|
||||||
|
disable_radix_cache=False,
|
||||||
|
enable_mixed_chunk=False,
|
||||||
|
disable_overlap=False,
|
||||||
|
chunked_prefill_size=8192,
|
||||||
|
assert_has_abort=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user