[Fix] memory leak by overlap + retract (#11981)
Co-authored-by: Liangsheng Yin <lsyincs@gmail.com>
This commit is contained in:
@@ -114,7 +114,6 @@ class Envs:
|
|||||||
# Test & Debug
|
# Test & Debug
|
||||||
SGLANG_IS_IN_CI = EnvBool(False)
|
SGLANG_IS_IN_CI = EnvBool(False)
|
||||||
SGLANG_IS_IN_CI_AMD = EnvBool(False)
|
SGLANG_IS_IN_CI_AMD = EnvBool(False)
|
||||||
SGLANG_TEST_RETRACT = EnvBool(False)
|
|
||||||
SGLANG_SET_CPU_AFFINITY = EnvBool(False)
|
SGLANG_SET_CPU_AFFINITY = EnvBool(False)
|
||||||
SGLANG_PROFILE_WITH_STACK = EnvBool(True)
|
SGLANG_PROFILE_WITH_STACK = EnvBool(True)
|
||||||
SGLANG_RECORD_STEP_TIME = EnvBool(False)
|
SGLANG_RECORD_STEP_TIME = EnvBool(False)
|
||||||
@@ -128,6 +127,11 @@ class Envs:
|
|||||||
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
|
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
|
||||||
SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
|
SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
|
||||||
|
|
||||||
|
# Scheduler: memory leak test
|
||||||
|
SGLANG_TEST_RETRACT = EnvBool(False)
|
||||||
|
SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3)
|
||||||
|
SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK = EnvBool(False)
|
||||||
|
|
||||||
# Scheduler: new token ratio hyperparameters
|
# Scheduler: new token ratio hyperparameters
|
||||||
SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7)
|
SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7)
|
||||||
SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR = EnvFloat(0.14)
|
SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR = EnvFloat(0.14)
|
||||||
|
|||||||
@@ -885,7 +885,6 @@ class Req:
|
|||||||
self.temp_input_top_logprobs_idx = None
|
self.temp_input_top_logprobs_idx = None
|
||||||
self.extend_logprob_start_len = 0
|
self.extend_logprob_start_len = 0
|
||||||
self.is_chunked = 0
|
self.is_chunked = 0
|
||||||
self.req_pool_idx = None
|
|
||||||
self.mamba_pool_idx = None
|
self.mamba_pool_idx = None
|
||||||
self.already_computed = 0
|
self.already_computed = 0
|
||||||
|
|
||||||
@@ -1482,7 +1481,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
new_estimate_ratio = (
|
new_estimate_ratio = (
|
||||||
total_decoded_tokens
|
total_decoded_tokens
|
||||||
+ envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
|
+ envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
|
||||||
) / total_max_new_tokens
|
) / (
|
||||||
|
total_max_new_tokens + 1
|
||||||
|
) # avoid zero division
|
||||||
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
||||||
|
|
||||||
return retracted_reqs, new_estimate_ratio, []
|
return retracted_reqs, new_estimate_ratio, []
|
||||||
@@ -1780,6 +1781,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# Only contain fields that will be used by process_batch_result
|
# Only contain fields that will be used by process_batch_result
|
||||||
return ScheduleBatch(
|
return ScheduleBatch(
|
||||||
reqs=self.reqs,
|
reqs=self.reqs,
|
||||||
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
|
req_pool_indices=self.req_pool_indices,
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
forward_mode=self.forward_mode,
|
forward_mode=self.forward_mode,
|
||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
|
|||||||
@@ -569,7 +569,8 @@ class PrefillAdder:
|
|||||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||||
|
|
||||||
total_tokens = req.extend_input_len + min(
|
total_tokens = req.extend_input_len + min(
|
||||||
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
max(req.sampling_params.max_new_tokens - len(req.output_ids), 0),
|
||||||
|
CLIP_MAX_NEW_TOKENS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# adjusting the input_tokens based on host_hit_length and page_size
|
# adjusting the input_tokens based on host_hit_length and page_size
|
||||||
|
|||||||
@@ -194,7 +194,8 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Test retract decode for debugging purposes
|
# Test retract decode for debugging purposes
|
||||||
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
|
||||||
|
TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
|
||||||
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||||
|
|
||||||
|
|
||||||
@@ -1017,6 +1018,9 @@ class Scheduler(
|
|||||||
self.launch_batch_sample_if_needed(batch_result)
|
self.launch_batch_sample_if_needed(batch_result)
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
|
if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get():
|
||||||
|
self._check_runtime_mem_leak()
|
||||||
|
|
||||||
def recv_requests(self) -> List[Req]:
|
def recv_requests(self) -> List[Req]:
|
||||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||||
|
|
||||||
@@ -1833,7 +1837,7 @@ class Scheduler(
|
|||||||
|
|
||||||
# Check if decode out of memory
|
# Check if decode out of memory
|
||||||
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
||||||
TEST_RETRACT and batch.batch_size() > 10
|
TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
|
||||||
):
|
):
|
||||||
old_ratio = self.new_token_ratio
|
old_ratio = self.new_token_ratio
|
||||||
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
|
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
|
||||||
|
|||||||
@@ -77,15 +77,28 @@ class SchedulerOutputProcessorMixin:
|
|||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
|
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
if req.is_retracted:
|
if self.enable_overlap and req.is_retracted and len(req.output_ids) > 0:
|
||||||
|
req_idx = batch.req_pool_indices[i]
|
||||||
|
seq_len = len(req.origin_input_ids) + len(req.output_ids)
|
||||||
|
pos = batch.req_to_token_pool.req_to_token[req_idx][
|
||||||
|
seq_len - 1 : seq_len
|
||||||
|
]
|
||||||
|
self.token_to_kv_pool_allocator.free(pos)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
if (
|
||||||
|
self.is_mixed_chunk
|
||||||
|
and self.enable_overlap
|
||||||
|
and (req.finished() or req.is_retracted)
|
||||||
|
):
|
||||||
# Free the one delayed token for the mixed decode batch
|
# Free the one delayed token for the mixed decode batch
|
||||||
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
||||||
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if req.is_retracted:
|
||||||
|
continue
|
||||||
|
|
||||||
if req.is_chunked <= 0:
|
if req.is_chunked <= 0:
|
||||||
# req output_ids are set here
|
# req output_ids are set here
|
||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
@@ -269,10 +282,8 @@ class SchedulerOutputProcessorMixin:
|
|||||||
# We should ignore using next_token_ids for spec decoding cases.
|
# We should ignore using next_token_ids for spec decoding cases.
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
req: Req
|
req: Req
|
||||||
if req.is_retracted:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if self.enable_overlap and req.finished():
|
if self.enable_overlap and (req.finished() or req.is_retracted):
|
||||||
indices_to_free = None
|
indices_to_free = None
|
||||||
if batch.spec_algorithm.is_eagle():
|
if batch.spec_algorithm.is_eagle():
|
||||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||||
@@ -301,6 +312,9 @@ class SchedulerOutputProcessorMixin:
|
|||||||
self.token_to_kv_pool_allocator.free(indices_to_free)
|
self.token_to_kv_pool_allocator.free(indices_to_free)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if req.is_retracted:
|
||||||
|
continue
|
||||||
|
|
||||||
new_accepted_len = 1
|
new_accepted_len = 1
|
||||||
if batch.spec_algorithm.is_none():
|
if batch.spec_algorithm.is_none():
|
||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import time
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||||
|
|
||||||
@@ -65,6 +66,58 @@ class SchedulerRuntimeCheckerMixin:
|
|||||||
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
||||||
return memory_leak, token_msg
|
return memory_leak, token_msg
|
||||||
|
|
||||||
|
def _check_runtime_mem_leak(self: Scheduler):
|
||||||
|
current_batch: ScheduleBatch = self.last_batch
|
||||||
|
|
||||||
|
if current_batch is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
_, _, available_size, evictable_size = self._get_token_info()
|
||||||
|
protected_size = self.tree_cache.protected_size()
|
||||||
|
|
||||||
|
extend_size = 0
|
||||||
|
for i, req in enumerate(current_batch.reqs):
|
||||||
|
seq_len = len(req.origin_input_ids) + len(req.output_ids)
|
||||||
|
fill_len = len(req.fill_ids) if req.fill_ids is not None else 0
|
||||||
|
prefix_len = (
|
||||||
|
len(req.prefix_indices) if req.prefix_indices is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_batch.forward_mode.is_decode():
|
||||||
|
if req.finished():
|
||||||
|
unreleased_len = 1
|
||||||
|
else:
|
||||||
|
unreleased_len = seq_len - prefix_len
|
||||||
|
else:
|
||||||
|
unreleased_len = fill_len - prefix_len
|
||||||
|
|
||||||
|
extend_size += unreleased_len
|
||||||
|
|
||||||
|
if (
|
||||||
|
current_batch.forward_mode.is_extend()
|
||||||
|
and self.running_batch is not None
|
||||||
|
and not self.running_batch.is_empty()
|
||||||
|
and self.running_batch.forward_mode.is_decode()
|
||||||
|
):
|
||||||
|
for i, req in enumerate(self.running_batch.reqs):
|
||||||
|
seq_len = len(req.origin_input_ids) + len(req.output_ids)
|
||||||
|
prefix_len = (
|
||||||
|
len(req.prefix_indices) if req.prefix_indices is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if req.finished():
|
||||||
|
unreleased_len = 0
|
||||||
|
else:
|
||||||
|
unreleased_len = seq_len - prefix_len - 1
|
||||||
|
|
||||||
|
extend_size += unreleased_len
|
||||||
|
|
||||||
|
total_tokens = available_size + evictable_size + protected_size + extend_size
|
||||||
|
|
||||||
|
assert (
|
||||||
|
total_tokens == self.max_total_num_tokens
|
||||||
|
), f"Mem Leak Detected! {total_tokens=} vs {self.max_total_num_tokens=}"
|
||||||
|
|
||||||
def _check_req_pool(self: Scheduler):
|
def _check_req_pool(self: Scheduler):
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
req_total_size = (
|
req_total_size = (
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ class ChunkCache(BasePrefixCache):
|
|||||||
else:
|
else:
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
|
self.protected_size_ = 0
|
||||||
|
|
||||||
# NOTE (csy): this is to determine if a cache has prefix matching feature.
|
# NOTE (csy): this is to determine if a cache has prefix matching feature.
|
||||||
# Chunk cache always return True to indicate no prefix matching.
|
# Chunk cache always return True to indicate no prefix matching.
|
||||||
# TODO (csy): Using a prefix cache trait to replace this
|
# TODO (csy): Using a prefix cache trait to replace this
|
||||||
@@ -57,11 +59,13 @@ class ChunkCache(BasePrefixCache):
|
|||||||
]
|
]
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||||
|
self.protected_size_ -= len(req.prefix_indices)
|
||||||
|
|
||||||
def cache_unfinished_req(self, req: Req, chunked=False):
|
def cache_unfinished_req(self, req: Req, chunked=False):
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : len(req.fill_ids)
|
req.req_pool_idx, : len(req.fill_ids)
|
||||||
]
|
]
|
||||||
|
self.protected_size_ += len(kv_indices) - len(req.prefix_indices)
|
||||||
|
|
||||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||||
req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||||
@@ -75,6 +79,9 @@ class ChunkCache(BasePrefixCache):
|
|||||||
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
|
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def protected_size(self):
|
||||||
|
return self.protected_size_
|
||||||
|
|
||||||
def pretty_print(self):
|
def pretty_print(self):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ suites = {
|
|||||||
TestFile("test_reasoning_parser.py", 5),
|
TestFile("test_reasoning_parser.py", 5),
|
||||||
TestFile("test_regex_constrained.py", 64),
|
TestFile("test_regex_constrained.py", 64),
|
||||||
TestFile("test_request_queue_validation.py", 30),
|
TestFile("test_request_queue_validation.py", 30),
|
||||||
TestFile("test_retract_decode.py", 54),
|
TestFile("test_retract_decode.py", 90),
|
||||||
TestFile("test_score_api.py", 310),
|
TestFile("test_score_api.py", 310),
|
||||||
TestFile("test_server_args.py", 1),
|
TestFile("test_server_args.py", 1),
|
||||||
TestFile("test_skip_tokenizer_init.py", 117),
|
TestFile("test_skip_tokenizer_init.py", 117),
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import os
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.environ import envs
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.run_eval import run_eval
|
from sglang.test.run_eval import run_eval
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
@@ -16,10 +17,9 @@ from sglang.test.test_utils import (
|
|||||||
class TestRetractDecode(CustomTestCase):
|
class TestRetractDecode(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
os.environ["SGLANG_TEST_RETRACT"] = "1"
|
|
||||||
|
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
with envs.SGLANG_TEST_RETRACT.override(True):
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(
|
||||||
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||||
)
|
)
|
||||||
@@ -39,15 +39,17 @@ class TestRetractDecode(CustomTestCase):
|
|||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||||
|
time.sleep(1) # wait for mem check
|
||||||
|
|
||||||
|
assert self.process.poll() is None, "Server crashed during test"
|
||||||
|
|
||||||
|
|
||||||
class TestRetractDecodeChunkCache(CustomTestCase):
|
class TestRetractDecodeChunkCache(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
os.environ["SGLANG_TEST_RETRACT"] = "1"
|
|
||||||
|
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
with envs.SGLANG_TEST_RETRACT.override(True):
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(
|
||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
@@ -55,6 +57,25 @@ class TestRetractDecodeChunkCache(CustomTestCase):
|
|||||||
other_args=["--disable-radix-cache", "--chunked-prefill-size", 128],
|
other_args=["--disable-radix-cache", "--chunked-prefill-size", 128],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||||
|
time.sleep(1) # wait for mem check
|
||||||
|
|
||||||
|
assert self.process.poll() is None, "Server crashed during test"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user