diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 310e92c23..098a3d1e3 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -57,6 +57,7 @@ from sglang.srt.utils import ( assert_pkg_version, configure_logger, kill_process_tree, + launch_dummy_health_check_server, maybe_set_triton_cache_manager, prepare_model_and_tokenizer, set_prometheus_multiproc_dir, @@ -400,14 +401,16 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": # When using `Engine` as a Python API, we don't want to block here. - return + return None, None + + launch_dummy_health_check_server(server_args.host, server_args.port) for proc in scheduler_procs: proc.join() logger.error( f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" ) - return + return None, None # Launch detokenizer process detoken_proc = mp.Process( diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 65efa0feb..36b87ca0b 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE + from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP + _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info( enable_dp_attention, tp_rank, tp_size, dp_size ) @@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): ], tp_rank, torch.distributed.get_backend(tp_group.device_group), - False, + SYNC_TOKEN_IDS_ACROSS_TP, False, False, False, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 972f9595b..a8ded73bc 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -201,6 +201,7 @@ class DetokenizerManager: prompt_tokens=recv_obj.prompt_tokens, completion_tokens=recv_obj.completion_tokens, cached_tokens=recv_obj.cached_tokens, + spec_verify_ct=recv_obj.spec_verify_ct, input_token_logprobs_val=recv_obj.input_token_logprobs_val, input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, output_token_logprobs_val=recv_obj.output_token_logprobs_val, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index eee9b6722..a2f25abc2 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -354,10 +354,13 @@ class BatchTokenIDOut: skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] no_stop_trim: List[bool] + # Token counts prompt_tokens: List[int] completion_tokens: List[int] cached_tokens: List[int] + spec_verify_ct: List[int] + # Logprobs input_token_logprobs_val: List[float] input_token_logprobs_idx: List[int] @@ -382,6 +385,7 @@ class BatchStrOut: prompt_tokens: List[int] completion_tokens: List[int] cached_tokens: List[int] + spec_verify_ct: List[int] # Logprobs input_token_logprobs_val: List[float] diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2a342c5df..bdf780e4f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -252,7 +252,6 @@ class Req: # Sampling info self.sampling_params = sampling_params - self.lora_path = lora_path self.custom_logit_processor = custom_logit_processor # Memory pool info @@ -300,7 +299,7 @@ class Req: self.logprob_start_len = 0 self.top_logprobs_num = top_logprobs_num - # Logprobs (return value) + # Logprobs (return values) self.input_token_logprobs_val: Optional[List[float]] = None self.input_token_logprobs_idx: Optional[List[int]] = None self.input_top_logprobs_val: Optional[List[float]] = None @@ -329,10 +328,15 @@ class Req: # Constrained decoding self.grammar: Optional[BaseGrammarObject] = None - # The number of cached tokens, that were already cached in the KV cache + # The number of cached tokens that were already cached in the KV cache self.cached_tokens = 0 self.already_computed = 0 + # The number of verification forward passes in the speculative decoding. + # This is used to compute the average acceptance length per request. + self.spec_verify_ct = 0 + self.lora_path = lora_path + def extend_image_inputs(self, image_inputs): if self.image_inputs is None: self.image_inputs = image_inputs diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9cfa14c30..3e354a971 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -281,6 +281,7 @@ class Scheduler: # Print debug info logger.info( f"max_total_num_tokens={self.max_total_num_tokens}, " + f"chunked_prefill_size={server_args.chunked_prefill_size}, " f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_running_requests={self.max_running_requests}, " f"context_len={self.model_config.context_len}" @@ -408,6 +409,11 @@ class Scheduler: }, ) + # The largest prefill length of a single request + self._largest_prefill_len: int = 0 + # The largest context length (prefill + generation) of a single request + self._largest_prefill_decode_len: int = 0 + # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( [ @@ -1371,6 +1377,7 @@ class Scheduler: prompt_tokens = [] completion_tokens = [] cached_tokens = [] + spec_verify_ct = [] if return_logprob: input_token_logprobs_val = [] @@ -1424,6 +1431,9 @@ class Scheduler: completion_tokens.append(len(req.output_ids)) cached_tokens.append(req.cached_tokens) + if not self.spec_algorithm.is_none(): + spec_verify_ct.append(req.spec_verify_ct) + if return_logprob: input_token_logprobs_val.append(req.input_token_logprobs_val) input_token_logprobs_idx.append(req.input_token_logprobs_idx) @@ -1451,6 +1461,7 @@ class Scheduler: prompt_tokens, completion_tokens, cached_tokens, + spec_verify_ct, input_token_logprobs_val, input_token_logprobs_idx, output_token_logprobs_val, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 2be2e532d..53e1f4eda 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -785,6 +785,9 @@ class TokenizerManager: i, ) + if self.server_args.speculative_algorithm: + meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] + if not isinstance(recv_obj, BatchEmbeddingOut): meta_info.update( { @@ -809,6 +812,7 @@ class TokenizerManager: "embedding": recv_obj.embeddings[i], "meta_info": meta_info, } + state.out_list.append(out_dict) state.finished = recv_obj.finished_reasons[i] is not None state.event.set() diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 169b64343..93b4d0ea5 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner -def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): +def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): for sub in model._modules.values(): if isinstance(sub, CustomOp): if reverse: @@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): else: # NOTE: Temporarily workaround MoE if "FusedMoE" in sub.__class__.__name__: - if batch_size == 1: + if num_tokens == 1: # The performance of torch.compile on this layer is not always good when bs > 1, # so we decide to only use torch.compile when bs =1 sub._forward_method = fused_moe_forward_native @@ -55,14 +55,14 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): sub._forward_method = sub.forward_native setattr(sub, "is_torch_compile", True) if isinstance(sub, torch.nn.Module): - _to_torch(sub, reverse, batch_size) + _to_torch(sub, reverse, num_tokens) @contextmanager def patch_model( model: torch.nn.Module, enable_compile: bool, - batch_size: int, + num_tokens: int, tp_group: GroupCoordinator, ): """Patch the model to make it compatible with with torch.compile""" @@ -70,7 +70,7 @@ def patch_model( try: if enable_compile: - _to_torch(model, reverse=False, batch_size=batch_size) + _to_torch(model, reverse=False, num_tokens=num_tokens) backup_ca_comm = tp_group.ca_comm # Use custom-allreduce here. # We found the custom allreduce is much faster than the built-in allreduce in torch, @@ -85,7 +85,7 @@ def patch_model( yield model.forward finally: if enable_compile: - _to_torch(model, reverse=True, batch_size=batch_size) + _to_torch(model, reverse=True, num_tokens=num_tokens) tp_group.ca_comm = backup_ca_comm @@ -283,8 +283,8 @@ class CudaGraphRunner: with patch_model( self.model_runner.model, bs in self.compile_bs, - bs, - self.model_runner.tp_group, + num_tokens=bs * self.num_tokens_per_bs, + tp_group=self.model_runner.tp_group, ) as forward: ( graph, diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 049ba2275..97cdb2640 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -603,6 +603,7 @@ class EagleVerifyInput(SpecInfo): if not req.finished(): new_accept_index.extend(new_accept_index_) unfinished_index.append(i) + req.spec_verify_ct += 1 accept_length = (accept_index != -1).sum(dim=1) - 1 accept_index = accept_index[accept_index != -1] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f1d57e906..0568f0fd4 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -14,6 +14,7 @@ """Common utilities.""" import base64 +import ctypes import dataclasses import io import ipaddress @@ -29,6 +30,7 @@ import shutil import signal import socket import subprocess +import sys import tempfile import time import warnings @@ -59,7 +61,6 @@ from triton.runtime.cache import ( default_dump_dir, default_override_dir, ) -from uvicorn.config import LOGGING_CONFIG logger = logging.getLogger(__name__) @@ -1366,7 +1367,33 @@ def nullable_str(val: str): return val +def pyspy_dump_schedulers(): + """py-spy dump on all scheduler in a local node.""" + try: + pid = psutil.Process().pid + # Command to run py-spy with the PID + cmd = f"py-spy dump --pid {pid}" + result = subprocess.run( + cmd, shell=True, capture_output=True, text=True, check=True + ) + logger.info(f"Profile for PID {pid}:\n{result.stdout}") + except subprocess.CalledProcessError as e: + logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}") + + +def kill_itself_when_parent_died(): + if sys.platform == "linux": + # sigkill this process when parent worker manager dies + PR_SET_PDEATHSIG = 1 + libc = ctypes.CDLL("libc.so.6") + libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL) + else: + logger.warninig("kill_itself_when_parent_died is only supported in linux.") + + def set_uvicorn_logging_configs(): + from uvicorn.config import LOGGING_CONFIG + LOGGING_CONFIG["formatters"]["default"][ "fmt" ] = "[%(asctime)s] %(levelprefix)s %(message)s" @@ -1449,3 +1476,28 @@ def rank0_print(msg: str): if get_tensor_model_parallel_rank() == 0: print(msg, flush=True) + + +def launch_dummy_health_check_server(host, port): + import uvicorn + from fastapi import FastAPI, Response + + app = FastAPI() + + @app.get("/health") + async def health(): + """Check the health of the http server.""" + return Response(status_code=200) + + @app.get("/health_generate") + async def health_generate(): + """Check the health of the http server.""" + return Response(status_code=200) + + uvicorn.run( + app, + host=host, + port=port, + timeout_keep_alive=5, + loop="uvloop", + )