diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index ce78f87b6..f320933f7 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -114,7 +114,6 @@ class Envs: # Test & Debug SGLANG_IS_IN_CI = EnvBool(False) SGLANG_IS_IN_CI_AMD = EnvBool(False) - SGLANG_TEST_RETRACT = EnvBool(False) SGLANG_SET_CPU_AFFINITY = EnvBool(False) SGLANG_PROFILE_WITH_STACK = EnvBool(True) SGLANG_RECORD_STEP_TIME = EnvBool(False) @@ -128,6 +127,11 @@ class Envs: SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial") SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp") + # Scheduler: memory leak test + SGLANG_TEST_RETRACT = EnvBool(False) + SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3) + SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK = EnvBool(False) + # Scheduler: new token ratio hyperparameters SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7) SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR = EnvFloat(0.14) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 682213e68..93e11424d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -885,7 +885,6 @@ class Req: self.temp_input_top_logprobs_idx = None self.extend_logprob_start_len = 0 self.is_chunked = 0 - self.req_pool_idx = None self.mamba_pool_idx = None self.already_computed = 0 @@ -1482,7 +1481,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): new_estimate_ratio = ( total_decoded_tokens + envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs) - ) / total_max_new_tokens + ) / ( + total_max_new_tokens + 1 + ) # avoid zero division new_estimate_ratio = min(1.0, new_estimate_ratio) return retracted_reqs, new_estimate_ratio, [] @@ -1780,6 +1781,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Only contain fields that will be used by process_batch_result return ScheduleBatch( reqs=self.reqs, + req_to_token_pool=self.req_to_token_pool, + req_pool_indices=self.req_pool_indices, model_config=self.model_config, forward_mode=self.forward_mode, out_cache_loc=self.out_cache_loc, diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 288984bb8..9a43121e7 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -569,7 +569,8 @@ class PrefillAdder: return self.add_one_req_ignore_eos(req, has_chunked_req) total_tokens = req.extend_input_len + min( - req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS + max(req.sampling_params.max_new_tokens - len(req.output_ids), 0), + CLIP_MAX_NEW_TOKENS, ) # adjusting the input_tokens based on host_hit_length and page_size diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 917d33e87..ed1fe91e9 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -194,7 +194,8 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) # Test retract decode for debugging purposes -TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT") +TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get() +TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get() GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) @@ -1017,6 +1018,9 @@ class Scheduler( self.launch_batch_sample_if_needed(batch_result) self.last_batch = batch + if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get(): + self._check_runtime_mem_leak() + def recv_requests(self) -> List[Req]: """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" @@ -1833,7 +1837,7 @@ class Scheduler( # Check if decode out of memory if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or ( - TEST_RETRACT and batch.batch_size() > 10 + TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0 ): old_ratio = self.new_token_ratio retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode( diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index f63fa8179..e06fac95a 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -77,15 +77,28 @@ class SchedulerOutputProcessorMixin: logprob_pt = 0 for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): - if req.is_retracted: + if self.enable_overlap and req.is_retracted and len(req.output_ids) > 0: + req_idx = batch.req_pool_indices[i] + seq_len = len(req.origin_input_ids) + len(req.output_ids) + pos = batch.req_to_token_pool.req_to_token[req_idx][ + seq_len - 1 : seq_len + ] + self.token_to_kv_pool_allocator.free(pos) continue - if self.is_mixed_chunk and self.enable_overlap and req.finished(): + if ( + self.is_mixed_chunk + and self.enable_overlap + and (req.finished() or req.is_retracted) + ): # Free the one delayed token for the mixed decode batch j = len(batch.out_cache_loc) - len(batch.reqs) + i self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1]) continue + if req.is_retracted: + continue + if req.is_chunked <= 0: # req output_ids are set here req.output_ids.append(next_token_id) @@ -269,10 +282,8 @@ class SchedulerOutputProcessorMixin: # We should ignore using next_token_ids for spec decoding cases. for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): req: Req - if req.is_retracted: - continue - if self.enable_overlap and req.finished(): + if self.enable_overlap and (req.finished() or req.is_retracted): indices_to_free = None if batch.spec_algorithm.is_eagle(): from sglang.srt.speculative.eagle_info import EagleDraftInput @@ -301,6 +312,9 @@ class SchedulerOutputProcessorMixin: self.token_to_kv_pool_allocator.free(indices_to_free) continue + if req.is_retracted: + continue + new_accepted_len = 1 if batch.spec_algorithm.is_none(): req.output_ids.append(next_token_id) diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index 20a57aa83..ae6f84574 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -4,6 +4,7 @@ import time from typing import TYPE_CHECKING from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache @@ -65,6 +66,58 @@ class SchedulerRuntimeCheckerMixin: token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" return memory_leak, token_msg + def _check_runtime_mem_leak(self: Scheduler): + current_batch: ScheduleBatch = self.last_batch + + if current_batch is None: + return + + _, _, available_size, evictable_size = self._get_token_info() + protected_size = self.tree_cache.protected_size() + + extend_size = 0 + for i, req in enumerate(current_batch.reqs): + seq_len = len(req.origin_input_ids) + len(req.output_ids) + fill_len = len(req.fill_ids) if req.fill_ids is not None else 0 + prefix_len = ( + len(req.prefix_indices) if req.prefix_indices is not None else 0 + ) + + if current_batch.forward_mode.is_decode(): + if req.finished(): + unreleased_len = 1 + else: + unreleased_len = seq_len - prefix_len + else: + unreleased_len = fill_len - prefix_len + + extend_size += unreleased_len + + if ( + current_batch.forward_mode.is_extend() + and self.running_batch is not None + and not self.running_batch.is_empty() + and self.running_batch.forward_mode.is_decode() + ): + for i, req in enumerate(self.running_batch.reqs): + seq_len = len(req.origin_input_ids) + len(req.output_ids) + prefix_len = ( + len(req.prefix_indices) if req.prefix_indices is not None else 0 + ) + + if req.finished(): + unreleased_len = 0 + else: + unreleased_len = seq_len - prefix_len - 1 + + extend_size += unreleased_len + + total_tokens = available_size + evictable_size + protected_size + extend_size + + assert ( + total_tokens == self.max_total_num_tokens + ), f"Mem Leak Detected! {total_tokens=} vs {self.max_total_num_tokens=}" + def _check_req_pool(self: Scheduler): if self.disaggregation_mode == DisaggregationMode.DECODE: req_total_size = ( diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index bb308f077..533b5b9cf 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -32,6 +32,8 @@ class ChunkCache(BasePrefixCache): else: self.device = torch.device("cpu") + self.protected_size_ = 0 + # NOTE (csy): this is to determine if a cache has prefix matching feature. # Chunk cache always return True to indicate no prefix matching. # TODO (csy): Using a prefix cache trait to replace this @@ -57,11 +59,13 @@ class ChunkCache(BasePrefixCache): ] self.req_to_token_pool.free(req.req_pool_idx) self.token_to_kv_pool_allocator.free(kv_indices) + self.protected_size_ -= len(req.prefix_indices) def cache_unfinished_req(self, req: Req, chunked=False): kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(req.fill_ids) ] + self.protected_size_ += len(kv_indices) - len(req.prefix_indices) # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True) @@ -75,6 +79,9 @@ class ChunkCache(BasePrefixCache): def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None): return 0 + def protected_size(self): + return self.protected_size_ + def pretty_print(self): return "" diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 845e22ee6..f1b30e6ae 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -112,7 +112,7 @@ suites = { TestFile("test_reasoning_parser.py", 5), TestFile("test_regex_constrained.py", 64), TestFile("test_request_queue_validation.py", 30), - TestFile("test_retract_decode.py", 54), + TestFile("test_retract_decode.py", 90), TestFile("test_score_api.py", 310), TestFile("test_server_args.py", 1), TestFile("test_skip_tokenizer_init.py", 117), diff --git a/test/srt/test_retract_decode.py b/test/srt/test_retract_decode.py index 92f5ab915..8dc22ccac 100644 --- a/test/srt/test_retract_decode.py +++ b/test/srt/test_retract_decode.py @@ -1,7 +1,8 @@ -import os +import time import unittest from types import SimpleNamespace +from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( @@ -16,13 +17,12 @@ from sglang.test.test_utils import ( class TestRetractDecode(CustomTestCase): @classmethod def setUpClass(cls): - os.environ["SGLANG_TEST_RETRACT"] = "1" - cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH - ) + with envs.SGLANG_TEST_RETRACT.override(True): + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) @classmethod def tearDownClass(cls): @@ -39,22 +39,43 @@ class TestRetractDecode(CustomTestCase): metrics = run_eval(args) self.assertGreaterEqual(metrics["score"], 0.65) + time.sleep(1) # wait for mem check + + assert self.process.poll() is None, "Server crashed during test" class TestRetractDecodeChunkCache(CustomTestCase): @classmethod def setUpClass(cls): - os.environ["SGLANG_TEST_RETRACT"] = "1" - cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--disable-radix-cache", "--chunked-prefill-size", 128], + with envs.SGLANG_TEST_RETRACT.override(True): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--disable-radix-cache", "--chunked-prefill-size", 128], + ) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, ) + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + time.sleep(1) # wait for mem check + + assert self.process.poll() is None, "Server crashed during test" + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + if __name__ == "__main__": unittest.main()