Split the scheduler into multiple mixin classes to reduce the file size (#8483)
This commit is contained in:
@@ -694,10 +694,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||||
== 0
|
== 0
|
||||||
):
|
):
|
||||||
# When the server is idle, do self-check and re-init some states
|
self.self_check_during_idle()
|
||||||
self.check_memory()
|
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
|
||||||
self.maybe_sleep_on_idle()
|
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
@@ -771,10 +768,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||||
== 0
|
== 0
|
||||||
):
|
):
|
||||||
# When the server is idle, do self-check and re-init some states
|
self.self_check_during_idle()
|
||||||
self.check_memory()
|
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
|
||||||
self.maybe_sleep_on_idle()
|
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
self.last_batch_in_queue = last_batch_in_queue
|
self.last_batch_in_queue = last_batch_in_queue
|
||||||
|
|||||||
@@ -287,9 +287,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
self.process_disagg_prefill_inflight_queue()
|
self.process_disagg_prefill_inflight_queue()
|
||||||
|
|
||||||
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
||||||
self.check_memory()
|
self.self_check_during_idle()
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
|
||||||
self.maybe_sleep_on_idle()
|
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
||||||
@@ -337,9 +335,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
self.process_disagg_prefill_inflight_queue()
|
self.process_disagg_prefill_inflight_queue()
|
||||||
|
|
||||||
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
||||||
self.check_memory()
|
self.self_check_during_idle()
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
|
||||||
self.maybe_sleep_on_idle()
|
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
||||||
|
|||||||
@@ -652,25 +652,19 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||||
)
|
)
|
||||||
|
|
||||||
def sigchld_handler(signum, frame):
|
if True: # Keep this check for internal code compatibility
|
||||||
pid, exitcode = os.waitpid(0, os.WNOHANG)
|
|
||||||
if exitcode != 0:
|
|
||||||
logger.warning(
|
|
||||||
f"Child process unexpectedly failed with {exitcode=}. {pid=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGCHLD, sigchld_handler)
|
|
||||||
|
|
||||||
# Register the signal handler.
|
# Register the signal handler.
|
||||||
# The child processes will send SIGQUIT to this process when any error happens
|
# The child processes will send SIGQUIT to this process when any error happens
|
||||||
# This process then clean up the whole process tree
|
# This process then clean up the whole process tree
|
||||||
def sigquit_handler(signum, frame):
|
# Note: This sigquit handler is used in the launch phase, and may be replaced by
|
||||||
|
# the running_phase_sigquit_handler in the tokenizer manager after the grpc server is launched.
|
||||||
|
def launch_phase_sigquit_handler(signum, frame):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Received sigquit from a child process. It usually means the child failed."
|
"Received sigquit from a child process. It usually means the child failed."
|
||||||
)
|
)
|
||||||
kill_process_tree(os.getpid())
|
kill_process_tree(os.getpid())
|
||||||
|
|
||||||
signal.signal(signal.SIGQUIT, sigquit_handler)
|
signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler)
|
||||||
|
|
||||||
# Set mp start method
|
# Set mp start method
|
||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|||||||
@@ -238,6 +238,9 @@ async def health() -> Response:
|
|||||||
@app.get("/health_generate")
|
@app.get("/health_generate")
|
||||||
async def health_generate(request: Request) -> Response:
|
async def health_generate(request: Request) -> Response:
|
||||||
"""Check the health of the inference server by generating one token."""
|
"""Check the health of the inference server by generating one token."""
|
||||||
|
if _global_state.tokenizer_manager.gracefully_exit:
|
||||||
|
logger.info("Health check request received during shutdown. Returning 503.")
|
||||||
|
return Response(status_code=503)
|
||||||
|
|
||||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
||||||
rid = f"HEALTH_CHECK_{time.time()}"
|
rid = f"HEALTH_CHECK_{time.time()}"
|
||||||
@@ -260,9 +263,14 @@ async def health_generate(request: Request) -> Response:
|
|||||||
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
||||||
break
|
break
|
||||||
|
|
||||||
tic = time.perf_counter()
|
# This request is a special request.
|
||||||
|
# If the server already has something running, this request will be ignored, so it creates zero overhead.
|
||||||
|
# If the server is not running, this request will be run, so we know whether the server is healthy.
|
||||||
task = asyncio.create_task(gen())
|
task = asyncio.create_task(gen())
|
||||||
while time.perf_counter() < tic + HEALTH_CHECK_TIMEOUT:
|
|
||||||
|
# As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
|
||||||
|
tic = time.time()
|
||||||
|
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|||||||
@@ -152,8 +152,6 @@ class GenerateReqInput:
|
|||||||
else:
|
else:
|
||||||
self._normalize_batch_inputs()
|
self._normalize_batch_inputs()
|
||||||
|
|
||||||
self._validate_session_params()
|
|
||||||
|
|
||||||
def _validate_inputs(self):
|
def _validate_inputs(self):
|
||||||
"""Validate that the input configuration is valid."""
|
"""Validate that the input configuration is valid."""
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||||
|
|
||||||
import datetime
|
|
||||||
import faulthandler
|
import faulthandler
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -21,11 +20,10 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict, deque
|
from collections import deque
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -37,7 +35,6 @@ from torch.distributed import barrier
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
|
||||||
from sglang.srt.constrained.base_grammar_backend import (
|
from sglang.srt.constrained.base_grammar_backend import (
|
||||||
INVALID_GRAMMAR_OBJ,
|
INVALID_GRAMMAR_OBJ,
|
||||||
create_grammar_backend,
|
create_grammar_backend,
|
||||||
@@ -47,7 +44,6 @@ from sglang.srt.disaggregation.decode import (
|
|||||||
DecodeTransferQueue,
|
DecodeTransferQueue,
|
||||||
SchedulerDisaggregationDecodeMixin,
|
SchedulerDisaggregationDecodeMixin,
|
||||||
)
|
)
|
||||||
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
|
||||||
from sglang.srt.disaggregation.prefill import (
|
from sglang.srt.disaggregation.prefill import (
|
||||||
PrefillBootstrapQueue,
|
PrefillBootstrapQueue,
|
||||||
SchedulerDisaggregationPrefillMixin,
|
SchedulerDisaggregationPrefillMixin,
|
||||||
@@ -78,21 +74,15 @@ from sglang.srt.managers.io_struct import (
|
|||||||
GetInternalStateReq,
|
GetInternalStateReq,
|
||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
GetWeightsByNameReqOutput,
|
|
||||||
HealthCheckOutput,
|
HealthCheckOutput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
InitWeightsUpdateGroupReqOutput,
|
|
||||||
LoadLoRAAdapterReqInput,
|
LoadLoRAAdapterReqInput,
|
||||||
LoadLoRAAdapterReqOutput,
|
LoadLoRAAdapterReqOutput,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
OpenSessionReqOutput,
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
ProfileReqOutput,
|
|
||||||
ProfileReqType,
|
|
||||||
ReleaseMemoryOccupationReqInput,
|
ReleaseMemoryOccupationReqInput,
|
||||||
ReleaseMemoryOccupationReqOutput,
|
|
||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqOutput,
|
|
||||||
RpcReqInput,
|
RpcReqInput,
|
||||||
RpcReqOutput,
|
RpcReqOutput,
|
||||||
SetInternalStateReq,
|
SetInternalStateReq,
|
||||||
@@ -104,11 +94,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UnloadLoRAAdapterReqInput,
|
UnloadLoRAAdapterReqInput,
|
||||||
UnloadLoRAAdapterReqOutput,
|
UnloadLoRAAdapterReqOutput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightFromDiskReqOutput,
|
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromDistributedReqOutput,
|
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
UpdateWeightsFromTensorReqOutput,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.mm_utils import init_embedding_cache
|
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
@@ -124,9 +111,17 @@ from sglang.srt.managers.schedule_policy import (
|
|||||||
SchedulePolicy,
|
SchedulePolicy,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
|
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
|
||||||
|
from sglang.srt.managers.scheduler_metrics_mixin import (
|
||||||
|
RECORD_STEP_TIME,
|
||||||
|
SchedulerMetricsMixin,
|
||||||
|
)
|
||||||
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
||||||
SchedulerOutputProcessorMixin,
|
SchedulerOutputProcessorMixin,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
||||||
|
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
||||||
|
SchedulerUpdateWeightsMixin,
|
||||||
|
)
|
||||||
from sglang.srt.managers.session_controller import Session
|
from sglang.srt.managers.session_controller import Session
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||||
@@ -135,7 +130,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
|||||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
@@ -168,7 +162,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Test retract decode for debugging purposes
|
# Test retract decode for debugging purposes
|
||||||
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||||
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
|
||||||
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||||
|
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
@@ -191,41 +184,11 @@ class EmbeddingBatchResult:
|
|||||||
bid: int
|
bid: int
|
||||||
|
|
||||||
|
|
||||||
class KvMetrics:
|
|
||||||
def __init__(self):
|
|
||||||
self.request_active_slots = None
|
|
||||||
self.request_total_slots = None
|
|
||||||
self.kv_active_blocks = None
|
|
||||||
self.kv_total_blocks = None
|
|
||||||
self.num_requests_waiting = None
|
|
||||||
self.gpu_cache_usage_perc = None
|
|
||||||
self.gpu_prefix_cache_hit_rate = None
|
|
||||||
self.data_parallel_rank = None
|
|
||||||
|
|
||||||
|
|
||||||
class IdleSleeper:
|
|
||||||
"""
|
|
||||||
In setups which have long inactivity periods it is desirable to reduce
|
|
||||||
system power consumption when sglang does nothing. This would lead not only
|
|
||||||
to power savings, but also to more CPU thermal headroom when a request
|
|
||||||
eventually comes. This is important in cases when multiple GPUs are connected
|
|
||||||
as each GPU would otherwise pin one thread at 100% CPU usage.
|
|
||||||
|
|
||||||
The simplest solution is to use zmq.Poller on all sockets that may receive
|
|
||||||
data that needs handling immediately.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, sockets):
|
|
||||||
self.poller = zmq.Poller()
|
|
||||||
for s in sockets:
|
|
||||||
self.poller.register(s, zmq.POLLIN)
|
|
||||||
|
|
||||||
def maybe_sleep(self):
|
|
||||||
self.poller.poll(1000)
|
|
||||||
|
|
||||||
|
|
||||||
class Scheduler(
|
class Scheduler(
|
||||||
SchedulerOutputProcessorMixin,
|
SchedulerOutputProcessorMixin,
|
||||||
|
SchedulerUpdateWeightsMixin,
|
||||||
|
SchedulerProfilerMixin,
|
||||||
|
SchedulerMetricsMixin,
|
||||||
SchedulerDisaggregationDecodeMixin,
|
SchedulerDisaggregationDecodeMixin,
|
||||||
SchedulerDisaggregationPrefillMixin,
|
SchedulerDisaggregationPrefillMixin,
|
||||||
):
|
):
|
||||||
@@ -266,7 +229,7 @@ class Scheduler(
|
|||||||
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
||||||
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
|
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
|
||||||
self.page_size = server_args.page_size
|
self.page_size = server_args.page_size
|
||||||
self.dp_size = server_args.dp_size
|
|
||||||
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
||||||
compute_dp_attention_world_info(
|
compute_dp_attention_world_info(
|
||||||
server_args.enable_dp_attention,
|
server_args.enable_dp_attention,
|
||||||
@@ -284,10 +247,13 @@ class Scheduler(
|
|||||||
self.recv_from_tokenizer = get_zmq_socket(
|
self.recv_from_tokenizer = get_zmq_socket(
|
||||||
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||||
)
|
)
|
||||||
|
self.recv_from_rpc = get_zmq_socket(
|
||||||
|
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
||||||
|
)
|
||||||
|
|
||||||
self.send_to_tokenizer = get_zmq_socket(
|
self.send_to_tokenizer = get_zmq_socket(
|
||||||
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
||||||
)
|
)
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
# Directly send to the TokenizerManager
|
# Directly send to the TokenizerManager
|
||||||
self.send_to_detokenizer = get_zmq_socket(
|
self.send_to_detokenizer = get_zmq_socket(
|
||||||
@@ -299,9 +265,6 @@ class Scheduler(
|
|||||||
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.recv_from_rpc = get_zmq_socket(
|
|
||||||
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
|
||||||
)
|
|
||||||
if self.server_args.sleep_on_idle:
|
if self.server_args.sleep_on_idle:
|
||||||
self.idle_sleeper = IdleSleeper(
|
self.idle_sleeper = IdleSleeper(
|
||||||
[
|
[
|
||||||
@@ -398,7 +361,7 @@ class Scheduler(
|
|||||||
global_server_args_dict.update(worker_global_server_args_dict)
|
global_server_args_dict.update(worker_global_server_args_dict)
|
||||||
set_random_seed(self.random_seed)
|
set_random_seed(self.random_seed)
|
||||||
|
|
||||||
# Hybrid
|
# Hybrid memory pool
|
||||||
self.is_hybrid = self.tp_worker.is_hybrid
|
self.is_hybrid = self.tp_worker.is_hybrid
|
||||||
if self.is_hybrid:
|
if self.is_hybrid:
|
||||||
self.sliding_window_size = self.tp_worker.sliding_window_size
|
self.sliding_window_size = self.tp_worker.sliding_window_size
|
||||||
@@ -515,6 +478,15 @@ class Scheduler(
|
|||||||
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
||||||
self.init_kv_events(server_args.kv_events_config)
|
self.init_kv_events(server_args.kv_events_config)
|
||||||
|
|
||||||
|
# Init disaggregation
|
||||||
|
self.disaggregation_mode = DisaggregationMode(
|
||||||
|
self.server_args.disaggregation_mode
|
||||||
|
)
|
||||||
|
self.init_disaggregation()
|
||||||
|
|
||||||
|
if get_bool_env_var("SGLANG_GC_LOG"):
|
||||||
|
configure_gc_logger()
|
||||||
|
|
||||||
# Init request dispatcher
|
# Init request dispatcher
|
||||||
self._request_dispatcher = TypeBasedDispatcher(
|
self._request_dispatcher = TypeBasedDispatcher(
|
||||||
[
|
[
|
||||||
@@ -545,22 +517,6 @@ class Scheduler(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init disaggregation
|
|
||||||
self.disaggregation_mode = DisaggregationMode(
|
|
||||||
self.server_args.disaggregation_mode
|
|
||||||
)
|
|
||||||
self.init_disaggregation()
|
|
||||||
|
|
||||||
if get_bool_env_var("SGLANG_GC_LOG"):
|
|
||||||
configure_gc_logger()
|
|
||||||
|
|
||||||
def current_scheduler_metrics_enabled(self):
|
|
||||||
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
|
|
||||||
|
|
||||||
def maybe_sleep_on_idle(self):
|
|
||||||
if self.idle_sleeper is not None:
|
|
||||||
self.idle_sleeper.maybe_sleep()
|
|
||||||
|
|
||||||
def init_tokenizer(self):
|
def init_tokenizer(self):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
|
|
||||||
@@ -668,50 +624,6 @@ class Scheduler(
|
|||||||
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
|
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
|
||||||
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
||||||
|
|
||||||
def init_profier(self):
|
|
||||||
self.torch_profiler = None
|
|
||||||
self.torch_profiler_output_dir: Optional[str] = None
|
|
||||||
self.profiler_activities: Optional[List[str]] = None
|
|
||||||
self.profile_id: Optional[str] = None
|
|
||||||
self.profiler_start_forward_ct: Optional[int] = None
|
|
||||||
self.profiler_target_forward_ct: Optional[int] = None
|
|
||||||
self.profiler_target_prefill_ct: Optional[int] = None
|
|
||||||
self.profiler_target_decode_ct: Optional[int] = None
|
|
||||||
self.profiler_prefill_ct: Optional[int] = None
|
|
||||||
self.profiler_decode_ct: Optional[int] = None
|
|
||||||
self.profile_by_stage: bool = False
|
|
||||||
self.profile_steps: Optional[int] = None
|
|
||||||
self.profile_in_progress: bool = False
|
|
||||||
self.rpd_profiler = None
|
|
||||||
|
|
||||||
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
|
|
||||||
self.last_gen_throughput: float = 0.0
|
|
||||||
self.last_input_throughput: float = 0.0
|
|
||||||
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
|
||||||
self.spec_num_total_accepted_tokens = 0
|
|
||||||
self.spec_num_total_forward_ct = 0
|
|
||||||
self.cum_spec_accept_length = 0
|
|
||||||
self.cum_spec_accept_count = 0
|
|
||||||
self.total_retracted_reqs = 0
|
|
||||||
self.stats = SchedulerStats()
|
|
||||||
if self.enable_metrics:
|
|
||||||
engine_type = "unified"
|
|
||||||
labels = {
|
|
||||||
"model_name": self.server_args.served_model_name,
|
|
||||||
"engine_type": engine_type,
|
|
||||||
"tp_rank": tp_rank,
|
|
||||||
"pp_rank": pp_rank,
|
|
||||||
}
|
|
||||||
if dp_rank is not None:
|
|
||||||
labels["dp_rank"] = dp_rank
|
|
||||||
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
|
||||||
|
|
||||||
def init_kv_events(self, kv_events_config: Optional[str]):
|
|
||||||
if self.enable_kv_cache_events:
|
|
||||||
self.kv_event_publisher = EventPublisherFactory.create(
|
|
||||||
kv_events_config, self.attn_dp_rank
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_disaggregation(self):
|
def init_disaggregation(self):
|
||||||
self.transfer_backend = TransferBackend(
|
self.transfer_backend = TransferBackend(
|
||||||
self.server_args.disaggregation_transfer_backend
|
self.server_args.disaggregation_transfer_backend
|
||||||
@@ -820,10 +732,7 @@ class Scheduler(
|
|||||||
self.process_batch_result(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
else:
|
else:
|
||||||
# When the server is idle, do self-check and re-init some states
|
# When the server is idle, do self-check and re-init some states
|
||||||
self.check_memory()
|
self.self_check_during_idle()
|
||||||
self.check_tree_cache()
|
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
|
||||||
self.maybe_sleep_on_idle()
|
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
@@ -866,10 +775,7 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
elif batch is None:
|
elif batch is None:
|
||||||
# When the server is idle, do self-check and re-init some states
|
# When the server is idle, do self-check and re-init some states
|
||||||
self.check_memory()
|
self.self_check_during_idle()
|
||||||
self.check_tree_cache()
|
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
|
||||||
self.maybe_sleep_on_idle()
|
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
@@ -1003,10 +909,8 @@ class Scheduler(
|
|||||||
|
|
||||||
# When the server is idle, self-check and re-init some states
|
# When the server is idle, self-check and re-init some states
|
||||||
if server_is_idle:
|
if server_is_idle:
|
||||||
self.check_memory()
|
# When the server is idle, do self-check and re-init some states
|
||||||
self.check_tree_cache()
|
self.self_check_during_idle()
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
|
||||||
self.maybe_sleep_on_idle()
|
|
||||||
|
|
||||||
def recv_requests(self) -> List[Req]:
|
def recv_requests(self) -> List[Req]:
|
||||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||||
@@ -1355,170 +1259,11 @@ class Scheduler(
|
|||||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
|
|
||||||
def _emit_kv_metrics(self):
|
def self_check_during_idle(self):
|
||||||
kv_metrics = KvMetrics()
|
self.check_memory()
|
||||||
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
self.check_tree_cache()
|
||||||
kv_metrics.request_total_slots = self.max_running_requests
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
kv_metrics.kv_active_blocks = int(
|
self.maybe_sleep_on_idle()
|
||||||
self.stats.token_usage * self.max_total_num_tokens
|
|
||||||
)
|
|
||||||
kv_metrics.kv_total_blocks = self.max_total_num_tokens
|
|
||||||
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
|
|
||||||
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
|
|
||||||
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
|
|
||||||
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
|
|
||||||
|
|
||||||
if not self.send_metrics_from_scheduler.closed:
|
|
||||||
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
|
||||||
|
|
||||||
def log_prefill_stats(
|
|
||||||
self,
|
|
||||||
adder: PrefillAdder,
|
|
||||||
can_run_list: List[Req],
|
|
||||||
running_bs: int,
|
|
||||||
):
|
|
||||||
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
|
||||||
self.last_prefill_stats_tic = time.perf_counter()
|
|
||||||
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
|
||||||
self.last_prefill_tokens = adder.log_input_tokens
|
|
||||||
|
|
||||||
if self.is_hybrid:
|
|
||||||
(
|
|
||||||
full_num_used,
|
|
||||||
swa_num_used,
|
|
||||||
full_token_usage,
|
|
||||||
swa_token_usage,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) = self._get_swa_token_info()
|
|
||||||
num_used = max(full_num_used, swa_num_used)
|
|
||||||
token_usage = max(full_token_usage, swa_token_usage)
|
|
||||||
token_msg = (
|
|
||||||
f"full token usage: {full_token_usage:.2f}, "
|
|
||||||
f"swa token usage: {swa_token_usage:.2f}, "
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
num_used, token_usage, _, _ = self._get_token_info()
|
|
||||||
token_msg = f"token usage: {token_usage:.2f}, "
|
|
||||||
|
|
||||||
num_new_seq = len(can_run_list)
|
|
||||||
f = (
|
|
||||||
f"Prefill batch. "
|
|
||||||
f"#new-seq: {num_new_seq}, "
|
|
||||||
f"#new-token: {adder.log_input_tokens}, "
|
|
||||||
f"#cached-token: {adder.log_hit_tokens}, "
|
|
||||||
f"{token_msg}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
||||||
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
|
||||||
f += f"#queue-req: {len(self.waiting_queue)}, "
|
|
||||||
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
|
||||||
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
|
|
||||||
else:
|
|
||||||
f += f"#running-req: {running_bs}, "
|
|
||||||
f += f"#queue-req: {len(self.waiting_queue)}, "
|
|
||||||
|
|
||||||
logger.info(f)
|
|
||||||
|
|
||||||
if self.enable_metrics:
|
|
||||||
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
|
||||||
|
|
||||||
cache_hit_rate = (
|
|
||||||
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
|
||||||
)
|
|
||||||
self.stats.num_running_reqs = running_bs
|
|
||||||
self.stats.num_used_tokens = num_used
|
|
||||||
self.stats.token_usage = round(token_usage, 2)
|
|
||||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
|
||||||
self.stats.cache_hit_rate = cache_hit_rate
|
|
||||||
|
|
||||||
total_queue_latency = 0
|
|
||||||
for req in can_run_list:
|
|
||||||
total_queue_latency += req.queue_time_end - req.queue_time_start
|
|
||||||
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
|
||||||
|
|
||||||
self.metrics_collector.log_stats(self.stats)
|
|
||||||
self._emit_kv_metrics()
|
|
||||||
self._publish_kv_events()
|
|
||||||
|
|
||||||
def log_decode_stats(
|
|
||||||
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
|
||||||
):
|
|
||||||
batch = running_batch or self.running_batch
|
|
||||||
|
|
||||||
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
|
||||||
self.last_decode_stats_tic = time.perf_counter()
|
|
||||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
|
||||||
self.num_generated_tokens = 0
|
|
||||||
num_running_reqs = len(batch.reqs)
|
|
||||||
if self.is_hybrid:
|
|
||||||
(
|
|
||||||
full_num_used,
|
|
||||||
swa_num_used,
|
|
||||||
full_token_usage,
|
|
||||||
swa_token_usage,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) = self._get_swa_token_info()
|
|
||||||
num_used = max(full_num_used, swa_num_used)
|
|
||||||
token_usage = max(full_token_usage, swa_token_usage)
|
|
||||||
token_msg = (
|
|
||||||
f"#full token: {full_num_used}, "
|
|
||||||
f"full token usage: {full_token_usage:.2f}, "
|
|
||||||
f"#swa token: {swa_num_used}, "
|
|
||||||
f"swa token usage: {swa_token_usage:.2f}, "
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
num_used, token_usage, _, _ = self._get_token_info()
|
|
||||||
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
|
|
||||||
|
|
||||||
if RECORD_STEP_TIME:
|
|
||||||
self.step_time_dict[num_running_reqs].append(
|
|
||||||
gap_latency / self.server_args.decode_log_interval
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
|
|
||||||
|
|
||||||
if self.spec_algorithm.is_none():
|
|
||||||
spec_accept_length = 0
|
|
||||||
else:
|
|
||||||
spec_accept_length = (
|
|
||||||
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
|
||||||
)
|
|
||||||
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
|
||||||
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
|
||||||
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
|
||||||
msg += f"accept len: {spec_accept_length:.2f}, "
|
|
||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
||||||
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
|
||||||
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
|
||||||
|
|
||||||
msg += (
|
|
||||||
f"cuda graph: {can_run_cuda_graph}, "
|
|
||||||
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
|
||||||
f"#queue-req: {len(self.waiting_queue)}, "
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(msg)
|
|
||||||
if self.enable_metrics:
|
|
||||||
self.stats.num_running_reqs = num_running_reqs
|
|
||||||
self.stats.num_used_tokens = num_used
|
|
||||||
self.stats.token_usage = round(token_usage, 2)
|
|
||||||
self.stats.cache_hit_rate = 0.0
|
|
||||||
self.stats.gen_throughput = self.last_gen_throughput
|
|
||||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
|
||||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
|
||||||
self.stats.spec_accept_length = spec_accept_length
|
|
||||||
self.stats.total_retracted_reqs = self.total_retracted_reqs
|
|
||||||
self.metrics_collector.log_stats(self.stats)
|
|
||||||
self._emit_kv_metrics()
|
|
||||||
self._publish_kv_events()
|
|
||||||
|
|
||||||
def check_memory(self):
|
def check_memory(self):
|
||||||
if self.is_hybrid:
|
if self.is_hybrid:
|
||||||
@@ -2422,22 +2167,6 @@ class Scheduler(
|
|||||||
barrier()
|
barrier()
|
||||||
return RpcReqOutput(success, "" if not exec else str(exec))
|
return RpcReqOutput(success, "" if not exec else str(exec))
|
||||||
|
|
||||||
def save_remote_model(self, params):
|
|
||||||
url = params["url"]
|
|
||||||
|
|
||||||
worker = self.tp_worker.worker
|
|
||||||
|
|
||||||
worker.model_runner.save_remote_model(url)
|
|
||||||
|
|
||||||
def save_sharded_model(self, params):
|
|
||||||
worker = self.tp_worker.worker
|
|
||||||
|
|
||||||
worker.model_runner.save_sharded_model(
|
|
||||||
path=params["path"],
|
|
||||||
pattern=params["pattern"],
|
|
||||||
max_size=params["max_size"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def abort_request(self, recv_req: AbortReq):
|
def abort_request(self, recv_req: AbortReq):
|
||||||
# Delete requests in the waiting queue
|
# Delete requests in the waiting queue
|
||||||
to_del = []
|
to_del = []
|
||||||
@@ -2515,16 +2244,6 @@ class Scheduler(
|
|||||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
|
||||||
"""In-place update of the weights from disk."""
|
|
||||||
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
|
||||||
if success:
|
|
||||||
flush_cache_success = self.flush_cache()
|
|
||||||
assert flush_cache_success, "Cache flush failed after updating weights"
|
|
||||||
else:
|
|
||||||
logger.error(message)
|
|
||||||
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
|
||||||
|
|
||||||
def load_lora_adapter(
|
def load_lora_adapter(
|
||||||
self, recv_req: LoadLoRAAdapterReqInput
|
self, recv_req: LoadLoRAAdapterReqInput
|
||||||
) -> LoadLoRAAdapterReqOutput:
|
) -> LoadLoRAAdapterReqOutput:
|
||||||
@@ -2541,81 +2260,6 @@ class Scheduler(
|
|||||||
result = self.tp_worker.unload_lora_adapter(recv_req)
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
|
||||||
"""Initialize the online model parameter update group."""
|
|
||||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
|
||||||
return InitWeightsUpdateGroupReqOutput(success, message)
|
|
||||||
|
|
||||||
def update_weights_from_distributed(
|
|
||||||
self,
|
|
||||||
recv_req: UpdateWeightsFromDistributedReqInput,
|
|
||||||
) -> Tuple[bool, str]:
|
|
||||||
"""Update the online model parameter."""
|
|
||||||
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
|
||||||
if success:
|
|
||||||
if recv_req.flush_cache:
|
|
||||||
flush_cache_success = self.flush_cache()
|
|
||||||
assert flush_cache_success, "Cache flush failed after updating weights"
|
|
||||||
else:
|
|
||||||
logger.error(message)
|
|
||||||
return UpdateWeightsFromDistributedReqOutput(success, message)
|
|
||||||
|
|
||||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
|
||||||
"""Update the online model parameter from tensors."""
|
|
||||||
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
|
||||||
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
|
||||||
if success:
|
|
||||||
if recv_req.flush_cache:
|
|
||||||
flush_cache_success = self.flush_cache()
|
|
||||||
assert flush_cache_success, "Cache flush failed after updating weights"
|
|
||||||
else:
|
|
||||||
logger.error(message)
|
|
||||||
barrier(group=self.tp_cpu_group)
|
|
||||||
return UpdateWeightsFromTensorReqOutput(success, message)
|
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
|
||||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
|
||||||
return GetWeightsByNameReqOutput(parameter)
|
|
||||||
|
|
||||||
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
|
||||||
tags = recv_req.tags
|
|
||||||
|
|
||||||
if tags is None or len(tags) == 0:
|
|
||||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
|
||||||
|
|
||||||
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
|
||||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
|
||||||
self.flush_cache()
|
|
||||||
|
|
||||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
|
||||||
self.stashed_model_static_state = _export_static_state(
|
|
||||||
self.tp_worker.worker.model_runner.model
|
|
||||||
)
|
|
||||||
torch.distributed.barrier(self.tp_cpu_group)
|
|
||||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
|
||||||
|
|
||||||
return ReleaseMemoryOccupationReqOutput()
|
|
||||||
|
|
||||||
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
|
||||||
tags = recv_req.tags
|
|
||||||
|
|
||||||
if tags is None or len(tags) == 0:
|
|
||||||
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
|
||||||
|
|
||||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
|
||||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
|
||||||
torch.distributed.barrier(self.tp_cpu_group)
|
|
||||||
_import_static_state(
|
|
||||||
self.tp_worker.worker.model_runner.model,
|
|
||||||
self.stashed_model_static_state,
|
|
||||||
)
|
|
||||||
del self.stashed_model_static_state
|
|
||||||
|
|
||||||
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
|
||||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
|
||||||
|
|
||||||
return ResumeMemoryOccupationReqOutput()
|
|
||||||
|
|
||||||
def slow_down(self, recv_req: SlowDownReqInput):
|
def slow_down(self, recv_req: SlowDownReqInput):
|
||||||
t = recv_req.forward_sleep_time
|
t = recv_req.forward_sleep_time
|
||||||
if t is not None and t <= 0:
|
if t is not None and t <= 0:
|
||||||
@@ -2623,254 +2267,6 @@ class Scheduler(
|
|||||||
self.forward_sleep_time = t
|
self.forward_sleep_time = t
|
||||||
return SlowDownReqOutput()
|
return SlowDownReqOutput()
|
||||||
|
|
||||||
def profile(self, recv_req: ProfileReq):
|
|
||||||
if recv_req.type == ProfileReqType.START_PROFILE:
|
|
||||||
if recv_req.profile_by_stage or recv_req.start_step:
|
|
||||||
return self.init_profile(
|
|
||||||
recv_req.output_dir,
|
|
||||||
recv_req.start_step,
|
|
||||||
recv_req.num_steps,
|
|
||||||
recv_req.activities,
|
|
||||||
recv_req.with_stack,
|
|
||||||
recv_req.record_shapes,
|
|
||||||
recv_req.profile_by_stage,
|
|
||||||
recv_req.profile_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.init_profile(
|
|
||||||
recv_req.output_dir,
|
|
||||||
recv_req.start_step,
|
|
||||||
recv_req.num_steps,
|
|
||||||
recv_req.activities,
|
|
||||||
recv_req.with_stack,
|
|
||||||
recv_req.record_shapes,
|
|
||||||
recv_req.profile_by_stage,
|
|
||||||
recv_req.profile_id,
|
|
||||||
)
|
|
||||||
return self.start_profile(True)
|
|
||||||
else:
|
|
||||||
return self.stop_profile()
|
|
||||||
|
|
||||||
def init_profile(
|
|
||||||
self,
|
|
||||||
output_dir: Optional[str],
|
|
||||||
start_step: Optional[int],
|
|
||||||
num_steps: Optional[int],
|
|
||||||
activities: Optional[List[str]],
|
|
||||||
with_stack: Optional[bool],
|
|
||||||
record_shapes: Optional[bool],
|
|
||||||
profile_by_stage: bool,
|
|
||||||
profile_id: str,
|
|
||||||
) -> ProfileReqOutput:
|
|
||||||
if self.profile_in_progress:
|
|
||||||
return ProfileReqOutput(
|
|
||||||
success=False,
|
|
||||||
message="Profiling is already in progress. Call /stop_profile first.",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.profile_by_stage = profile_by_stage
|
|
||||||
|
|
||||||
if output_dir is None:
|
|
||||||
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
|
||||||
if activities is None:
|
|
||||||
activities = ["CPU", "GPU"]
|
|
||||||
|
|
||||||
self.torch_profiler_output_dir = output_dir
|
|
||||||
self.torch_profiler_with_stack = with_stack
|
|
||||||
self.torch_profiler_record_shapes = record_shapes
|
|
||||||
self.profiler_activities = activities
|
|
||||||
self.profile_id = profile_id
|
|
||||||
|
|
||||||
if start_step:
|
|
||||||
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
|
|
||||||
|
|
||||||
if num_steps:
|
|
||||||
self.profile_steps = num_steps
|
|
||||||
if self.profile_by_stage:
|
|
||||||
self.profiler_target_prefill_ct = num_steps
|
|
||||||
self.profiler_target_decode_ct = num_steps
|
|
||||||
self.profiler_prefill_ct = 0
|
|
||||||
self.profiler_decode_ct = 0
|
|
||||||
elif start_step:
|
|
||||||
self.profiler_target_forward_ct = (
|
|
||||||
self.profiler_start_forward_ct + num_steps
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
|
||||||
# The caller will be notified when reaching profiler_target_forward_ct
|
|
||||||
else:
|
|
||||||
self.profiler_target_forward_ct = None
|
|
||||||
|
|
||||||
return ProfileReqOutput(success=True, message="Succeeded")
|
|
||||||
|
|
||||||
def start_profile(
|
|
||||||
self, stage: Optional[ForwardMode] = None
|
|
||||||
) -> ProfileReqOutput | None:
|
|
||||||
stage_str = f" for {stage.__str__()}" if stage else ""
|
|
||||||
logger.info(
|
|
||||||
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
|
||||||
)
|
|
||||||
|
|
||||||
activities = self.profiler_activities
|
|
||||||
with_stack = self.torch_profiler_with_stack
|
|
||||||
record_shapes = self.torch_profiler_record_shapes
|
|
||||||
|
|
||||||
activity_map = {
|
|
||||||
"CPU": torch.profiler.ProfilerActivity.CPU,
|
|
||||||
"GPU": torch.profiler.ProfilerActivity.CUDA,
|
|
||||||
}
|
|
||||||
torchprof_activities = [
|
|
||||||
activity_map[a] for a in activities if a in activity_map
|
|
||||||
]
|
|
||||||
|
|
||||||
if "RPD" in activities:
|
|
||||||
from rpdTracerControl import rpdTracerControl
|
|
||||||
|
|
||||||
rpdTracerControl.skipCreate()
|
|
||||||
|
|
||||||
self.rpd_profile_path = os.path.join(
|
|
||||||
self.torch_profiler_output_dir,
|
|
||||||
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.tp_rank == 0:
|
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
from rocpd.schema import RocpdSchema
|
|
||||||
|
|
||||||
if os.path.exists("trace.rpd"):
|
|
||||||
os.unlink("trace.rpd")
|
|
||||||
schema = RocpdSchema()
|
|
||||||
connection = sqlite3.connect("trace.rpd")
|
|
||||||
schema.writeSchema(connection)
|
|
||||||
connection.commit()
|
|
||||||
del connection
|
|
||||||
torch.distributed.barrier(self.tp_cpu_group)
|
|
||||||
|
|
||||||
self.rpd_profiler = rpdTracerControl()
|
|
||||||
self.rpd_profiler.setPythonTrace(True)
|
|
||||||
self.rpd_profiler.start()
|
|
||||||
self.rpd_profiler.rangePush("", "rpd profile range", "")
|
|
||||||
self.profile_in_progress = True
|
|
||||||
elif torchprof_activities:
|
|
||||||
self.torch_profiler = torch.profiler.profile(
|
|
||||||
activities=torchprof_activities,
|
|
||||||
with_stack=with_stack if with_stack is not None else True,
|
|
||||||
record_shapes=record_shapes if record_shapes is not None else False,
|
|
||||||
)
|
|
||||||
self.torch_profiler.start()
|
|
||||||
self.profile_in_progress = True
|
|
||||||
|
|
||||||
if "MEM" in activities:
|
|
||||||
torch.cuda.memory._record_memory_history(max_entries=100000)
|
|
||||||
self.profile_in_progress = True
|
|
||||||
|
|
||||||
if "CUDA_PROFILER" in activities:
|
|
||||||
torch.cuda.cudart().cudaProfilerStart()
|
|
||||||
self.profile_in_progress = True
|
|
||||||
|
|
||||||
return ProfileReqOutput(success=True, message="Succeeded")
|
|
||||||
|
|
||||||
def stop_profile(
|
|
||||||
self, stage: Optional[ForwardMode] = None
|
|
||||||
) -> ProfileReqOutput | None:
|
|
||||||
if not self.profile_in_progress:
|
|
||||||
return ProfileReqOutput(
|
|
||||||
success=False,
|
|
||||||
message="Profiling is not in progress. Call /start_profile first.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not Path(self.torch_profiler_output_dir).exists():
|
|
||||||
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
stage_suffix = f"-{stage.__str__()}" if stage else ""
|
|
||||||
logger.info("Stop profiling" + stage_suffix + "...")
|
|
||||||
if self.torch_profiler is not None:
|
|
||||||
self.torch_profiler.stop()
|
|
||||||
self.torch_profiler.export_chrome_trace(
|
|
||||||
os.path.join(
|
|
||||||
self.torch_profiler_output_dir,
|
|
||||||
self.profile_id
|
|
||||||
+ f"-TP-{self.tp_rank}"
|
|
||||||
+ stage_suffix
|
|
||||||
+ ".trace.json.gz",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
torch.distributed.barrier(self.tp_cpu_group)
|
|
||||||
|
|
||||||
if self.rpd_profiler is not None:
|
|
||||||
self.rpd_profiler.rangePop()
|
|
||||||
self.rpd_profiler.stop()
|
|
||||||
self.rpd_profiler.flush()
|
|
||||||
|
|
||||||
torch.distributed.barrier(self.tp_cpu_group)
|
|
||||||
if self.tp_rank == 0:
|
|
||||||
from sglang.srt.utils import rpd_to_chrome_trace
|
|
||||||
|
|
||||||
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
|
|
||||||
self.rpd_profiler = None
|
|
||||||
self.rpd_profiler_path = None
|
|
||||||
|
|
||||||
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
|
|
||||||
memory_profile_path = os.path.join(
|
|
||||||
self.torch_profiler_output_dir,
|
|
||||||
str(time.time())
|
|
||||||
+ f"-TP-{self.tp_rank}-memory"
|
|
||||||
+ stage_suffix
|
|
||||||
+ ".pickle",
|
|
||||||
)
|
|
||||||
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
|
||||||
torch.cuda.memory._record_memory_history(enabled=None)
|
|
||||||
|
|
||||||
if "CUDA_PROFILER" in self.profiler_activities:
|
|
||||||
torch.cuda.cudart().cudaProfilerStop()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Profiling done. Traces are saved to: %s",
|
|
||||||
self.torch_profiler_output_dir,
|
|
||||||
)
|
|
||||||
self.torch_profiler = None
|
|
||||||
self.profile_in_progress = False
|
|
||||||
self.profiler_start_forward_ct = None
|
|
||||||
|
|
||||||
return ProfileReqOutput(success=True, message="Succeeded.")
|
|
||||||
|
|
||||||
def _profile_batch_predicate(self, batch):
|
|
||||||
if self.profile_by_stage:
|
|
||||||
if batch.forward_mode.is_prefill():
|
|
||||||
if self.profiler_prefill_ct == 0:
|
|
||||||
self.start_profile(batch.forward_mode)
|
|
||||||
self.profiler_prefill_ct += 1
|
|
||||||
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
|
|
||||||
if self.profile_in_progress:
|
|
||||||
self.stop_profile(stage=ForwardMode.EXTEND)
|
|
||||||
elif batch.forward_mode.is_decode():
|
|
||||||
if self.profiler_decode_ct == 0:
|
|
||||||
if self.profile_in_progress:
|
|
||||||
# force trace flush
|
|
||||||
self.stop_profile(ForwardMode.EXTEND)
|
|
||||||
self.start_profile(batch.forward_mode)
|
|
||||||
self.profiler_decode_ct += 1
|
|
||||||
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
|
||||||
if self.profile_in_progress:
|
|
||||||
self.stop_profile(stage=ForwardMode.DECODE)
|
|
||||||
elif batch.forward_mode.is_idle():
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
|
|
||||||
else:
|
|
||||||
# Check profiler
|
|
||||||
if (
|
|
||||||
self.profiler_target_forward_ct
|
|
||||||
and self.profiler_target_forward_ct <= self.forward_ct
|
|
||||||
):
|
|
||||||
self.stop_profile()
|
|
||||||
if (
|
|
||||||
self.profiler_start_forward_ct
|
|
||||||
and self.profiler_start_forward_ct == self.forward_ct
|
|
||||||
):
|
|
||||||
self.start_profile()
|
|
||||||
|
|
||||||
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
||||||
if recv_req == ExpertDistributionReq.START_RECORD:
|
if recv_req == ExpertDistributionReq.START_RECORD:
|
||||||
get_global_expert_distribution_recorder().start_record()
|
get_global_expert_distribution_recorder().start_record()
|
||||||
@@ -2879,7 +2275,7 @@ class Scheduler(
|
|||||||
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
||||||
get_global_expert_distribution_recorder().dump_record()
|
get_global_expert_distribution_recorder().dump_record()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unrecognized ExpertDistributionReq value")
|
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
|
||||||
return ExpertDistributionReqOutput()
|
return ExpertDistributionReqOutput()
|
||||||
|
|
||||||
def open_session(self, recv_req: OpenSessionReqInput):
|
def open_session(self, recv_req: OpenSessionReqInput):
|
||||||
@@ -2915,12 +2311,33 @@ class Scheduler(
|
|||||||
prefix += f" PP{self.pp_rank}"
|
prefix += f" PP{self.pp_rank}"
|
||||||
return prefix
|
return prefix
|
||||||
|
|
||||||
def _publish_kv_events(self):
|
def current_scheduler_metrics_enabled(self):
|
||||||
if self.enable_kv_cache_events:
|
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
|
||||||
events = self.tree_cache.take_events()
|
|
||||||
if events:
|
def maybe_sleep_on_idle(self):
|
||||||
batch = KVEventBatch(ts=time.time(), events=events)
|
if self.idle_sleeper is not None:
|
||||||
self.kv_event_publisher.publish(batch)
|
self.idle_sleeper.maybe_sleep()
|
||||||
|
|
||||||
|
|
||||||
|
class IdleSleeper:
|
||||||
|
"""
|
||||||
|
In setups which have long inactivity periods it is desirable to reduce
|
||||||
|
system power consumption when sglang does nothing. This would lead not only
|
||||||
|
to power savings, but also to more CPU thermal headroom when a request
|
||||||
|
eventually comes. This is important in cases when multiple GPUs are connected
|
||||||
|
as each GPU would otherwise pin one thread at 100% CPU usage.
|
||||||
|
|
||||||
|
The simplest solution is to use zmq.Poller on all sockets that may receive
|
||||||
|
data that needs handling immediately.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sockets):
|
||||||
|
self.poller = zmq.Poller()
|
||||||
|
for s in sockets:
|
||||||
|
self.poller.register(s, zmq.POLLIN)
|
||||||
|
|
||||||
|
def maybe_sleep(self):
|
||||||
|
self.poller.poll(1000)
|
||||||
|
|
||||||
|
|
||||||
def is_health_check_generate_req(recv_req):
|
def is_health_check_generate_req(recv_req):
|
||||||
@@ -2931,20 +2348,6 @@ def is_work_request(recv_req):
|
|||||||
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
|
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
|
||||||
|
|
||||||
|
|
||||||
def _export_static_state(model):
|
|
||||||
return dict(
|
|
||||||
buffers=[
|
|
||||||
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _import_static_state(model, static_params):
|
|
||||||
self_named_buffers = dict(model.named_buffers())
|
|
||||||
for name, tensor in static_params["buffers"]:
|
|
||||||
self_named_buffers[name][...] = tensor
|
|
||||||
|
|
||||||
|
|
||||||
def run_scheduler_process(
|
def run_scheduler_process(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
|
|||||||
229
python/sglang/srt/managers/scheduler_metrics_mixin.py
Normal file
229
python/sglang/srt/managers/scheduler_metrics_mixin.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
||||||
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
|
from sglang.srt.managers.schedule_policy import PrefillAdder
|
||||||
|
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
||||||
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||||
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
||||||
|
|
||||||
|
|
||||||
|
class KvMetrics:
|
||||||
|
def __init__(self):
|
||||||
|
self.request_active_slots = None
|
||||||
|
self.request_total_slots = None
|
||||||
|
self.kv_active_blocks = None
|
||||||
|
self.kv_total_blocks = None
|
||||||
|
self.num_requests_waiting = None
|
||||||
|
self.gpu_cache_usage_perc = None
|
||||||
|
self.gpu_prefix_cache_hit_rate = None
|
||||||
|
self.data_parallel_rank = None
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMetricsMixin:
|
||||||
|
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
|
||||||
|
self.last_gen_throughput: float = 0.0
|
||||||
|
self.last_input_throughput: float = 0.0
|
||||||
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
||||||
|
self.spec_num_total_accepted_tokens = 0
|
||||||
|
self.spec_num_total_forward_ct = 0
|
||||||
|
self.cum_spec_accept_length = 0
|
||||||
|
self.cum_spec_accept_count = 0
|
||||||
|
self.total_retracted_reqs = 0
|
||||||
|
self.stats = SchedulerStats()
|
||||||
|
if self.enable_metrics:
|
||||||
|
engine_type = "unified"
|
||||||
|
labels = {
|
||||||
|
"model_name": self.server_args.served_model_name,
|
||||||
|
"engine_type": engine_type,
|
||||||
|
"tp_rank": tp_rank,
|
||||||
|
"pp_rank": pp_rank,
|
||||||
|
}
|
||||||
|
if dp_rank is not None:
|
||||||
|
labels["dp_rank"] = dp_rank
|
||||||
|
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
||||||
|
|
||||||
|
def init_kv_events(self, kv_events_config: Optional[str]):
|
||||||
|
if self.enable_kv_cache_events:
|
||||||
|
self.kv_event_publisher = EventPublisherFactory.create(
|
||||||
|
kv_events_config, self.attn_dp_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_prefill_stats(
|
||||||
|
self,
|
||||||
|
adder: PrefillAdder,
|
||||||
|
can_run_list: List[Req],
|
||||||
|
running_bs: int,
|
||||||
|
):
|
||||||
|
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
||||||
|
self.last_prefill_stats_tic = time.perf_counter()
|
||||||
|
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
||||||
|
self.last_prefill_tokens = adder.log_input_tokens
|
||||||
|
|
||||||
|
if self.is_hybrid:
|
||||||
|
(
|
||||||
|
full_num_used,
|
||||||
|
swa_num_used,
|
||||||
|
full_token_usage,
|
||||||
|
swa_token_usage,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self._get_swa_token_info()
|
||||||
|
num_used = max(full_num_used, swa_num_used)
|
||||||
|
token_usage = max(full_token_usage, swa_token_usage)
|
||||||
|
token_msg = (
|
||||||
|
f"full token usage: {full_token_usage:.2f}, "
|
||||||
|
f"swa token usage: {swa_token_usage:.2f}, "
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
num_used, token_usage, _, _ = self._get_token_info()
|
||||||
|
token_msg = f"token usage: {token_usage:.2f}, "
|
||||||
|
|
||||||
|
num_new_seq = len(can_run_list)
|
||||||
|
f = (
|
||||||
|
f"Prefill batch. "
|
||||||
|
f"#new-seq: {num_new_seq}, "
|
||||||
|
f"#new-token: {adder.log_input_tokens}, "
|
||||||
|
f"#cached-token: {adder.log_hit_tokens}, "
|
||||||
|
f"{token_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
||||||
|
f += f"#queue-req: {len(self.waiting_queue)}, "
|
||||||
|
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
||||||
|
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
|
||||||
|
else:
|
||||||
|
f += f"#running-req: {running_bs}, "
|
||||||
|
f += f"#queue-req: {len(self.waiting_queue)}, "
|
||||||
|
|
||||||
|
logger.info(f)
|
||||||
|
|
||||||
|
if self.enable_metrics:
|
||||||
|
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
||||||
|
|
||||||
|
cache_hit_rate = (
|
||||||
|
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
||||||
|
)
|
||||||
|
self.stats.num_running_reqs = running_bs
|
||||||
|
self.stats.num_used_tokens = num_used
|
||||||
|
self.stats.token_usage = round(token_usage, 2)
|
||||||
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||||
|
self.stats.cache_hit_rate = cache_hit_rate
|
||||||
|
|
||||||
|
total_queue_latency = 0
|
||||||
|
for req in can_run_list:
|
||||||
|
total_queue_latency += req.queue_time_end - req.queue_time_start
|
||||||
|
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
||||||
|
|
||||||
|
self.metrics_collector.log_stats(self.stats)
|
||||||
|
self._emit_kv_metrics()
|
||||||
|
self._publish_kv_events()
|
||||||
|
|
||||||
|
def log_decode_stats(
|
||||||
|
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
||||||
|
):
|
||||||
|
batch = running_batch or self.running_batch
|
||||||
|
|
||||||
|
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
||||||
|
self.last_decode_stats_tic = time.perf_counter()
|
||||||
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||||
|
self.num_generated_tokens = 0
|
||||||
|
num_running_reqs = len(batch.reqs)
|
||||||
|
if self.is_hybrid:
|
||||||
|
(
|
||||||
|
full_num_used,
|
||||||
|
swa_num_used,
|
||||||
|
full_token_usage,
|
||||||
|
swa_token_usage,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self._get_swa_token_info()
|
||||||
|
num_used = max(full_num_used, swa_num_used)
|
||||||
|
token_usage = max(full_token_usage, swa_token_usage)
|
||||||
|
token_msg = (
|
||||||
|
f"#full token: {full_num_used}, "
|
||||||
|
f"full token usage: {full_token_usage:.2f}, "
|
||||||
|
f"#swa token: {swa_num_used}, "
|
||||||
|
f"swa token usage: {swa_token_usage:.2f}, "
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
num_used, token_usage, _, _ = self._get_token_info()
|
||||||
|
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
|
||||||
|
|
||||||
|
if RECORD_STEP_TIME:
|
||||||
|
self.step_time_dict[num_running_reqs].append(
|
||||||
|
gap_latency / self.server_args.decode_log_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
|
||||||
|
|
||||||
|
if self.spec_algorithm.is_none():
|
||||||
|
spec_accept_length = 0
|
||||||
|
else:
|
||||||
|
spec_accept_length = (
|
||||||
|
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
||||||
|
)
|
||||||
|
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
||||||
|
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
||||||
|
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
||||||
|
msg += f"accept len: {spec_accept_length:.2f}, "
|
||||||
|
|
||||||
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
||||||
|
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
||||||
|
|
||||||
|
msg += (
|
||||||
|
f"cuda graph: {can_run_cuda_graph}, "
|
||||||
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
||||||
|
f"#queue-req: {len(self.waiting_queue)}, "
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(msg)
|
||||||
|
if self.enable_metrics:
|
||||||
|
self.stats.num_running_reqs = num_running_reqs
|
||||||
|
self.stats.num_used_tokens = num_used
|
||||||
|
self.stats.token_usage = round(token_usage, 2)
|
||||||
|
self.stats.cache_hit_rate = 0.0
|
||||||
|
self.stats.gen_throughput = self.last_gen_throughput
|
||||||
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||||
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||||
|
self.stats.spec_accept_length = spec_accept_length
|
||||||
|
self.stats.total_retracted_reqs = self.total_retracted_reqs
|
||||||
|
self.metrics_collector.log_stats(self.stats)
|
||||||
|
self._emit_kv_metrics()
|
||||||
|
self._publish_kv_events()
|
||||||
|
|
||||||
|
def _emit_kv_metrics(self):
|
||||||
|
kv_metrics = KvMetrics()
|
||||||
|
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
||||||
|
kv_metrics.request_total_slots = self.max_running_requests
|
||||||
|
kv_metrics.kv_active_blocks = int(
|
||||||
|
self.stats.token_usage * self.max_total_num_tokens
|
||||||
|
)
|
||||||
|
kv_metrics.kv_total_blocks = self.max_total_num_tokens
|
||||||
|
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
|
||||||
|
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
|
||||||
|
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
|
||||||
|
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
|
||||||
|
|
||||||
|
if not self.send_metrics_from_scheduler.closed:
|
||||||
|
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
||||||
|
|
||||||
|
def _publish_kv_events(self):
|
||||||
|
if self.enable_kv_cache_events:
|
||||||
|
events = self.tree_cache.take_events()
|
||||||
|
if events:
|
||||||
|
batch = KVEventBatch(ts=time.time(), events=events)
|
||||||
|
self.kv_event_publisher.publish(batch)
|
||||||
279
python/sglang/srt/managers/scheduler_profiler_mixin.py
Normal file
279
python/sglang/srt/managers/scheduler_profiler_mixin.py
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerProfilerMixin:
|
||||||
|
|
||||||
|
def init_profier(self):
|
||||||
|
self.torch_profiler = None
|
||||||
|
self.torch_profiler_output_dir: Optional[str] = None
|
||||||
|
self.profiler_activities: Optional[List[str]] = None
|
||||||
|
self.profile_id: Optional[str] = None
|
||||||
|
self.profiler_start_forward_ct: Optional[int] = None
|
||||||
|
self.profiler_target_forward_ct: Optional[int] = None
|
||||||
|
self.profiler_target_prefill_ct: Optional[int] = None
|
||||||
|
self.profiler_target_decode_ct: Optional[int] = None
|
||||||
|
self.profiler_prefill_ct: Optional[int] = None
|
||||||
|
self.profiler_decode_ct: Optional[int] = None
|
||||||
|
self.profile_by_stage: bool = False
|
||||||
|
self.profile_steps: Optional[int] = None
|
||||||
|
self.profile_in_progress: bool = False
|
||||||
|
self.rpd_profiler = None
|
||||||
|
|
||||||
|
def init_profile(
|
||||||
|
self,
|
||||||
|
output_dir: Optional[str],
|
||||||
|
start_step: Optional[int],
|
||||||
|
num_steps: Optional[int],
|
||||||
|
activities: Optional[List[str]],
|
||||||
|
with_stack: Optional[bool],
|
||||||
|
record_shapes: Optional[bool],
|
||||||
|
profile_by_stage: bool,
|
||||||
|
profile_id: str,
|
||||||
|
) -> ProfileReqOutput:
|
||||||
|
if self.profile_in_progress:
|
||||||
|
return ProfileReqOutput(
|
||||||
|
success=False,
|
||||||
|
message="Profiling is already in progress. Call /stop_profile first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.profile_by_stage = profile_by_stage
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
||||||
|
if activities is None:
|
||||||
|
activities = ["CPU", "GPU"]
|
||||||
|
|
||||||
|
self.torch_profiler_output_dir = output_dir
|
||||||
|
self.torch_profiler_with_stack = with_stack
|
||||||
|
self.torch_profiler_record_shapes = record_shapes
|
||||||
|
self.profiler_activities = activities
|
||||||
|
self.profile_id = profile_id
|
||||||
|
|
||||||
|
if start_step:
|
||||||
|
self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)
|
||||||
|
|
||||||
|
if num_steps:
|
||||||
|
self.profile_steps = num_steps
|
||||||
|
if self.profile_by_stage:
|
||||||
|
self.profiler_target_prefill_ct = num_steps
|
||||||
|
self.profiler_target_decode_ct = num_steps
|
||||||
|
self.profiler_prefill_ct = 0
|
||||||
|
self.profiler_decode_ct = 0
|
||||||
|
elif start_step:
|
||||||
|
self.profiler_target_forward_ct = (
|
||||||
|
self.profiler_start_forward_ct + num_steps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
||||||
|
# The caller will be notified when reaching profiler_target_forward_ct
|
||||||
|
else:
|
||||||
|
self.profiler_target_forward_ct = None
|
||||||
|
|
||||||
|
return ProfileReqOutput(success=True, message="Succeeded")
|
||||||
|
|
||||||
|
def start_profile(
|
||||||
|
self, stage: Optional[ForwardMode] = None
|
||||||
|
) -> ProfileReqOutput | None:
|
||||||
|
stage_str = f" for {stage.__str__()}" if stage else ""
|
||||||
|
logger.info(
|
||||||
|
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
||||||
|
)
|
||||||
|
|
||||||
|
activities = self.profiler_activities
|
||||||
|
with_stack = self.torch_profiler_with_stack
|
||||||
|
record_shapes = self.torch_profiler_record_shapes
|
||||||
|
|
||||||
|
activity_map = {
|
||||||
|
"CPU": torch.profiler.ProfilerActivity.CPU,
|
||||||
|
"GPU": torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
}
|
||||||
|
torchprof_activities = [
|
||||||
|
activity_map[a] for a in activities if a in activity_map
|
||||||
|
]
|
||||||
|
|
||||||
|
if "RPD" in activities:
|
||||||
|
from rpdTracerControl import rpdTracerControl
|
||||||
|
|
||||||
|
rpdTracerControl.skipCreate()
|
||||||
|
|
||||||
|
self.rpd_profile_path = os.path.join(
|
||||||
|
self.torch_profiler_output_dir,
|
||||||
|
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from rocpd.schema import RocpdSchema
|
||||||
|
|
||||||
|
if os.path.exists("trace.rpd"):
|
||||||
|
os.unlink("trace.rpd")
|
||||||
|
schema = RocpdSchema()
|
||||||
|
connection = sqlite3.connect("trace.rpd")
|
||||||
|
schema.writeSchema(connection)
|
||||||
|
connection.commit()
|
||||||
|
del connection
|
||||||
|
torch.distributed.barrier(self.tp_cpu_group)
|
||||||
|
|
||||||
|
self.rpd_profiler = rpdTracerControl()
|
||||||
|
self.rpd_profiler.setPythonTrace(True)
|
||||||
|
self.rpd_profiler.start()
|
||||||
|
self.rpd_profiler.rangePush("", "rpd profile range", "")
|
||||||
|
self.profile_in_progress = True
|
||||||
|
elif torchprof_activities:
|
||||||
|
self.torch_profiler = torch.profiler.profile(
|
||||||
|
activities=torchprof_activities,
|
||||||
|
with_stack=with_stack if with_stack is not None else True,
|
||||||
|
record_shapes=record_shapes if record_shapes is not None else False,
|
||||||
|
)
|
||||||
|
self.torch_profiler.start()
|
||||||
|
self.profile_in_progress = True
|
||||||
|
|
||||||
|
if "MEM" in activities:
|
||||||
|
torch.cuda.memory._record_memory_history(max_entries=100000)
|
||||||
|
self.profile_in_progress = True
|
||||||
|
|
||||||
|
if "CUDA_PROFILER" in activities:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
self.profile_in_progress = True
|
||||||
|
|
||||||
|
return ProfileReqOutput(success=True, message="Succeeded")
|
||||||
|
|
||||||
|
def stop_profile(
|
||||||
|
self, stage: Optional[ForwardMode] = None
|
||||||
|
) -> ProfileReqOutput | None:
|
||||||
|
if not self.profile_in_progress:
|
||||||
|
return ProfileReqOutput(
|
||||||
|
success=False,
|
||||||
|
message="Profiling is not in progress. Call /start_profile first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not Path(self.torch_profiler_output_dir).exists():
|
||||||
|
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
stage_suffix = f"-{stage.__str__()}" if stage else ""
|
||||||
|
logger.info("Stop profiling" + stage_suffix + "...")
|
||||||
|
if self.torch_profiler is not None:
|
||||||
|
self.torch_profiler.stop()
|
||||||
|
self.torch_profiler.export_chrome_trace(
|
||||||
|
os.path.join(
|
||||||
|
self.torch_profiler_output_dir,
|
||||||
|
self.profile_id
|
||||||
|
+ f"-TP-{self.tp_rank}"
|
||||||
|
+ stage_suffix
|
||||||
|
+ ".trace.json.gz",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
torch.distributed.barrier(self.tp_cpu_group)
|
||||||
|
|
||||||
|
if self.rpd_profiler is not None:
|
||||||
|
self.rpd_profiler.rangePop()
|
||||||
|
self.rpd_profiler.stop()
|
||||||
|
self.rpd_profiler.flush()
|
||||||
|
|
||||||
|
torch.distributed.barrier(self.tp_cpu_group)
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
from sglang.srt.utils import rpd_to_chrome_trace
|
||||||
|
|
||||||
|
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
|
||||||
|
self.rpd_profiler = None
|
||||||
|
self.rpd_profiler_path = None
|
||||||
|
|
||||||
|
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
|
||||||
|
memory_profile_path = os.path.join(
|
||||||
|
self.torch_profiler_output_dir,
|
||||||
|
str(time.time())
|
||||||
|
+ f"-TP-{self.tp_rank}-memory"
|
||||||
|
+ stage_suffix
|
||||||
|
+ ".pickle",
|
||||||
|
)
|
||||||
|
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
||||||
|
torch.cuda.memory._record_memory_history(enabled=None)
|
||||||
|
|
||||||
|
if "CUDA_PROFILER" in self.profiler_activities:
|
||||||
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Profiling done. Traces are saved to: %s",
|
||||||
|
self.torch_profiler_output_dir,
|
||||||
|
)
|
||||||
|
self.torch_profiler = None
|
||||||
|
self.profile_in_progress = False
|
||||||
|
self.profiler_start_forward_ct = None
|
||||||
|
|
||||||
|
return ProfileReqOutput(success=True, message="Succeeded.")
|
||||||
|
|
||||||
|
def _profile_batch_predicate(self, batch):
|
||||||
|
if self.profile_by_stage:
|
||||||
|
if batch.forward_mode.is_prefill():
|
||||||
|
if self.profiler_prefill_ct == 0:
|
||||||
|
self.start_profile(batch.forward_mode)
|
||||||
|
self.profiler_prefill_ct += 1
|
||||||
|
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
|
||||||
|
if self.profile_in_progress:
|
||||||
|
self.stop_profile(stage=ForwardMode.EXTEND)
|
||||||
|
elif batch.forward_mode.is_decode():
|
||||||
|
if self.profiler_decode_ct == 0:
|
||||||
|
if self.profile_in_progress:
|
||||||
|
# force trace flush
|
||||||
|
self.stop_profile(ForwardMode.EXTEND)
|
||||||
|
self.start_profile(batch.forward_mode)
|
||||||
|
self.profiler_decode_ct += 1
|
||||||
|
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
||||||
|
if self.profile_in_progress:
|
||||||
|
self.stop_profile(stage=ForwardMode.DECODE)
|
||||||
|
elif batch.forward_mode.is_idle():
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
|
||||||
|
else:
|
||||||
|
# Check profiler
|
||||||
|
if (
|
||||||
|
self.profiler_target_forward_ct
|
||||||
|
and self.profiler_target_forward_ct <= self.forward_ct
|
||||||
|
):
|
||||||
|
self.stop_profile()
|
||||||
|
if (
|
||||||
|
self.profiler_start_forward_ct
|
||||||
|
and self.profiler_start_forward_ct == self.forward_ct
|
||||||
|
):
|
||||||
|
self.start_profile()
|
||||||
|
|
||||||
|
def profile(self, recv_req: ProfileReq):
|
||||||
|
if recv_req.type == ProfileReqType.START_PROFILE:
|
||||||
|
if recv_req.profile_by_stage or recv_req.start_step:
|
||||||
|
return self.init_profile(
|
||||||
|
recv_req.output_dir,
|
||||||
|
recv_req.start_step,
|
||||||
|
recv_req.num_steps,
|
||||||
|
recv_req.activities,
|
||||||
|
recv_req.with_stack,
|
||||||
|
recv_req.record_shapes,
|
||||||
|
recv_req.profile_by_stage,
|
||||||
|
recv_req.profile_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.init_profile(
|
||||||
|
recv_req.output_dir,
|
||||||
|
recv_req.start_step,
|
||||||
|
recv_req.num_steps,
|
||||||
|
recv_req.activities,
|
||||||
|
recv_req.with_stack,
|
||||||
|
recv_req.record_shapes,
|
||||||
|
recv_req.profile_by_stage,
|
||||||
|
recv_req.profile_id,
|
||||||
|
)
|
||||||
|
return self.start_profile(True)
|
||||||
|
else:
|
||||||
|
return self.stop_profile()
|
||||||
142
python/sglang/srt/managers/scheduler_update_weights_mixin.py
Normal file
142
python/sglang/srt/managers/scheduler_update_weights_mixin.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
||||||
|
from sglang.srt.managers.io_struct import (
|
||||||
|
GetWeightsByNameReqInput,
|
||||||
|
GetWeightsByNameReqOutput,
|
||||||
|
InitWeightsUpdateGroupReqInput,
|
||||||
|
InitWeightsUpdateGroupReqOutput,
|
||||||
|
ReleaseMemoryOccupationReqInput,
|
||||||
|
ReleaseMemoryOccupationReqOutput,
|
||||||
|
ResumeMemoryOccupationReqInput,
|
||||||
|
ResumeMemoryOccupationReqOutput,
|
||||||
|
UpdateWeightFromDiskReqInput,
|
||||||
|
UpdateWeightFromDiskReqOutput,
|
||||||
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
UpdateWeightsFromDistributedReqOutput,
|
||||||
|
UpdateWeightsFromTensorReqInput,
|
||||||
|
UpdateWeightsFromTensorReqOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerUpdateWeightsMixin:
|
||||||
|
|
||||||
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||||
|
"""In-place update of the weights from disk."""
|
||||||
|
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
||||||
|
if success:
|
||||||
|
flush_cache_success = self.flush_cache()
|
||||||
|
assert flush_cache_success, "Cache flush failed after updating weights"
|
||||||
|
else:
|
||||||
|
logger.error(message)
|
||||||
|
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
||||||
|
|
||||||
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||||
|
"""Initialize the online model parameter update group."""
|
||||||
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||||
|
return InitWeightsUpdateGroupReqOutput(success, message)
|
||||||
|
|
||||||
|
def update_weights_from_distributed(
|
||||||
|
self,
|
||||||
|
recv_req: UpdateWeightsFromDistributedReqInput,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
"""Update the online model parameter."""
|
||||||
|
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
||||||
|
if success:
|
||||||
|
if recv_req.flush_cache:
|
||||||
|
flush_cache_success = self.flush_cache()
|
||||||
|
assert flush_cache_success, "Cache flush failed after updating weights"
|
||||||
|
else:
|
||||||
|
logger.error(message)
|
||||||
|
return UpdateWeightsFromDistributedReqOutput(success, message)
|
||||||
|
|
||||||
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||||
|
"""Update the online model parameter from tensors."""
|
||||||
|
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
||||||
|
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
||||||
|
if success:
|
||||||
|
if recv_req.flush_cache:
|
||||||
|
flush_cache_success = self.flush_cache()
|
||||||
|
assert flush_cache_success, "Cache flush failed after updating weights"
|
||||||
|
else:
|
||||||
|
logger.error(message)
|
||||||
|
torch.distributed.barrier(group=self.tp_cpu_group)
|
||||||
|
return UpdateWeightsFromTensorReqOutput(success, message)
|
||||||
|
|
||||||
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||||
|
return GetWeightsByNameReqOutput(parameter)
|
||||||
|
|
||||||
|
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
||||||
|
tags = recv_req.tags
|
||||||
|
|
||||||
|
if tags is None or len(tags) == 0:
|
||||||
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||||
|
|
||||||
|
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||||
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
||||||
|
self.flush_cache()
|
||||||
|
|
||||||
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||||
|
self.stashed_model_static_state = _export_static_state(
|
||||||
|
self.tp_worker.worker.model_runner.model
|
||||||
|
)
|
||||||
|
torch.distributed.barrier(self.tp_cpu_group)
|
||||||
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
||||||
|
|
||||||
|
return ReleaseMemoryOccupationReqOutput()
|
||||||
|
|
||||||
|
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
||||||
|
tags = recv_req.tags
|
||||||
|
|
||||||
|
if tags is None or len(tags) == 0:
|
||||||
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
||||||
|
|
||||||
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||||
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
||||||
|
torch.distributed.barrier(self.tp_cpu_group)
|
||||||
|
_import_static_state(
|
||||||
|
self.tp_worker.worker.model_runner.model,
|
||||||
|
self.stashed_model_static_state,
|
||||||
|
)
|
||||||
|
del self.stashed_model_static_state
|
||||||
|
|
||||||
|
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
||||||
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
||||||
|
|
||||||
|
return ResumeMemoryOccupationReqOutput()
|
||||||
|
|
||||||
|
def save_remote_model(self, params):
|
||||||
|
url = params["url"]
|
||||||
|
|
||||||
|
worker = self.tp_worker.worker
|
||||||
|
|
||||||
|
worker.model_runner.save_remote_model(url)
|
||||||
|
|
||||||
|
def save_sharded_model(self, params):
|
||||||
|
worker = self.tp_worker.worker
|
||||||
|
|
||||||
|
worker.model_runner.save_sharded_model(
|
||||||
|
path=params["path"],
|
||||||
|
pattern=params["pattern"],
|
||||||
|
max_size=params["max_size"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _export_static_state(model):
|
||||||
|
return dict(
|
||||||
|
buffers=[
|
||||||
|
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _import_static_state(model, static_params):
|
||||||
|
self_named_buffers = dict(model.named_buffers())
|
||||||
|
for name, tensor in static_params["buffers"]:
|
||||||
|
self_named_buffers[name][...] = tensor
|
||||||
@@ -170,16 +170,6 @@ class ReqState:
|
|||||||
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
|
||||||
is_cross_node = server_args.dist_init_addr
|
|
||||||
|
|
||||||
if is_cross_node:
|
|
||||||
# Fallback to default CPU transport for multi-node
|
|
||||||
return "default"
|
|
||||||
else:
|
|
||||||
return "cuda_ipc"
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizerManager:
|
class TokenizerManager:
|
||||||
"""TokenizerManager is a process that tokenizes the text."""
|
"""TokenizerManager is a process that tokenizes the text."""
|
||||||
|
|
||||||
@@ -199,16 +189,6 @@ class TokenizerManager:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.crash_dump_folder = server_args.crash_dump_folder
|
self.crash_dump_folder = server_args.crash_dump_folder
|
||||||
self.crash_dump_performed = False # Flag to ensure dump is only called once
|
|
||||||
|
|
||||||
# Init inter-process communication
|
|
||||||
context = zmq.asyncio.Context(2)
|
|
||||||
self.recv_from_detokenizer = get_zmq_socket(
|
|
||||||
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
|
||||||
)
|
|
||||||
self.send_to_scheduler = get_zmq_socket(
|
|
||||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Read model args
|
# Read model args
|
||||||
self.model_path = server_args.model_path
|
self.model_path = server_args.model_path
|
||||||
@@ -218,8 +198,7 @@ class TokenizerManager:
|
|||||||
self.is_image_gen = self.model_config.is_image_gen
|
self.is_image_gen = self.model_config.is_image_gen
|
||||||
self.context_len = self.model_config.context_len
|
self.context_len = self.model_config.context_len
|
||||||
self.image_token_id = self.model_config.image_token_id
|
self.image_token_id = self.model_config.image_token_id
|
||||||
self._updating = False
|
self.max_req_input_len = None # Will be set later in engine.py
|
||||||
self._cond = asyncio.Condition()
|
|
||||||
|
|
||||||
if self.model_config.is_multimodal:
|
if self.model_config.is_multimodal:
|
||||||
import_processors()
|
import_processors()
|
||||||
@@ -258,39 +237,57 @@ class TokenizerManager:
|
|||||||
revision=server_args.revision,
|
revision=server_args.revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
# Init inter-process communication
|
||||||
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
context = zmq.asyncio.Context(2)
|
||||||
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
self.recv_from_detokenizer = get_zmq_socket(
|
||||||
# to internally used unique LoRA IDs.
|
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
||||||
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
|
)
|
||||||
|
self.send_to_scheduler = get_zmq_socket(
|
||||||
|
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
||||||
|
)
|
||||||
|
|
||||||
# Store states
|
# Request states
|
||||||
self.no_create_loop = False
|
self.no_create_loop = False
|
||||||
self.rid_to_state: Dict[str, ReqState] = {}
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
|
self.asyncio_tasks = set()
|
||||||
|
|
||||||
|
# Health check
|
||||||
self.health_check_failed = False
|
self.health_check_failed = False
|
||||||
self.gracefully_exit = False
|
self.gracefully_exit = False
|
||||||
self.last_receive_tstamp = 0
|
self.last_receive_tstamp = 0
|
||||||
|
|
||||||
|
# Dumping
|
||||||
self.dump_requests_folder = "" # By default do not dump
|
self.dump_requests_folder = "" # By default do not dump
|
||||||
self.dump_requests_threshold = 1000
|
self.dump_requests_threshold = 1000
|
||||||
self.dump_request_list: List[Tuple] = []
|
self.dump_request_list: List[Tuple] = []
|
||||||
self.crash_dump_request_list: deque[Tuple] = deque()
|
|
||||||
self.log_request_metadata = self.get_log_request_metadata()
|
self.log_request_metadata = self.get_log_request_metadata()
|
||||||
self.session_futures = {} # session_id -> asyncio event
|
self.crash_dump_request_list: deque[Tuple] = deque()
|
||||||
self.max_req_input_len = None
|
self.crash_dump_performed = False # Flag to ensure dump is only called once
|
||||||
self.asyncio_tasks = set()
|
|
||||||
|
|
||||||
|
# Session
|
||||||
|
self.session_futures = {} # session_id -> asyncio event
|
||||||
|
|
||||||
|
# Weight updates
|
||||||
# The event to notify the weight sync is finished.
|
# The event to notify the weight sync is finished.
|
||||||
self.model_update_lock = RWLock()
|
self.model_update_lock = RWLock()
|
||||||
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
self._is_updating = False
|
||||||
|
self._is_updating_cond = asyncio.Condition()
|
||||||
|
|
||||||
|
# LoRA
|
||||||
|
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
||||||
|
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
||||||
|
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
||||||
|
# to internally used unique LoRA IDs.
|
||||||
|
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
|
||||||
# Lock to serialize LoRA update operations.
|
# Lock to serialize LoRA update operations.
|
||||||
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
||||||
# LoRA updates and inference to overlap.
|
# LoRA updates and inference to overlap.
|
||||||
self.lora_update_lock = asyncio.Lock()
|
self.lora_update_lock = asyncio.Lock()
|
||||||
|
|
||||||
# For pd disaggregtion
|
# For PD disaggregtion
|
||||||
self.disaggregation_mode = DisaggregationMode(
|
self.disaggregation_mode = DisaggregationMode(
|
||||||
self.server_args.disaggregation_mode
|
self.server_args.disaggregation_mode
|
||||||
)
|
)
|
||||||
@@ -458,17 +455,11 @@ class TokenizerManager:
|
|||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
):
|
):
|
||||||
created_time = time.time()
|
created_time = time.time()
|
||||||
async with self._cond:
|
|
||||||
await self._cond.wait_for(lambda: not self._updating)
|
|
||||||
|
|
||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
obj.normalize_batch_and_arguments()
|
obj.normalize_batch_and_arguments()
|
||||||
|
|
||||||
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
async with self._is_updating_cond:
|
||||||
raise ValueError(
|
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
|
||||||
"This model does not appear to be an embedding model by default. "
|
|
||||||
"Please add `--is-embedding` when launching the server or try another model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
max_length, skip_names, _ = self.log_request_metadata
|
max_length, skip_names, _ = self.log_request_metadata
|
||||||
@@ -567,6 +558,12 @@ class TokenizerManager:
|
|||||||
f"model's context length ({self.context_len} tokens)."
|
f"model's context length ({self.context_len} tokens)."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||||
|
raise ValueError(
|
||||||
|
"This model does not appear to be an embedding model by default. "
|
||||||
|
"Please add `--is-embedding` when launching the server or try another model."
|
||||||
|
)
|
||||||
|
|
||||||
# Check total tokens (input + max_new_tokens)
|
# Check total tokens (input + max_new_tokens)
|
||||||
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
||||||
if (
|
if (
|
||||||
@@ -959,14 +956,14 @@ class TokenizerManager:
|
|||||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
||||||
|
|
||||||
async def pause_generation(self):
|
async def pause_generation(self):
|
||||||
async with self._cond:
|
async with self._is_updating_cond:
|
||||||
self._updating = True
|
self._is_updating = True
|
||||||
self.abort_request(abort_all=True)
|
self.abort_request(abort_all=True)
|
||||||
|
|
||||||
async def continue_generation(self):
|
async def continue_generation(self):
|
||||||
async with self._cond:
|
async with self._is_updating_cond:
|
||||||
self._updating = False
|
self._is_updating = False
|
||||||
self._cond.notify_all()
|
self._is_updating_cond.notify_all()
|
||||||
|
|
||||||
async def update_weights_from_disk(
|
async def update_weights_from_disk(
|
||||||
self,
|
self,
|
||||||
@@ -1208,14 +1205,6 @@ class TokenizerManager:
|
|||||||
# Many DP ranks
|
# Many DP ranks
|
||||||
return [res.internal_state for res in responses]
|
return [res.internal_state for res in responses]
|
||||||
|
|
||||||
async def get_load(self) -> dict:
|
|
||||||
# TODO(lsyin): fake load report server
|
|
||||||
if not self.current_load_lock.locked():
|
|
||||||
async with self.current_load_lock:
|
|
||||||
internal_state = await self.get_internal_state()
|
|
||||||
self.current_load = internal_state[0]["load"]
|
|
||||||
return {"load": self.current_load}
|
|
||||||
|
|
||||||
async def set_internal_state(
|
async def set_internal_state(
|
||||||
self, obj: SetInternalStateReq
|
self, obj: SetInternalStateReq
|
||||||
) -> SetInternalStateReqOutput:
|
) -> SetInternalStateReqOutput:
|
||||||
@@ -1224,6 +1213,14 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
return [res.internal_state for res in responses]
|
return [res.internal_state for res in responses]
|
||||||
|
|
||||||
|
async def get_load(self) -> dict:
|
||||||
|
# TODO(lsyin): fake load report server
|
||||||
|
if not self.current_load_lock.locked():
|
||||||
|
async with self.current_load_lock:
|
||||||
|
internal_state = await self.get_internal_state()
|
||||||
|
self.current_load = internal_state[0]["load"]
|
||||||
|
return {"load": self.current_load}
|
||||||
|
|
||||||
def get_log_request_metadata(self):
|
def get_log_request_metadata(self):
|
||||||
max_length = None
|
max_length = None
|
||||||
skip_names = None
|
skip_names = None
|
||||||
@@ -1343,11 +1340,24 @@ class TokenizerManager:
|
|||||||
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
|
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
|
|
||||||
self.crash_dump_performed = True
|
|
||||||
if not self.crash_dump_folder:
|
if not self.crash_dump_folder:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
|
||||||
|
self.crash_dump_performed = True
|
||||||
|
|
||||||
|
# Check if NFS directory is available
|
||||||
|
# expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
|
||||||
|
# use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
|
||||||
|
# expected_nfs_dir, os.W_OK
|
||||||
|
# )
|
||||||
|
use_nfs_dir = False
|
||||||
|
if not use_nfs_dir:
|
||||||
|
logger.error(
|
||||||
|
f"Expected NFS directory is not available or writable. Uploading to GCS."
|
||||||
|
)
|
||||||
|
|
||||||
data_to_dump = []
|
data_to_dump = []
|
||||||
if self.crash_dump_request_list:
|
if self.crash_dump_request_list:
|
||||||
data_to_dump.extend(self.crash_dump_request_list)
|
data_to_dump.extend(self.crash_dump_request_list)
|
||||||
@@ -1357,7 +1367,12 @@ class TokenizerManager:
|
|||||||
for rid, state in self.rid_to_state.items():
|
for rid, state in self.rid_to_state.items():
|
||||||
if not state.finished:
|
if not state.finished:
|
||||||
unfinished_requests.append(
|
unfinished_requests.append(
|
||||||
(state.obj, {}, state.created_time, time.time())
|
(
|
||||||
|
state.obj,
|
||||||
|
state.out_list[-1] if state.out_list else {},
|
||||||
|
state.created_time,
|
||||||
|
time.time(),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if unfinished_requests:
|
if unfinished_requests:
|
||||||
data_to_dump.extend(unfinished_requests)
|
data_to_dump.extend(unfinished_requests)
|
||||||
@@ -1365,10 +1380,11 @@ class TokenizerManager:
|
|||||||
if not data_to_dump:
|
if not data_to_dump:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
|
||||||
filename = os.path.join(
|
filename = os.path.join(
|
||||||
self.crash_dump_folder,
|
self.crash_dump_folder,
|
||||||
os.getenv("HOSTNAME", None),
|
os.getenv("HOSTNAME", None),
|
||||||
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
|
object_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||||
@@ -1383,6 +1399,24 @@ class TokenizerManager:
|
|||||||
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
|
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _upload_file_to_gcs(bucket_name, source_file_path, object_name):
|
||||||
|
from google.cloud import storage
|
||||||
|
|
||||||
|
client = storage.Client()
|
||||||
|
bucket = client.bucket(bucket_name)
|
||||||
|
blob = bucket.blob(object_name)
|
||||||
|
blob.upload_from_filename(source_file_path, if_generation_match=0)
|
||||||
|
logger.error(
|
||||||
|
f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not use_nfs_dir:
|
||||||
|
_upload_file_to_gcs(
|
||||||
|
"sglang_crash_dump",
|
||||||
|
filename,
|
||||||
|
os.getenv("HOSTNAME", None) + "/" + object_name,
|
||||||
|
)
|
||||||
|
|
||||||
async def sigterm_watchdog(self):
|
async def sigterm_watchdog(self):
|
||||||
while not self.gracefully_exit:
|
while not self.gracefully_exit:
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
@@ -1426,7 +1460,7 @@ class TokenizerManager:
|
|||||||
while True:
|
while True:
|
||||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||||
self._result_dispatcher(recv_obj)
|
self._result_dispatcher(recv_obj)
|
||||||
self.last_receive_tstamp = time.perf_counter()
|
self.last_receive_tstamp = time.time()
|
||||||
|
|
||||||
def _handle_batch_output(
|
def _handle_batch_output(
|
||||||
self,
|
self,
|
||||||
@@ -1697,24 +1731,13 @@ class TokenizerManager:
|
|||||||
self.dump_requests_folder,
|
self.dump_requests_folder,
|
||||||
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
|
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
|
||||||
)
|
)
|
||||||
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
|
self._dump_data_to_file(
|
||||||
|
data_list=self.dump_request_list,
|
||||||
to_dump = self.dump_request_list
|
filename=filename,
|
||||||
|
log_message=f"Dump {len(self.dump_request_list)} requests to {filename}",
|
||||||
|
)
|
||||||
self.dump_request_list = []
|
self.dump_request_list = []
|
||||||
|
|
||||||
to_dump_with_server_args = {
|
|
||||||
"server_args": self.server_args,
|
|
||||||
"requests": to_dump,
|
|
||||||
}
|
|
||||||
|
|
||||||
def background_task():
|
|
||||||
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
|
||||||
with open(filename, "wb") as f:
|
|
||||||
pickle.dump(to_dump_with_server_args, f)
|
|
||||||
|
|
||||||
# Schedule the task to run in the background without awaiting it
|
|
||||||
asyncio.create_task(asyncio.to_thread(background_task))
|
|
||||||
|
|
||||||
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
|
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
self.crash_dump_request_list.append(
|
self.crash_dump_request_list.append(
|
||||||
@@ -1727,6 +1750,22 @@ class TokenizerManager:
|
|||||||
):
|
):
|
||||||
self.crash_dump_request_list.popleft()
|
self.crash_dump_request_list.popleft()
|
||||||
|
|
||||||
|
def _dump_data_to_file(
|
||||||
|
self, data_list: List[Tuple], filename: str, log_message: str
|
||||||
|
):
|
||||||
|
logger.info(log_message)
|
||||||
|
to_dump_with_server_args = {
|
||||||
|
"server_args": self.server_args,
|
||||||
|
"requests": data_list.copy(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def background_task():
|
||||||
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||||
|
with open(filename, "wb") as f:
|
||||||
|
pickle.dump(to_dump_with_server_args, f)
|
||||||
|
|
||||||
|
asyncio.create_task(asyncio.to_thread(background_task))
|
||||||
|
|
||||||
def _handle_abort_req(self, recv_obj):
|
def _handle_abort_req(self, recv_obj):
|
||||||
state = self.rid_to_state[recv_obj.rid]
|
state = self.rid_to_state[recv_obj.rid]
|
||||||
state.finished = True
|
state.finished = True
|
||||||
@@ -1862,6 +1901,16 @@ class TokenizerManager:
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
||||||
|
is_cross_node = server_args.dist_init_addr
|
||||||
|
|
||||||
|
if is_cross_node:
|
||||||
|
# Fallback to default CPU transport for multi-node
|
||||||
|
return "default"
|
||||||
|
else:
|
||||||
|
return "cuda_ipc"
|
||||||
|
|
||||||
|
|
||||||
async def print_exception_wrapper(func):
|
async def print_exception_wrapper(func):
|
||||||
"""
|
"""
|
||||||
Sometimes an asyncio function does not print exception.
|
Sometimes an asyncio function does not print exception.
|
||||||
|
|||||||
@@ -2071,6 +2071,9 @@ class PortArgs:
|
|||||||
|
|
||||||
dist_init_host, dist_init_port = dist_init_addr
|
dist_init_host, dist_init_port = dist_init_addr
|
||||||
port_base = int(dist_init_port) + 1
|
port_base = int(dist_init_port) + 1
|
||||||
|
detokenizer_port = port_base + 1
|
||||||
|
rpc_port = port_base + 2
|
||||||
|
metrics_ipc_name = port_base + 3
|
||||||
if dp_rank is None:
|
if dp_rank is None:
|
||||||
# TokenizerManager to DataParallelController
|
# TokenizerManager to DataParallelController
|
||||||
scheduler_input_port = port_base + 4
|
scheduler_input_port = port_base + 4
|
||||||
@@ -2080,10 +2083,10 @@ class PortArgs:
|
|||||||
return PortArgs(
|
return PortArgs(
|
||||||
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
||||||
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
||||||
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
detokenizer_ipc_name=f"tcp://{dist_init_host}:{detokenizer_port}",
|
||||||
nccl_port=nccl_port,
|
nccl_port=nccl_port,
|
||||||
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
|
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
|
||||||
metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}",
|
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -291,17 +291,6 @@ def find_printable_text(text: str):
|
|||||||
return text[: text.rfind(" ") + 1]
|
return text[: text.rfind(" ") + 1]
|
||||||
|
|
||||||
|
|
||||||
def graceful_registry(sub_module_name: str):
|
|
||||||
def graceful_shutdown(signum, frame):
|
|
||||||
logger.info(
|
|
||||||
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
|
|
||||||
)
|
|
||||||
if signum == signal.SIGTERM:
|
|
||||||
logger.info(f"{sub_module_name} receive sigterm")
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, graceful_shutdown)
|
|
||||||
|
|
||||||
|
|
||||||
class LazyImport:
|
class LazyImport:
|
||||||
"""Lazy import to make `import sglang` run faster."""
|
"""Lazy import to make `import sglang` run faster."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user