[Fix] memory leak by overlap + retract (#11981)

Co-authored-by: Liangsheng Yin <lsyincs@gmail.com>
This commit is contained in:
cctry
2025-10-23 07:59:23 -07:00
committed by GitHub
parent 6c18addb6f
commit b0b4f71679
9 changed files with 132 additions and 25 deletions

View File

@@ -114,7 +114,6 @@ class Envs:
# Test & Debug
SGLANG_IS_IN_CI = EnvBool(False)
SGLANG_IS_IN_CI_AMD = EnvBool(False)
SGLANG_TEST_RETRACT = EnvBool(False)
SGLANG_SET_CPU_AFFINITY = EnvBool(False)
SGLANG_PROFILE_WITH_STACK = EnvBool(True)
SGLANG_RECORD_STEP_TIME = EnvBool(False)
@@ -128,6 +127,11 @@ class Envs:
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
# Scheduler: memory leak test
SGLANG_TEST_RETRACT = EnvBool(False)
SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3)
SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK = EnvBool(False)
# Scheduler: new token ratio hyperparameters
SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7)
SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR = EnvFloat(0.14)

View File

@@ -885,7 +885,6 @@ class Req:
self.temp_input_top_logprobs_idx = None
self.extend_logprob_start_len = 0
self.is_chunked = 0
self.req_pool_idx = None
self.mamba_pool_idx = None
self.already_computed = 0
@@ -1482,7 +1481,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
new_estimate_ratio = (
total_decoded_tokens
+ envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
) / total_max_new_tokens
) / (
total_max_new_tokens + 1
) # avoid zero division
new_estimate_ratio = min(1.0, new_estimate_ratio)
return retracted_reqs, new_estimate_ratio, []
@@ -1780,6 +1781,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Only contain fields that will be used by process_batch_result
return ScheduleBatch(
reqs=self.reqs,
req_to_token_pool=self.req_to_token_pool,
req_pool_indices=self.req_pool_indices,
model_config=self.model_config,
forward_mode=self.forward_mode,
out_cache_loc=self.out_cache_loc,

View File

@@ -569,7 +569,8 @@ class PrefillAdder:
return self.add_one_req_ignore_eos(req, has_chunked_req)
total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
max(req.sampling_params.max_new_tokens - len(req.output_ids), 0),
CLIP_MAX_NEW_TOKENS,
)
# adjusting the input_tokens based on host_hit_length and page_size

View File

@@ -194,7 +194,8 @@ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
@@ -1017,6 +1018,9 @@ class Scheduler(
self.launch_batch_sample_if_needed(batch_result)
self.last_batch = batch
if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get():
self._check_runtime_mem_leak()
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
@@ -1833,7 +1837,7 @@ class Scheduler(
# Check if decode out of memory
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
TEST_RETRACT and batch.batch_size() > 10
TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
):
old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(

View File

@@ -77,15 +77,28 @@ class SchedulerOutputProcessorMixin:
logprob_pt = 0
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted:
if self.enable_overlap and req.is_retracted and len(req.output_ids) > 0:
req_idx = batch.req_pool_indices[i]
seq_len = len(req.origin_input_ids) + len(req.output_ids)
pos = batch.req_to_token_pool.req_to_token[req_idx][
seq_len - 1 : seq_len
]
self.token_to_kv_pool_allocator.free(pos)
continue
if self.is_mixed_chunk and self.enable_overlap and req.finished():
if (
self.is_mixed_chunk
and self.enable_overlap
and (req.finished() or req.is_retracted)
):
# Free the one delayed token for the mixed decode batch
j = len(batch.out_cache_loc) - len(batch.reqs) + i
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
continue
if req.is_retracted:
continue
if req.is_chunked <= 0:
# req output_ids are set here
req.output_ids.append(next_token_id)
@@ -269,10 +282,8 @@ class SchedulerOutputProcessorMixin:
# We should ignore using next_token_ids for spec decoding cases.
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req: Req
if req.is_retracted:
continue
if self.enable_overlap and req.finished():
if self.enable_overlap and (req.finished() or req.is_retracted):
indices_to_free = None
if batch.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_info import EagleDraftInput
@@ -301,6 +312,9 @@ class SchedulerOutputProcessorMixin:
self.token_to_kv_pool_allocator.free(indices_to_free)
continue
if req.is_retracted:
continue
new_accepted_len = 1
if batch.spec_algorithm.is_none():
req.output_ids.append(next_token_id)

View File

@@ -4,6 +4,7 @@ import time
from typing import TYPE_CHECKING
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
@@ -65,6 +66,58 @@ class SchedulerRuntimeCheckerMixin:
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
return memory_leak, token_msg
def _check_runtime_mem_leak(self: Scheduler):
current_batch: ScheduleBatch = self.last_batch
if current_batch is None:
return
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
extend_size = 0
for i, req in enumerate(current_batch.reqs):
seq_len = len(req.origin_input_ids) + len(req.output_ids)
fill_len = len(req.fill_ids) if req.fill_ids is not None else 0
prefix_len = (
len(req.prefix_indices) if req.prefix_indices is not None else 0
)
if current_batch.forward_mode.is_decode():
if req.finished():
unreleased_len = 1
else:
unreleased_len = seq_len - prefix_len
else:
unreleased_len = fill_len - prefix_len
extend_size += unreleased_len
if (
current_batch.forward_mode.is_extend()
and self.running_batch is not None
and not self.running_batch.is_empty()
and self.running_batch.forward_mode.is_decode()
):
for i, req in enumerate(self.running_batch.reqs):
seq_len = len(req.origin_input_ids) + len(req.output_ids)
prefix_len = (
len(req.prefix_indices) if req.prefix_indices is not None else 0
)
if req.finished():
unreleased_len = 0
else:
unreleased_len = seq_len - prefix_len - 1
extend_size += unreleased_len
total_tokens = available_size + evictable_size + protected_size + extend_size
assert (
total_tokens == self.max_total_num_tokens
), f"Mem Leak Detected! {total_tokens=} vs {self.max_total_num_tokens=}"
def _check_req_pool(self: Scheduler):
if self.disaggregation_mode == DisaggregationMode.DECODE:
req_total_size = (

View File

@@ -32,6 +32,8 @@ class ChunkCache(BasePrefixCache):
else:
self.device = torch.device("cpu")
self.protected_size_ = 0
# NOTE (csy): this is to determine if a cache has prefix matching feature.
# Chunk cache always return True to indicate no prefix matching.
# TODO (csy): Using a prefix cache trait to replace this
@@ -57,11 +59,13 @@ class ChunkCache(BasePrefixCache):
]
self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool_allocator.free(kv_indices)
self.protected_size_ -= len(req.prefix_indices)
def cache_unfinished_req(self, req: Req, chunked=False):
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(req.fill_ids)
]
self.protected_size_ += len(kv_indices) - len(req.prefix_indices)
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = kv_indices.to(dtype=torch.int64, copy=True)
@@ -75,6 +79,9 @@ class ChunkCache(BasePrefixCache):
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
return 0
def protected_size(self):
return self.protected_size_
def pretty_print(self):
return ""