From c555ce2ca20cd8a2fc87a0e048c39c181614388e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 25 Oct 2024 10:24:44 -0700 Subject: [PATCH] Revert "Fix memory leak when doing chunked prefill" (#1797) --- python/sglang/global_config.py | 12 +- python/sglang/srt/managers/schedule_batch.py | 7 +- python/sglang/srt/managers/schedule_policy.py | 25 ++-- python/sglang/srt/managers/scheduler.py | 95 +++++++++------ python/sglang/test/test_utils.py | 1 - test/srt/test_radix_attention.py | 112 ------------------ 6 files changed, 69 insertions(+), 183 deletions(-) delete mode 100644 test/srt/test_radix_attention.py diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index e84526698..5e7290edc 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -15,7 +15,7 @@ class GlobalConfig: # Runtime constants: New generation token ratio estimation self.init_new_token_ratio = 0.7 - self.min_new_token_ratio = 0.1 + self.base_min_new_token_ratio = 0.1 self.new_token_ratio_decay = 0.001 # Runtime constants: others @@ -32,15 +32,5 @@ class GlobalConfig: self.enable_precache_with_tracing = True self.enable_parallel_encoding = True - def adjust_new_token_ratio(self, schedule_conservativeness=1): - assert schedule_conservativeness >= 0, "Invalid schedule_conservativeness" - min_new_token_ratio = min( - self.min_new_token_ratio * schedule_conservativeness, - 1.0, - ) - init_new_token_ratio = max(self.init_new_token_ratio, min_new_token_ratio) - - return min_new_token_ratio, init_new_token_ratio - global_config = GlobalConfig() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 39fc1e558..fcd06d8cc 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -222,7 +222,7 @@ class Req: self.prefix_indices = [] self.extend_input_len = 0 self.last_node = None - self.is_being_chunked = False + self.is_inflight_req = 0 # Logprobs (arguments) self.return_logprob = False @@ -906,14 +906,15 @@ class ScheduleBatch: def filter_batch( self, - being_chunked_req: Optional[Req] = None, + current_inflight_req: Optional[Req] = None, keep_indices: Optional[List[int]] = None, ): if keep_indices is None: keep_indices = [ i for i in range(len(self.reqs)) - if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req + if not self.reqs[i].finished() + and self.reqs[i] is not current_inflight_req ] if keep_indices is None or len(keep_indices) == 0: diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index a5362ff7c..45c9be37a 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -136,7 +136,7 @@ class PrefillAdder: self.req_states = None self.can_run_list = [] - self.new_chunked_req = None + self.new_inflight_req = None self.log_hit_tokens = 0 self.log_input_tokens = 0 @@ -176,7 +176,7 @@ class PrefillAdder: self.log_hit_tokens += prefix_len self.log_input_tokens += extend_input_len - def add_being_chunked_req(self, req: Req): + def add_inflight_req(self, req: Req): truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] @@ -192,13 +192,8 @@ class PrefillAdder: ), ) - if truncated: - # Continue to chunk the request - assert req.is_being_chunked - self.new_chunked_req = req - else: - # Release the being chunked status - req.is_being_chunked = False + # Return if chunked prefill not finished + return req if truncated else None @contextmanager def _lock_node(self, last_node: TreeNode): @@ -267,14 +262,11 @@ class PrefillAdder: ) else: # Chunked prefill - assert self.new_chunked_req is None - trunc_len = self.rem_chunk_tokens req.extend_input_len = trunc_len - req.is_being_chunked = True req.fill_ids = req.fill_ids[:trunc_len] self.can_run_list.append(req) - self.new_chunked_req = req + self.new_inflight_req = req self._prefill_one_req(0, trunc_len, 0) return self.budget_state() @@ -313,18 +305,15 @@ class PrefillAdder: min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), ) else: + # Chunked prefill trunc_len = self.rem_chunk_tokens if trunc_len == 0: return AddReqResult.OTHER - # Chunked prefill - assert self.new_chunked_req is None - req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] - req.is_being_chunked = True self.can_run_list.append(req) - self.new_chunked_req = req + self.new_inflight_req = req self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 76e3be073..e9bf7be8e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -219,28 +219,35 @@ class Scheduler: # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size - self.being_chunked_req = None + self.current_inflight_req = None self.is_mixed_chunk = ( self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) # Init the FSM cache for constrained generation - self.regex_fsm_cache = FSMCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - skip_tokenizer_init=server_args.skip_tokenizer_init, - constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, - ) + if not server_args.skip_tokenizer_init: + self.regex_fsm_cache = FSMCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + skip_tokenizer_init=server_args.skip_tokenizer_init, + constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, + ) self.jump_forward_cache = JumpForwardCache() # Init new token estimation - self.min_new_token_ratio, self.init_new_token_ratio = ( - global_config.adjust_new_token_ratio(server_args.schedule_conservativeness) + assert ( + server_args.schedule_conservativeness >= 0 + ), "Invalid schedule_conservativeness" + self.min_new_token_ratio = min( + global_config.base_min_new_token_ratio + * server_args.schedule_conservativeness, + 1.0, ) - self.new_token_ratio = self.init_new_token_ratio + self.new_token_ratio = self.min_new_token_ratio + self.new_token_ratio_decay = global_config.new_token_ratio_decay self.batch_is_full = False # Init profiler @@ -287,7 +294,7 @@ class Scheduler: self.process_batch_result(batch, result) else: self.check_memory() - self.new_token_ratio = self.init_new_token_ratio + self.new_token_ratio = global_config.init_new_token_ratio self.last_batch = batch @@ -314,7 +321,7 @@ class Scheduler: self.process_batch_result(tmp_batch, tmp_result) elif batch is None: self.check_memory() - self.new_token_ratio = self.init_new_token_ratio + self.new_token_ratio = global_config.init_new_token_ratio self.last_batch = batch @@ -492,18 +499,20 @@ class Scheduler: ) exit(1) if crash_on_warning else None - def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: + def get_next_batch_to_run(self): # Merge the prefill batch into the running batch if ( self.last_batch and not self.last_batch.forward_mode.is_decode() and not self.last_batch.is_empty() ): - if self.being_chunked_req: - self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req) - self.tree_cache.cache_unfinished_req(self.being_chunked_req) - # Being chunked request keeps its rid but will get a new req_pool_idx. - self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) + if self.current_inflight_req: + self.last_batch.filter_batch( + current_inflight_req=self.current_inflight_req + ) + self.tree_cache.cache_unfinished_req(self.current_inflight_req) + # Inflight request keeps its rid but will get a new req_pool_idx. + self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx) self.batch_is_full = False if not self.last_batch.is_empty(): if self.running_batch is None: @@ -534,7 +543,7 @@ class Scheduler: # Handle the cases where prefill is not allowed if ( self.batch_is_full or len(self.waiting_queue) == 0 - ) and self.being_chunked_req is None: + ) and self.current_inflight_req is None: return None running_bs = len(self.running_batch.reqs) if self.running_batch else 0 @@ -557,6 +566,15 @@ class Scheduler: num_mixed_running, ) + has_inflight = self.current_inflight_req is not None + if has_inflight: + self.current_inflight_req.init_next_round_input( + None if prefix_computed else self.tree_cache + ) + self.current_inflight_req = adder.add_inflight_req( + self.current_inflight_req + ) + if self.lora_paths: lora_set = ( set([req.lora_path for req in self.running_batch.reqs]) @@ -564,13 +582,6 @@ class Scheduler: else set([]) ) - # NOTE: if there is request being chunked, we always add it first - has_being_chunked = self.being_chunked_req is not None - if has_being_chunked: - # NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result - self.being_chunked_req.init_next_round_input() - adder.add_being_chunked_req(self.being_chunked_req) - # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: if ( @@ -604,8 +615,12 @@ class Scheduler: x for x in self.waiting_queue if x not in set(can_run_list) ] - # Update new round being chunked request - self.being_chunked_req = adder.new_chunked_req + if adder.new_inflight_req is not None: + assert self.current_inflight_req is None + self.current_inflight_req = adder.new_inflight_req + + if self.current_inflight_req: + self.current_inflight_req.is_inflight_req += 1 # Print stats if self.tp_rank == 0: @@ -634,7 +649,7 @@ class Scheduler: f"#cached-token: {adder.log_hit_tokens}, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" + f"#queue-req: {len(self.waiting_queue) + has_inflight}" ) else: logger.info( @@ -645,7 +660,7 @@ class Scheduler: f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" + f"#queue-req: {len(self.waiting_queue) + has_inflight}" ) # Create a new batch @@ -694,7 +709,7 @@ class Scheduler: self.waiting_queue.extend(retracted_reqs) else: self.new_token_ratio = max( - self.new_token_ratio - global_config.new_token_ratio_decay, + self.new_token_ratio - self.new_token_ratio_decay, self.min_new_token_ratio, ) @@ -768,8 +783,10 @@ class Scheduler: # Check finish conditions logprob_pt = 0 for i, req in enumerate(batch.reqs): - if not req.is_being_chunked: - # Being chunked reqs' prefill is not finished + if req.is_inflight_req > 0: + req.is_inflight_req -= 1 + else: + # Inflight reqs' prefill is not finished req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_ids[i]) req.check_finished() @@ -795,8 +812,10 @@ class Scheduler: # Check finish conditions for i, req in enumerate(batch.reqs): req.embedding = embeddings[i] - if not req.is_being_chunked: - # Being chunked reqs' prefill is not finished + if req.is_inflight_req > 0: + req.is_inflight_req -= 1 + else: + # Inflight reqs' prefill is not finished # dummy output token for embedding models req.output_ids.append(0) req.check_finished() diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index d6b5d41dc..4a5a894c0 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -663,7 +663,6 @@ def run_mmlu_test( chunked_prefill_size=32, ): other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] - other_args += ["--mem-fraction-static", "0.85"] if disable_radix_cache: other_args += ["--disable-radix-cache"] if enable_mixed_chunk: diff --git a/test/srt/test_radix_attention.py b/test/srt/test_radix_attention.py deleted file mode 100644 index 292a7b454..000000000 --- a/test/srt/test_radix_attention.py +++ /dev/null @@ -1,112 +0,0 @@ -import os -import random -import unittest - -import requests - -from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - kill_child_process, - popen_launch_server, -) - - -def gen_radix_tree(num_nodes=400, chunk_len=256): - num0 = num_nodes // 2 - num1 = num_nodes - num0 - nodes = [{"input_ids": [37] * 117, "decode_len": 217}] - for _ in range(num0): - parent = random.choice(nodes) - unique_len = random.randint(0, chunk_len) - decode_len = random.randint(0, chunk_len) - token_id = random.randint(0, 32000) - child = { - "input_ids": parent["input_ids"] + [token_id] * unique_len, - "decode_len": decode_len, - } - nodes.append(child) - - while num1 > 0: - num_branch = random.randint(1, min(num1, 10)) - parent = random.choice(nodes) - for _ in range(num_branch): - unique_len = random.randint(0, chunk_len) - decode_len = random.randint(0, chunk_len) - token_id = random.randint(0, 32000) - child = { - "input_ids": parent["input_ids"] + [token_id] * unique_len, - "decode_len": decode_len, - } - nodes.append(child) - - num1 -= num_branch - - random.shuffle(nodes) - return nodes - - -def run_test(base_url, nodes): - data = { - "input_ids": [node["input_ids"] for node in nodes], - "sampling_params": [ - {"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes - ], - } - - res = requests.post(base_url + "/generate", json=data) - assert res.status_code == 200 - - -class TestRadixCacheFCFS(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--chunked-prefill-size", - "128", - "--max-total-tokens", - "20000", - "--schedule-policy", - "fcfs", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_child_process(cls.process.pid) - - def test_radix_attention(self): - nodes = gen_radix_tree() - run_test(self.base_url, nodes) - - -class TestRadixCacheLPM(TestRadixCacheFCFS): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--chunked-prefill-size", - "128", - "--max-total-tokens", - "20000", - "--schedule-policy", - "lpm", - ], - ) - - -if __name__ == "__main__": - os.environ["SGLANG_TEST_RETRACT"] = "true" - unittest.main()