Support GC Freezing to improve latency & throughput (#9241)
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
@@ -536,6 +536,22 @@ class Engine(EngineBase):
|
|||||||
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def freeze_gc(self):
|
||||||
|
"""
|
||||||
|
To maintain a high performance server with low latency, we want to reduce the
|
||||||
|
stalls caused by the garbage collector scanning through a large number of objects.
|
||||||
|
|
||||||
|
It is usually helpful to start the server and warm it up with real requests to
|
||||||
|
initialize many of the long-lived objects that do not need to be garbage collected.
|
||||||
|
|
||||||
|
After sufficient warmup, we can call this function to freeze the garbage collector
|
||||||
|
so that all objects created before this point are considered out of scope for garbage
|
||||||
|
collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(self.tokenizer_manager.freeze_gc())
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Execute an RPC call on all scheduler processes.
|
Execute an RPC call on all scheduler processes.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -511,6 +511,18 @@ async def stop_profile_async():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/freeze_gc", methods=["GET", "POST"])
|
||||||
|
async def freeze_gc_async():
|
||||||
|
"""
|
||||||
|
See engine.freeze_gc for more details.
|
||||||
|
"""
|
||||||
|
await _global_state.tokenizer_manager.freeze_gc()
|
||||||
|
return Response(
|
||||||
|
content="Garbage collection frozen.\n",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
|
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
|
||||||
async def start_expert_distribution_record_async():
|
async def start_expert_distribution_record_async():
|
||||||
"""Start recording the expert distribution. Clear the previous record if any."""
|
"""Start recording the expert distribution. Clear the previous record if any."""
|
||||||
|
|||||||
@@ -31,10 +31,12 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchMultimodalOut,
|
BatchMultimodalOut,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
|
FreezeGCReq,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
configure_logger,
|
configure_logger,
|
||||||
|
freeze_gc,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_itself_when_parent_died,
|
kill_itself_when_parent_died,
|
||||||
)
|
)
|
||||||
@@ -100,6 +102,7 @@ class DetokenizerManager:
|
|||||||
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
||||||
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
||||||
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
||||||
|
(FreezeGCReq, self.handle_freeze_gc_req),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -108,7 +111,8 @@ class DetokenizerManager:
|
|||||||
while True:
|
while True:
|
||||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||||
output = self._request_dispatcher(recv_obj)
|
output = self._request_dispatcher(recv_obj)
|
||||||
self.send_to_tokenizer.send_pyobj(output)
|
if output is not None:
|
||||||
|
self.send_to_tokenizer.send_pyobj(output)
|
||||||
|
|
||||||
def trim_matched_stop(
|
def trim_matched_stop(
|
||||||
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
||||||
@@ -247,6 +251,10 @@ class DetokenizerManager:
|
|||||||
cached_tokens=recv_obj.cached_tokens,
|
cached_tokens=recv_obj.cached_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
|
||||||
|
freeze_gc("Detokenizer Manager")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class LimitedCapacityDict(OrderedDict):
|
class LimitedCapacityDict(OrderedDict):
|
||||||
def __init__(self, capacity: int, *args, **kwargs):
|
def __init__(self, capacity: int, *args, **kwargs):
|
||||||
|
|||||||
@@ -1005,6 +1005,11 @@ class ProfileReqOutput:
|
|||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FreezeGCReq:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConfigureLoggingReq:
|
class ConfigureLoggingReq:
|
||||||
log_requests: Optional[bool] = None
|
log_requests: Optional[bool] = None
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ExpertDistributionReqOutput,
|
ExpertDistributionReqOutput,
|
||||||
FlushCacheReqInput,
|
FlushCacheReqInput,
|
||||||
FlushCacheReqOutput,
|
FlushCacheReqOutput,
|
||||||
|
FreezeGCReq,
|
||||||
GetInternalStateReq,
|
GetInternalStateReq,
|
||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
@@ -145,6 +146,7 @@ from sglang.srt.utils import (
|
|||||||
configure_gc_logger,
|
configure_gc_logger,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
disable_request_logging,
|
disable_request_logging,
|
||||||
|
freeze_gc,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
@@ -524,6 +526,7 @@ class Scheduler(
|
|||||||
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
||||||
(SlowDownReqInput, self.slow_down),
|
(SlowDownReqInput, self.slow_down),
|
||||||
(ProfileReq, self.profile),
|
(ProfileReq, self.profile),
|
||||||
|
(FreezeGCReq, self.handle_freeze_gc),
|
||||||
(GetInternalStateReq, self.get_internal_state),
|
(GetInternalStateReq, self.get_internal_state),
|
||||||
(SetInternalStateReq, self.set_internal_state),
|
(SetInternalStateReq, self.set_internal_state),
|
||||||
(RpcReqInput, self.handle_rpc_request),
|
(RpcReqInput, self.handle_rpc_request),
|
||||||
@@ -2469,6 +2472,12 @@ class Scheduler(
|
|||||||
if self.idle_sleeper is not None:
|
if self.idle_sleeper is not None:
|
||||||
self.idle_sleeper.maybe_sleep()
|
self.idle_sleeper.maybe_sleep()
|
||||||
|
|
||||||
|
def handle_freeze_gc(self, recv_req: FreezeGCReq):
|
||||||
|
"""Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
|
||||||
|
freeze_gc("Scheduler")
|
||||||
|
self.send_to_detokenizer.send_pyobj(recv_req)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class IdleSleeper:
|
class IdleSleeper:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ExpertDistributionReqOutput,
|
ExpertDistributionReqOutput,
|
||||||
FlushCacheReqInput,
|
FlushCacheReqInput,
|
||||||
FlushCacheReqOutput,
|
FlushCacheReqOutput,
|
||||||
|
FreezeGCReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetInternalStateReq,
|
GetInternalStateReq,
|
||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
@@ -122,7 +123,9 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
|||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
configure_gc_warning,
|
||||||
dataclass_to_string_truncated,
|
dataclass_to_string_truncated,
|
||||||
|
freeze_gc,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
@@ -352,6 +355,10 @@ class TokenizerManager:
|
|||||||
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
|
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Configure GC warning
|
||||||
|
if self.server_args.gc_warning_threshold_secs > 0.0:
|
||||||
|
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
|
||||||
|
|
||||||
# Communicators
|
# Communicators
|
||||||
self.init_weights_update_group_communicator = _Communicator(
|
self.init_weights_update_group_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
@@ -446,6 +453,10 @@ class TokenizerManager:
|
|||||||
ProfileReqOutput,
|
ProfileReqOutput,
|
||||||
self.profile_communicator.handle_recv,
|
self.profile_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
FreezeGCReq,
|
||||||
|
lambda x: None,
|
||||||
|
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
||||||
(
|
(
|
||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
self.get_internal_state_communicator.handle_recv,
|
self.get_internal_state_communicator.handle_recv,
|
||||||
@@ -1359,6 +1370,12 @@ class TokenizerManager:
|
|||||||
logging.info(f"Config logging: {obj=}")
|
logging.info(f"Config logging: {obj=}")
|
||||||
self.log_request_metadata = self.get_log_request_metadata()
|
self.log_request_metadata = self.get_log_request_metadata()
|
||||||
|
|
||||||
|
async def freeze_gc(self):
|
||||||
|
"""Send a freeze_gc message to the scheduler first, then freeze locally."""
|
||||||
|
self.send_to_scheduler.send_pyobj(FreezeGCReq())
|
||||||
|
freeze_gc("Tokenizer Manager")
|
||||||
|
return None
|
||||||
|
|
||||||
def create_abort_task(self, obj: GenerateReqInput):
|
def create_abort_task(self, obj: GenerateReqInput):
|
||||||
# Abort the request if the client is disconnected.
|
# Abort the request if the client is disconnected.
|
||||||
async def abort_request():
|
async def abort_request():
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ class ServerArgs:
|
|||||||
decode_log_interval: int = 40
|
decode_log_interval: int = 40
|
||||||
enable_request_time_stats_logging: bool = False
|
enable_request_time_stats_logging: bool = False
|
||||||
kv_events_config: Optional[str] = None
|
kv_events_config: Optional[str] = None
|
||||||
|
gc_warning_threshold_secs: float = 0.0
|
||||||
|
|
||||||
# API related
|
# API related
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
@@ -1172,6 +1173,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.collect_tokens_histogram,
|
default=ServerArgs.collect_tokens_histogram,
|
||||||
help="Collect prompt/generation tokens histogram.",
|
help="Collect prompt/generation tokens histogram.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gc-warning-threshold-secs",
|
||||||
|
type=float,
|
||||||
|
default=ServerArgs.gc_warning_threshold_secs,
|
||||||
|
help="The threshold for long GC warning. If a GC takes longer than this, a warning will be logged. Set to 0 to disable.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decode-log-interval",
|
"--decode-log-interval",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -2541,6 +2541,50 @@ def dynamic_import(func_path: str):
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def gc_object_counts():
|
||||||
|
import gc
|
||||||
|
|
||||||
|
g0 = len(gc.get_objects(0))
|
||||||
|
g1 = len(gc.get_objects(1))
|
||||||
|
g2 = len(gc.get_objects(2))
|
||||||
|
return g0, g1, g2
|
||||||
|
|
||||||
|
|
||||||
|
def configure_gc_warning(warn_threshold_secs):
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc_start_time = {}
|
||||||
|
|
||||||
|
def gc_callback(phase, info):
|
||||||
|
gen = info.get("generation", "?")
|
||||||
|
if phase == "start":
|
||||||
|
gc_start_time[gen] = time.time()
|
||||||
|
elif phase == "stop":
|
||||||
|
duration = time.time() - gc_start_time.get(gen, time.time())
|
||||||
|
if duration > warn_threshold_secs:
|
||||||
|
g0, g1, g2 = gc_object_counts()
|
||||||
|
logger.warn(
|
||||||
|
f"LONG GARBAGE COLLECTION DETECTED | Generation {gen} | Duration: {duration:.4f}s | # Objects: gen0={g0}, gen1={g1}, gen2={g2} | "
|
||||||
|
f"This may cause latency jitter. Consider calling the freeze_gc API after sending a few warmup requests."
|
||||||
|
)
|
||||||
|
|
||||||
|
gc.callbacks.append(gc_callback)
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_gc(context: str):
|
||||||
|
import gc
|
||||||
|
|
||||||
|
g0_before, g1_before, g2_before = gc_object_counts()
|
||||||
|
gc.freeze()
|
||||||
|
g0_after, g1_after, g2_after = gc_object_counts()
|
||||||
|
logger.info(
|
||||||
|
f"Freezing GC in {context} process. "
|
||||||
|
f"gen0: {g0_before}->{g0_after}, "
|
||||||
|
f"gen1: {g1_before}->{g1_after}, "
|
||||||
|
f"gen2: {g2_before}->{g2_after}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def configure_gc_logger():
|
def configure_gc_logger():
|
||||||
logger.info("Enable GC Logger")
|
logger.info("Enable GC Logger")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user