[Fix] memory leak by overlap + retract (#11981)
Co-authored-by: Liangsheng Yin <lsyincs@gmail.com>
This commit is contained in:
@@ -114,7 +114,6 @@ class Envs:
|
||||
# Test & Debug
|
||||
SGLANG_IS_IN_CI = EnvBool(False)
|
||||
SGLANG_IS_IN_CI_AMD = EnvBool(False)
|
||||
SGLANG_TEST_RETRACT = EnvBool(False)
|
||||
SGLANG_SET_CPU_AFFINITY = EnvBool(False)
|
||||
SGLANG_PROFILE_WITH_STACK = EnvBool(True)
|
||||
SGLANG_RECORD_STEP_TIME = EnvBool(False)
|
||||
@@ -128,6 +127,11 @@ class Envs:
|
||||
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
|
||||
SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
|
||||
|
||||
# Scheduler: memory leak test
|
||||
SGLANG_TEST_RETRACT = EnvBool(False)
|
||||
SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3)
|
||||
SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK = EnvBool(False)
|
||||
|
||||
# Scheduler: new token ratio hyperparameters
|
||||
SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7)
|
||||
SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR = EnvFloat(0.14)
|
||||
|
||||
@@ -885,7 +885,6 @@ class Req:
|
||||
self.temp_input_top_logprobs_idx = None
|
||||
self.extend_logprob_start_len = 0
|
||||
self.is_chunked = 0
|
||||
self.req_pool_idx = None
|
||||
self.mamba_pool_idx = None
|
||||
self.already_computed = 0
|
||||
|
||||
@@ -1482,7 +1481,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
new_estimate_ratio = (
|
||||
total_decoded_tokens
|
||||
+ envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
|
||||
) / total_max_new_tokens
|
||||
) / (
|
||||
total_max_new_tokens + 1
|
||||
) # avoid zero division
|
||||
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
||||
|
||||
return retracted_reqs, new_estimate_ratio, []
|
||||
@@ -1780,6 +1781,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# Only contain fields that will be used by process_batch_result
|
||||
return ScheduleBatch(
|
||||
reqs=self.reqs,
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
model_config=self.model_config,
|
||||
forward_mode=self.forward_mode,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
|
||||
@@ -569,7 +569,8 @@ class PrefillAdder:
|
||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||
|
||||
total_tokens = req.extend_input_len + min(
|
||||
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
||||
max(req.sampling_params.max_new_tokens - len(req.output_ids), 0),
|
||||
CLIP_MAX_NEW_TOKENS,
|
||||
)
|
||||
|
||||
# adjusting the input_tokens based on host_hit_length and page_size
|
||||
|
||||
@@ -194,7 +194,8 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Test retract decode for debugging purposes
|
||||
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||
TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
|
||||
TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
|
||||
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||
|
||||
|
||||
@@ -1017,6 +1018,9 @@ class Scheduler(
|
||||
self.launch_batch_sample_if_needed(batch_result)
|
||||
self.last_batch = batch
|
||||
|
||||
if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get():
|
||||
self._check_runtime_mem_leak()
|
||||
|
||||
def recv_requests(self) -> List[Req]:
|
||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||
|
||||
@@ -1833,7 +1837,7 @@ class Scheduler(
|
||||
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
||||
TEST_RETRACT and batch.batch_size() > 10
|
||||
TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
|
||||
):
|
||||
old_ratio = self.new_token_ratio
|
||||
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
|
||||
|
||||
@@ -77,15 +77,28 @@ class SchedulerOutputProcessorMixin:
|
||||
logprob_pt = 0
|
||||
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
if req.is_retracted:
|
||||
if self.enable_overlap and req.is_retracted and len(req.output_ids) > 0:
|
||||
req_idx = batch.req_pool_indices[i]
|
||||
seq_len = len(req.origin_input_ids) + len(req.output_ids)
|
||||
pos = batch.req_to_token_pool.req_to_token[req_idx][
|
||||
seq_len - 1 : seq_len
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(pos)
|
||||
continue
|
||||
|
||||
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
||||
if (
|
||||
self.is_mixed_chunk
|
||||
and self.enable_overlap
|
||||
and (req.finished() or req.is_retracted)
|
||||
):
|
||||
# Free the one delayed token for the mixed decode batch
|
||||
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
||||
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
||||
continue
|
||||
|
||||
if req.is_retracted:
|
||||
continue
|
||||
|
||||
if req.is_chunked <= 0:
|
||||
# req output_ids are set here
|
||||
req.output_ids.append(next_token_id)
|
||||
@@ -269,10 +282,8 @@ class SchedulerOutputProcessorMixin:
|
||||
# We should ignore using next_token_ids for spec decoding cases.
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
req: Req
|
||||
if req.is_retracted:
|
||||
continue
|
||||
|
||||
if self.enable_overlap and req.finished():
|
||||
if self.enable_overlap and (req.finished() or req.is_retracted):
|
||||
indices_to_free = None
|
||||
if batch.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
@@ -301,6 +312,9 @@ class SchedulerOutputProcessorMixin:
|
||||
self.token_to_kv_pool_allocator.free(indices_to_free)
|
||||
continue
|
||||
|
||||
if req.is_retracted:
|
||||
continue
|
||||
|
||||
new_accepted_len = 1
|
||||
if batch.spec_algorithm.is_none():
|
||||
req.output_ids.append(next_token_id)
|
||||
|
||||
@@ -4,6 +4,7 @@ import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
|
||||
@@ -65,6 +66,58 @@ class SchedulerRuntimeCheckerMixin:
|
||||
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
||||
return memory_leak, token_msg
|
||||
|
||||
def _check_runtime_mem_leak(self: Scheduler):
|
||||
current_batch: ScheduleBatch = self.last_batch
|
||||
|
||||
if current_batch is None:
|
||||
return
|
||||
|
||||
_, _, available_size, evictable_size = self._get_token_info()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
|
||||
extend_size = 0
|
||||
for i, req in enumerate(current_batch.reqs):
|
||||
seq_len = len(req.origin_input_ids) + len(req.output_ids)
|
||||
fill_len = len(req.fill_ids) if req.fill_ids is not None else 0
|
||||
prefix_len = (
|
||||
len(req.prefix_indices) if req.prefix_indices is not None else 0
|
||||
)
|
||||
|
||||
if current_batch.forward_mode.is_decode():
|
||||
if req.finished():
|
||||
unreleased_len = 1
|
||||
else:
|
||||
unreleased_len = seq_len - prefix_len
|
||||
else:
|
||||
unreleased_len = fill_len - prefix_len
|
||||
|
||||
extend_size += unreleased_len
|
||||
|
||||
if (
|
||||
current_batch.forward_mode.is_extend()
|
||||
and self.running_batch is not None
|
||||
and not self.running_batch.is_empty()
|
||||
and self.running_batch.forward_mode.is_decode()
|
||||
):
|
||||
for i, req in enumerate(self.running_batch.reqs):
|
||||
seq_len = len(req.origin_input_ids) + len(req.output_ids)
|
||||
prefix_len = (
|
||||
len(req.prefix_indices) if req.prefix_indices is not None else 0
|
||||
)
|
||||
|
||||
if req.finished():
|
||||
unreleased_len = 0
|
||||
else:
|
||||
unreleased_len = seq_len - prefix_len - 1
|
||||
|
||||
extend_size += unreleased_len
|
||||
|
||||
total_tokens = available_size + evictable_size + protected_size + extend_size
|
||||
|
||||
assert (
|
||||
total_tokens == self.max_total_num_tokens
|
||||
), f"Mem Leak Detected! {total_tokens=} vs {self.max_total_num_tokens=}"
|
||||
|
||||
def _check_req_pool(self: Scheduler):
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
req_total_size = (
|
||||
|
||||
@@ -32,6 +32,8 @@ class ChunkCache(BasePrefixCache):
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.protected_size_ = 0
|
||||
|
||||
# NOTE (csy): this is to determine if a cache has prefix matching feature.
|
||||
# Chunk cache always return True to indicate no prefix matching.
|
||||
# TODO (csy): Using a prefix cache trait to replace this
|
||||
@@ -57,11 +59,13 @@ class ChunkCache(BasePrefixCache):
|
||||
]
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
self.protected_size_ -= len(req.prefix_indices)
|
||||
|
||||
def cache_unfinished_req(self, req: Req, chunked=False):
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(req.fill_ids)
|
||||
]
|
||||
self.protected_size_ += len(kv_indices) - len(req.prefix_indices)
|
||||
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
@@ -75,6 +79,9 @@ class ChunkCache(BasePrefixCache):
|
||||
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
|
||||
return 0
|
||||
|
||||
def protected_size(self):
|
||||
return self.protected_size_
|
||||
|
||||
def pretty_print(self):
|
||||
return ""
|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ suites = {
|
||||
TestFile("test_reasoning_parser.py", 5),
|
||||
TestFile("test_regex_constrained.py", 64),
|
||||
TestFile("test_request_queue_validation.py", 30),
|
||||
TestFile("test_retract_decode.py", 54),
|
||||
TestFile("test_retract_decode.py", 90),
|
||||
TestFile("test_score_api.py", 310),
|
||||
TestFile("test_server_args.py", 1),
|
||||
TestFile("test_skip_tokenizer_init.py", 117),
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
@@ -16,13 +17,12 @@ from sglang.test.test_utils import (
|
||||
class TestRetractDecode(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
os.environ["SGLANG_TEST_RETRACT"] = "1"
|
||||
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
with envs.SGLANG_TEST_RETRACT.override(True):
|
||||
cls.process = popen_launch_server(
|
||||
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
@@ -39,22 +39,43 @@ class TestRetractDecode(CustomTestCase):
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||
time.sleep(1) # wait for mem check
|
||||
|
||||
assert self.process.poll() is None, "Server crashed during test"
|
||||
|
||||
|
||||
class TestRetractDecodeChunkCache(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
os.environ["SGLANG_TEST_RETRACT"] = "1"
|
||||
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--disable-radix-cache", "--chunked-prefill-size", 128],
|
||||
with envs.SGLANG_TEST_RETRACT.override(True):
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--disable-radix-cache", "--chunked-prefill-size", 128],
|
||||
)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||
time.sleep(1) # wait for mem check
|
||||
|
||||
assert self.process.poll() is None, "Server crashed during test"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user