From a2e0424abfc0d9f382331c813b1d96e0ef39d3e0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 31 Oct 2024 14:51:51 -0700 Subject: [PATCH] Fix memory leak for chunked prefill 2 (#1858) Co-authored-by: Liangsheng Yin --- .github/workflows/pr-test.yml | 6 +- docs/hyperparameter_tuning.md | 1 - python/sglang/srt/managers/schedule_batch.py | 6 +- python/sglang/srt/managers/scheduler.py | 38 +++---- scripts/killall_sglang.sh | 4 +- test/srt/run_suite.py | 1 + test/srt/test_radix_attention.py | 112 +++++++++++++++++++ 7 files changed, 138 insertions(+), 30 deletions(-) create mode 100644 test/srt/test_radix_attention.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index c1bf8da5b..1d66a5c7b 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -50,7 +50,7 @@ jobs: timeout-minutes: 20 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 0 --range-end 5 + python3 run_suite.py --suite minimal --range-begin 0 --range-end 4 unit-test-backend-part-2: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -67,7 +67,7 @@ jobs: timeout-minutes: 20 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 5 --range-end 17 + python3 run_suite.py --suite minimal --range-begin 4 --range-end 14 unit-test-backend-part-3: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -84,7 +84,7 @@ jobs: timeout-minutes: 20 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 17 --range-end 20 + python3 run_suite.py --suite minimal --range-begin 14 --range-end 20 unit-test-backend-part-4: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' diff --git a/docs/hyperparameter_tuning.md b/docs/hyperparameter_tuning.md index 5013a80ab..96eb5b4f8 100644 --- a/docs/hyperparameter_tuning.md +++ b/docs/hyperparameter_tuning.md @@ -1,7 +1,6 @@ # Guide on Hyperparameter Tuning ## Achieving Peak Throughput - Achieving a large batch size is the most important thing for attaining high throughput. When the server is running at full load, look for the following in the log: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 85ca560a9..24e82993b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -221,7 +221,7 @@ class Req: self.prefix_indices = [] self.extend_input_len = 0 self.last_node = None - self.is_inflight_req = 0 + self.is_being_chunked = 0 # Logprobs (arguments) self.return_logprob = False @@ -888,7 +888,7 @@ class ScheduleBatch: def filter_batch( self, - current_inflight_req: Optional[Req] = None, + being_chunked_req: Optional[Req] = None, keep_indices: Optional[List[int]] = None, ): if keep_indices is None: @@ -896,7 +896,7 @@ class ScheduleBatch: i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not current_inflight_req + and self.reqs[i] is not being_chunked_req ] if keep_indices is None or len(keep_indices) == 0: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7c7780a64..9682b7260 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -231,7 +231,7 @@ class Scheduler: # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size - self.current_inflight_req = None + self.being_chunked_req = None self.is_mixed_chunk = ( self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) @@ -551,13 +551,13 @@ class Scheduler: and not self.last_batch.forward_mode.is_decode() and not self.last_batch.is_empty() ): - if self.current_inflight_req: + if self.being_chunked_req: self.last_batch.filter_batch( - current_inflight_req=self.current_inflight_req + being_chunked_req=self.being_chunked_req ) - self.tree_cache.cache_unfinished_req(self.current_inflight_req) + self.tree_cache.cache_unfinished_req(self.being_chunked_req) # Inflight request keeps its rid but will get a new req_pool_idx. - self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx) + self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) self.batch_is_full = False if not self.last_batch.is_empty(): if self.running_batch is None: @@ -588,7 +588,7 @@ class Scheduler: # Handle the cases where prefill is not allowed if ( self.batch_is_full or len(self.waiting_queue) == 0 - ) and self.current_inflight_req is None: + ) and self.being_chunked_req is None: return None running_bs = len(self.running_batch.reqs) if self.running_batch else 0 @@ -611,13 +611,11 @@ class Scheduler: num_mixed_running, ) - has_inflight = self.current_inflight_req is not None + has_inflight = self.being_chunked_req is not None if has_inflight: - self.current_inflight_req.init_next_round_input( - None if prefix_computed else self.tree_cache - ) - self.current_inflight_req = adder.add_inflight_req( - self.current_inflight_req + self.being_chunked_req.init_next_round_input() + self.being_chunked_req = adder.add_inflight_req( + self.being_chunked_req ) if self.lora_paths: @@ -661,11 +659,11 @@ class Scheduler: ] if adder.new_inflight_req is not None: - assert self.current_inflight_req is None - self.current_inflight_req = adder.new_inflight_req + assert self.being_chunked_req is None + self.being_chunked_req = adder.new_inflight_req - if self.current_inflight_req: - self.current_inflight_req.is_inflight_req += 1 + if self.being_chunked_req: + self.being_chunked_req.is_being_chunked += 1 # Print stats if self.tp_rank == 0: @@ -833,8 +831,8 @@ class Scheduler: # Check finish conditions logprob_pt = 0 for i, req in enumerate(batch.reqs): - if req.is_inflight_req > 0: - req.is_inflight_req -= 1 + if req.is_being_chunked > 0: + req.is_being_chunked -= 1 else: # Inflight reqs' prefill is not finished req.completion_tokens_wo_jump_forward += 1 @@ -860,8 +858,8 @@ class Scheduler: # Check finish conditions for i, req in enumerate(batch.reqs): req.embedding = embeddings[i] - if req.is_inflight_req > 0: - req.is_inflight_req -= 1 + if req.is_being_chunked > 0: + req.is_being_chunked -= 1 else: # Inflight reqs' prefill is not finished # dummy output token for embedding models diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index 9a57e3b1d..203da6040 100644 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -1,6 +1,4 @@ -""" -Kill all SGLang processes and free the GPU memory. -""" +# Kill all SGLang processes and free the GPU memory. kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 1237df709..f7277f03d 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,6 +19,7 @@ suites = { "test_openai_server.py", "test_overlap_schedule.py", "test_pytorch_sampling_backend.py", + "test_radix_attention.py", "test_retract_decode.py", "test_server_args.py", "test_skip_tokenizer_init.py", diff --git a/test/srt/test_radix_attention.py b/test/srt/test_radix_attention.py new file mode 100644 index 000000000..1c20d95eb --- /dev/null +++ b/test/srt/test_radix_attention.py @@ -0,0 +1,112 @@ +import os +import random +import unittest + +import requests + +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + kill_child_process, + popen_launch_server, +) + + +def gen_radix_tree(num_nodes=400, chunk_len=256): + num0 = num_nodes // 2 + num1 = num_nodes - num0 + nodes = [{"input_ids": [37] * 117, "decode_len": 217}] + for _ in range(num0): + parent = random.choice(nodes) + unique_len = random.randint(0, chunk_len) + decode_len = random.randint(0, chunk_len) + token_id = random.randint(0, 32000) + child = { + "input_ids": parent["input_ids"] + [token_id] * unique_len, + "decode_len": decode_len, + } + nodes.append(child) + + while num1 > 0: + num_branch = random.randint(1, min(num1, 10)) + parent = random.choice(nodes) + for _ in range(num_branch): + unique_len = random.randint(0, chunk_len) + decode_len = random.randint(0, chunk_len) + token_id = random.randint(0, 32000) + child = { + "input_ids": parent["input_ids"] + [token_id] * unique_len, + "decode_len": decode_len, + } + nodes.append(child) + + num1 -= num_branch + + random.shuffle(nodes) + return nodes + + +def run_test(base_url, nodes): + data = { + "input_ids": [node["input_ids"] for node in nodes], + "sampling_params": [ + {"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes + ], + } + + res = requests.post(base_url + "/generate", json=data) + assert res.status_code == 200 + + +class TestRadixCacheFCFS(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chunked-prefill-size", + "128", + "--max-total-tokens", + "20000", + "--schedule-policy", + "fcfs", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid, include_self=True) + + def test_radix_attention(self): + nodes = gen_radix_tree() + run_test(self.base_url, nodes) + + +class TestRadixCacheLPM(TestRadixCacheFCFS): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chunked-prefill-size", + "128", + "--max-total-tokens", + "20000", + "--schedule-policy", + "lpm", + ], + ) + + +if __name__ == "__main__": + os.environ["SGLANG_TEST_RETRACT"] = "true" + unittest.main()