2024-11-22 22:16:53 +08:00
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-09-29 02:36:12 -07:00
""" A scheduler that manages a tensor parallel GPU worker. """
2025-01-13 01:39:14 -08:00
import faulthandler
2024-09-29 02:36:12 -07:00
import logging
2024-09-29 17:42:45 -07:00
import os
2024-11-28 00:22:39 -08:00
import signal
2025-03-03 00:12:04 -08:00
import sys
2024-10-27 02:00:50 -07:00
import threading
2024-09-29 17:42:45 -07:00
import time
import warnings
2025-03-03 00:12:04 -08:00
from collections import defaultdict , deque
2024-11-12 21:17:38 -08:00
from concurrent import futures
2025-01-16 12:51:11 -08:00
from dataclasses import dataclass
2025-01-18 19:37:30 -08:00
from http import HTTPStatus
2024-10-12 21:35:30 -07:00
from types import SimpleNamespace
2025-01-16 12:51:11 -08:00
from typing import Dict , List , Optional , Tuple , Union
2024-09-29 02:36:12 -07:00
2024-11-28 00:22:39 -08:00
import psutil
2024-12-08 01:06:15 -08:00
import setproctitle
2024-09-29 17:42:45 -07:00
import torch
2024-09-29 02:36:12 -07:00
import zmq
2025-03-14 15:40:44 +08:00
from torch . distributed import barrier
2024-09-29 02:36:12 -07:00
2024-09-29 17:42:45 -07:00
from sglang . global_config import global_config
2024-11-24 04:47:10 -08:00
from sglang . srt . configs . model_config import ModelConfig
2025-01-19 17:10:29 -08:00
from sglang . srt . constrained . base_grammar_backend import create_grammar_backend
2025-03-21 14:47:47 -07:00
from sglang . srt . disaggregation . decode import (
DecodePreallocQueue ,
DecodeTransferQueue ,
SchedulerDisaggregationDecodeMixin ,
)
from sglang . srt . disaggregation . prefill import (
PrefillBootstrapQueue ,
SchedulerDisaggregationPrefillMixin ,
)
from sglang . srt . disaggregation . utils import (
DisaggregationMode ,
ReqToMetadataIdxAllocator ,
)
2024-09-29 17:42:45 -07:00
from sglang . srt . hf_transformers_utils import get_processor , get_tokenizer
2025-01-16 11:15:00 -08:00
from sglang . srt . layers . dp_attention import compute_dp_attention_world_info
2024-09-29 17:42:45 -07:00
from sglang . srt . layers . logits_processor import LogitsProcessorOutput
2025-03-27 04:21:25 +08:00
from sglang . srt . managers . expert_distribution import ExpertDistributionRecorder
2024-09-29 17:42:45 -07:00
from sglang . srt . managers . io_struct import (
AbortReq ,
2024-11-20 00:36:53 -08:00
CloseSessionReqInput ,
2025-03-24 21:34:19 -07:00
ExpertDistributionReq ,
2025-03-25 16:17:03 +08:00
ExpertDistributionReqOutput ,
2024-09-29 17:42:45 -07:00
FlushCacheReq ,
2025-03-03 00:12:04 -08:00
GetInternalStateReq ,
GetInternalStateReqOutput ,
2024-11-29 23:36:38 -08:00
GetWeightsByNameReqInput ,
GetWeightsByNameReqOutput ,
2025-03-03 00:12:04 -08:00
HealthCheckOutput ,
2024-12-01 23:23:18 -08:00
InitWeightsUpdateGroupReqInput ,
InitWeightsUpdateGroupReqOutput ,
2024-11-20 00:36:53 -08:00
OpenSessionReqInput ,
OpenSessionReqOutput ,
2024-10-11 17:34:25 +08:00
ProfileReq ,
2025-03-03 00:12:04 -08:00
ProfileReqOutput ,
ProfileReqType ,
2025-01-14 03:38:51 +08:00
ReleaseMemoryOccupationReqInput ,
ReleaseMemoryOccupationReqOutput ,
ResumeMemoryOccupationReqInput ,
ResumeMemoryOccupationReqOutput ,
2025-03-14 15:40:44 +08:00
RpcReqInput ,
RpcReqOutput ,
2025-03-03 00:12:04 -08:00
SetInternalStateReq ,
SetInternalStateReqOutput ,
2024-09-29 17:42:45 -07:00
TokenizedEmbeddingReqInput ,
TokenizedGenerateReqInput ,
2024-11-29 17:17:00 -08:00
UpdateWeightFromDiskReqInput ,
UpdateWeightFromDiskReqOutput ,
2024-12-01 23:23:18 -08:00
UpdateWeightsFromDistributedReqInput ,
UpdateWeightsFromDistributedReqOutput ,
2024-12-29 05:30:27 +08:00
UpdateWeightsFromTensorReqInput ,
UpdateWeightsFromTensorReqOutput ,
2024-09-29 17:42:45 -07:00
)
from sglang . srt . managers . schedule_batch import (
FINISH_ABORT ,
2025-03-25 11:08:40 +08:00
MultimodalInputs ,
2024-09-29 17:42:45 -07:00
Req ,
ScheduleBatch ,
2024-10-19 23:19:26 -07:00
global_server_args_dict ,
2024-09-29 17:42:45 -07:00
)
2024-10-04 18:00:18 -07:00
from sglang . srt . managers . schedule_policy import (
AddReqResult ,
PrefillAdder ,
SchedulePolicy ,
)
2025-03-12 16:21:49 -07:00
from sglang . srt . managers . scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin ,
)
2024-11-20 00:36:53 -08:00
from sglang . srt . managers . session_controller import Session
2024-09-30 02:41:11 -07:00
from sglang . srt . managers . tp_worker import TpModelWorker
2024-10-20 00:29:29 -07:00
from sglang . srt . managers . tp_worker_overlap_thread import TpModelWorkerClient
2025-03-27 04:21:25 +08:00
from sglang . srt . managers . utils import validate_input_length
2024-09-29 17:42:45 -07:00
from sglang . srt . mem_cache . chunk_cache import ChunkCache
2025-02-23 21:56:30 -08:00
from sglang . srt . mem_cache . hiradix_cache import HiRadixCache
2024-09-29 17:42:45 -07:00
from sglang . srt . mem_cache . radix_cache import RadixCache
2024-11-10 04:39:32 -08:00
from sglang . srt . metrics . collector import SchedulerMetricsCollector , SchedulerStats
2025-03-12 22:22:39 -07:00
from sglang . srt . model_executor . forward_batch_info import ForwardBatch , ForwardMode
2024-09-29 02:36:12 -07:00
from sglang . srt . server_args import PortArgs , ServerArgs
2025-01-02 02:09:08 -08:00
from sglang . srt . speculative . spec_info import SpeculativeAlgorithm
2025-01-13 14:24:00 -08:00
from sglang . srt . torch_memory_saver_adapter import TorchMemorySaverAdapter
2024-09-29 17:42:45 -07:00
from sglang . srt . utils import (
2025-03-17 13:54:16 +08:00
DynamicGradMode ,
2024-09-29 17:42:45 -07:00
broadcast_pyobj ,
configure_logger ,
2024-11-17 22:18:11 -08:00
crash_on_warnings ,
2024-11-27 02:52:46 -08:00
get_bool_env_var ,
2024-10-25 23:07:07 -07:00
get_zmq_socket ,
2025-03-12 22:22:39 -07:00
kill_itself_when_parent_died ,
2025-03-03 00:12:04 -08:00
pyspy_dump_schedulers ,
2024-11-27 02:52:46 -08:00
set_gpu_proc_affinity ,
2024-09-29 17:42:45 -07:00
set_random_seed ,
suppress_other_loggers ,
)
2025-01-19 12:13:27 +08:00
from sglang . utils import TypeBasedDispatcher , get_exception_traceback
2024-09-29 02:36:12 -07:00
2025-03-24 21:34:19 -07:00
expert_distribution_recorder = ExpertDistributionRecorder ( )
2024-09-29 02:36:12 -07:00
logger = logging . getLogger ( __name__ )
2024-12-29 00:45:57 -08:00
# Test retract decode for debugging purposes
2025-03-03 00:12:04 -08:00
TEST_RETRACT = get_bool_env_var ( " SGLANG_TEST_RETRACT " )
RECORD_STEP_TIME = get_bool_env_var ( " SGLANG_RECORD_STEP_TIME " )
2024-10-13 20:32:37 -07:00
2024-09-29 02:36:12 -07:00
2025-01-16 12:51:11 -08:00
@dataclass
class GenerationBatchResult :
logits_output : LogitsProcessorOutput
next_token_ids : List [ int ]
2025-03-03 00:12:04 -08:00
extend_input_len_per_req : List [ int ]
extend_logprob_start_len_per_req : List [ int ]
2025-01-16 12:51:11 -08:00
bid : int
@dataclass
class EmbeddingBatchResult :
embeddings : torch . Tensor
bid : int
2025-03-21 14:47:47 -07:00
class Scheduler (
SchedulerOutputProcessorMixin ,
SchedulerDisaggregationDecodeMixin ,
SchedulerDisaggregationPrefillMixin ,
) :
2024-09-29 02:36:12 -07:00
""" A scheduler that manages a tensor parallel GPU worker. """
def __init__ (
self ,
server_args : ServerArgs ,
port_args : PortArgs ,
gpu_id : int ,
tp_rank : int ,
2024-10-19 17:39:38 -07:00
dp_rank : Optional [ int ] ,
2024-09-29 02:36:12 -07:00
) :
# Parse args
2024-09-29 17:42:45 -07:00
self . server_args = server_args
2024-09-29 02:36:12 -07:00
self . tp_rank = tp_rank
self . tp_size = server_args . tp_size
2024-09-29 17:42:45 -07:00
self . schedule_policy = server_args . schedule_policy
self . lora_paths = server_args . lora_paths
self . max_loras_per_batch = server_args . max_loras_per_batch
2024-11-19 22:07:58 -08:00
self . enable_overlap = not server_args . disable_overlap_schedule
2024-10-25 18:51:59 -07:00
self . skip_tokenizer_init = server_args . skip_tokenizer_init
2024-11-10 04:39:32 -08:00
self . enable_metrics = server_args . enable_metrics
2025-03-03 00:12:04 -08:00
self . stream_interval = server_args . stream_interval
2025-01-02 02:09:08 -08:00
self . spec_algorithm = SpeculativeAlgorithm . from_string (
server_args . speculative_algorithm
)
2025-03-03 00:12:04 -08:00
self . gpu_id = gpu_id
self . enable_hierarchical_cache = server_args . enable_hierarchical_cache
2025-03-12 22:22:39 -07:00
self . page_size = server_args . page_size
2024-09-29 02:36:12 -07:00
2025-01-19 17:10:29 -08:00
# Distributed rank info
2025-01-16 11:15:00 -08:00
self . dp_size = server_args . dp_size
self . attn_tp_rank , self . attn_tp_size , self . dp_rank = (
compute_dp_attention_world_info (
server_args . enable_dp_attention ,
self . tp_rank ,
self . tp_size ,
self . dp_size ,
)
)
2025-01-19 17:10:29 -08:00
# Init inter-process communication
context = zmq . Context ( 2 )
2025-01-16 11:15:00 -08:00
if self . attn_tp_rank == 0 :
2024-10-25 23:07:07 -07:00
self . recv_from_tokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PULL , port_args . scheduler_input_ipc_name , False
2024-10-25 23:07:07 -07:00
)
2024-11-15 05:02:44 -08:00
self . send_to_tokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PUSH , port_args . tokenizer_ipc_name , False
2024-11-15 05:02:44 -08:00
)
2024-09-29 02:36:12 -07:00
2024-10-25 18:51:59 -07:00
if server_args . skip_tokenizer_init :
2025-02-24 14:47:59 -08:00
# Directly send to the TokenizerManager
2024-10-25 23:07:07 -07:00
self . send_to_detokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PUSH , port_args . tokenizer_ipc_name , False
2024-10-25 18:51:59 -07:00
)
else :
2024-12-29 00:45:57 -08:00
# Send to the DetokenizerManager
2024-10-25 23:07:07 -07:00
self . send_to_detokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PUSH , port_args . detokenizer_ipc_name , False
2024-10-25 18:51:59 -07:00
)
2025-03-14 15:40:44 +08:00
self . recv_from_rpc = get_zmq_socket (
context , zmq . DEALER , port_args . rpc_ipc_name , False
)
2024-09-29 02:36:12 -07:00
else :
2024-10-12 21:35:30 -07:00
self . recv_from_tokenizer = None
2025-03-14 15:40:44 +08:00
self . recv_from_rpc = None
2024-11-22 15:46:16 -08:00
self . send_to_tokenizer = SimpleNamespace ( send_pyobj = lambda x : None )
self . send_to_detokenizer = SimpleNamespace ( send_pyobj = lambda x : None )
2024-09-29 17:42:45 -07:00
# Init tokenizer
2025-03-06 00:13:20 -08:00
self . init_tokenizer ( )
2024-09-29 02:36:12 -07:00
2024-11-19 22:07:58 -08:00
# Check whether overlap can be enabled
if not self . is_generation :
self . enable_overlap = False
logger . info ( " Overlap scheduler is disabled for embedding models. " )
2024-11-27 23:44:33 -08:00
if self . model_config . is_multimodal :
self . enable_overlap = False
logger . info ( " Overlap scheduler is disabled for multimodal models. " )
2024-09-29 17:42:45 -07:00
# Launch a tensor parallel worker
2024-10-20 18:17:41 -07:00
if self . enable_overlap :
2024-10-20 00:29:29 -07:00
TpWorkerClass = TpModelWorkerClient
2024-10-19 23:19:26 -07:00
else :
TpWorkerClass = TpModelWorker
2024-10-20 00:29:29 -07:00
2024-10-19 23:19:26 -07:00
self . tp_worker = TpWorkerClass (
2024-10-19 17:39:38 -07:00
server_args = server_args ,
2024-09-29 02:36:12 -07:00
gpu_id = gpu_id ,
tp_rank = tp_rank ,
2024-10-19 17:39:38 -07:00
dp_rank = dp_rank ,
2024-10-11 07:22:48 -07:00
nccl_port = port_args . nccl_port ,
2024-09-29 02:36:12 -07:00
)
2024-09-29 17:42:45 -07:00
2025-03-03 00:12:04 -08:00
# Launch a draft worker for speculative decoding
2025-01-02 02:09:08 -08:00
if self . spec_algorithm . is_eagle ( ) :
from sglang . srt . speculative . eagle_worker import EAGLEWorker
self . draft_worker = EAGLEWorker (
gpu_id = gpu_id ,
tp_rank = tp_rank ,
server_args = server_args ,
nccl_port = port_args . nccl_port ,
target_worker = self . tp_worker ,
dp_rank = dp_rank ,
)
else :
self . draft_worker = None
2024-10-06 00:10:48 -07:00
# Get token and memory info from the model worker
2024-09-29 17:42:45 -07:00
(
self . max_total_num_tokens ,
self . max_prefill_tokens ,
self . max_running_requests ,
2024-10-21 16:12:04 -07:00
self . max_req_len ,
2024-09-29 17:42:45 -07:00
self . max_req_input_len ,
self . random_seed ,
2024-10-19 17:39:38 -07:00
self . device ,
2024-10-19 23:19:26 -07:00
worker_global_server_args_dict ,
_ ,
_ ,
_ ,
) = self . tp_worker . get_worker_info ( )
2024-10-19 17:39:38 -07:00
self . tp_cpu_group = self . tp_worker . get_tp_cpu_group ( )
2025-01-16 11:15:00 -08:00
self . attn_tp_cpu_group = self . tp_worker . get_attention_tp_cpu_group ( )
2024-10-19 17:39:38 -07:00
self . pad_input_ids_func = self . tp_worker . get_pad_input_ids_func ( )
2024-10-19 23:19:26 -07:00
global_server_args_dict . update ( worker_global_server_args_dict )
2024-09-29 17:42:45 -07:00
set_random_seed ( self . random_seed )
2025-03-03 00:12:04 -08:00
2024-09-29 17:42:45 -07:00
# Print debug info
logger . info (
f " max_total_num_tokens= { self . max_total_num_tokens } , "
2025-01-26 04:51:54 -08:00
f " chunked_prefill_size= { server_args . chunked_prefill_size } , "
2024-09-29 17:42:45 -07:00
f " max_prefill_tokens= { self . max_prefill_tokens } , "
f " max_running_requests= { self . max_running_requests } , "
f " context_len= { self . model_config . context_len } "
)
2025-03-12 22:22:39 -07:00
# Init memory pool and cache
2025-03-06 00:13:20 -08:00
self . init_memory_pool_and_cache ( )
2024-09-29 17:42:45 -07:00
# Init running status
self . waiting_queue : List [ Req ] = [ ]
2024-11-19 15:04:43 -08:00
# The running decoding batch for continuous batching
2025-03-12 22:22:39 -07:00
self . running_batch : ScheduleBatch = ScheduleBatch ( reqs = [ ] , batch_is_full = False )
2024-11-19 15:04:43 -08:00
# The current forward batch
2024-10-16 01:33:20 -07:00
self . cur_batch : Optional [ ScheduleBatch ] = None
2025-03-12 22:22:39 -07:00
# The last forward batch
2024-11-19 15:04:43 -08:00
self . last_batch : Optional [ ScheduleBatch ] = None
2024-10-27 02:00:50 -07:00
self . forward_ct = 0
self . forward_ct_decode = 0
2024-09-29 17:42:45 -07:00
self . num_generated_tokens = 0
2025-03-12 22:22:39 -07:00
self . num_prefill_tokens = 0
2024-11-10 04:39:32 -08:00
self . last_decode_stats_tic = time . time ( )
2025-03-12 22:22:39 -07:00
self . last_prefill_stats_tic = time . time ( )
2025-03-03 00:12:04 -08:00
self . return_health_check_ct = 0
2024-12-06 05:49:29 -08:00
self . current_stream = torch . get_device_module ( self . device ) . current_stream ( )
2025-01-17 13:22:53 +08:00
if self . device == " cpu " :
self . current_stream . synchronize = lambda : None # No-op for CPU
2024-12-06 05:49:29 -08:00
2025-03-06 00:13:20 -08:00
# Init session info
2024-12-22 06:25:57 -08:00
self . sessions : Dict [ str , Session ] = { }
2024-09-29 17:42:45 -07:00
# Init chunked prefill
self . chunked_prefill_size = server_args . chunked_prefill_size
2024-11-29 16:03:32 -08:00
if self . chunked_prefill_size < = 0 : # -1 means disable
self . chunked_prefill_size = None
2025-03-03 00:12:04 -08:00
self . chunked_req = None
2024-09-29 17:42:45 -07:00
self . is_mixed_chunk = (
self . chunked_prefill_size is not None and server_args . enable_mixed_chunk
)
2024-11-12 21:17:38 -08:00
# Init the grammar backend for constrained generation
2024-11-13 01:45:28 +09:00
self . grammar_queue : List [ Req ] = [ ]
2024-10-25 10:24:44 -07:00
if not server_args . skip_tokenizer_init :
2025-01-19 17:10:29 -08:00
self . grammar_backend = create_grammar_backend (
server_args , self . tokenizer , self . model_config . vocab_size
)
2024-11-12 21:17:38 -08:00
else :
self . grammar_backend = None
2024-09-29 17:42:45 -07:00
2025-03-06 00:13:20 -08:00
# Init schedule policy and new token estimation
2025-03-12 11:22:35 -07:00
self . policy = SchedulePolicy (
2025-03-12 22:22:39 -07:00
self . schedule_policy ,
self . tree_cache ,
self . enable_hierarchical_cache ,
2025-03-12 11:22:35 -07:00
)
2024-10-25 10:24:44 -07:00
assert (
server_args . schedule_conservativeness > = 0
) , " Invalid schedule_conservativeness "
2024-10-26 16:39:41 -07:00
self . init_new_token_ratio = min (
global_config . default_init_new_token_ratio
2024-10-25 10:24:44 -07:00
* server_args . schedule_conservativeness ,
1.0 ,
2024-09-29 17:42:45 -07:00
)
2024-10-26 16:39:41 -07:00
self . min_new_token_ratio = min (
self . init_new_token_ratio
* global_config . default_min_new_token_ratio_factor ,
1.0 ,
)
self . new_token_ratio_decay = (
self . init_new_token_ratio - self . min_new_token_ratio
) / global_config . default_new_token_ratio_decay_steps
self . new_token_ratio = self . init_new_token_ratio
2024-10-27 02:00:50 -07:00
# Init watchdog thread
self . watchdog_timeout = server_args . watchdog_timeout
t = threading . Thread ( target = self . watchdog_thread , daemon = True )
t . start ( )
2024-11-28 00:22:39 -08:00
self . parent_process = psutil . Process ( ) . parent ( )
2024-10-27 02:00:50 -07:00
2025-03-03 00:12:04 -08:00
# Init memory saver
2025-01-14 03:38:51 +08:00
self . memory_saver_adapter = TorchMemorySaverAdapter . create (
enable = server_args . enable_memory_saver
)
2024-10-16 11:20:17 -07:00
# Init profiler
2025-03-03 00:12:04 -08:00
self . torch_profiler = None
self . torch_profiler_output_dir : Optional [ str ] = None
self . torch_profiler_activities : Optional [ List [ str ] ] = None
self . profiler_target_forward_ct : Optional [ int ] = None
2024-11-10 04:39:32 -08:00
2024-11-06 12:42:53 +08:00
# Init metrics stats
2025-03-06 00:13:20 -08:00
self . init_metrics ( )
2024-10-11 17:34:25 +08:00
2025-01-19 17:10:29 -08:00
# Init request dispatcher
self . _request_dispatcher = TypeBasedDispatcher (
2025-01-19 12:13:27 +08:00
[
( TokenizedGenerateReqInput , self . handle_generate_request ) ,
( TokenizedEmbeddingReqInput , self . handle_embedding_request ) ,
( FlushCacheReq , self . flush_cache_wrapped ) ,
( AbortReq , self . abort_request ) ,
2025-03-03 00:12:04 -08:00
( OpenSessionReqInput , self . open_session ) ,
( CloseSessionReqInput , self . close_session ) ,
2025-01-19 12:13:27 +08:00
( UpdateWeightFromDiskReqInput , self . update_weights_from_disk ) ,
( InitWeightsUpdateGroupReqInput , self . init_weights_update_group ) ,
(
UpdateWeightsFromDistributedReqInput ,
self . update_weights_from_distributed ,
) ,
( UpdateWeightsFromTensorReqInput , self . update_weights_from_tensor ) ,
( GetWeightsByNameReqInput , self . get_weights_by_name ) ,
2025-03-03 00:12:04 -08:00
( ReleaseMemoryOccupationReqInput , self . release_memory_occupation ) ,
( ResumeMemoryOccupationReqInput , self . resume_memory_occupation ) ,
2025-01-19 12:13:27 +08:00
( ProfileReq , self . profile ) ,
2025-03-03 00:12:04 -08:00
( GetInternalStateReq , self . get_internal_state ) ,
2025-03-06 00:13:20 -08:00
( SetInternalStateReq , self . set_internal_state ) ,
2025-03-14 15:40:44 +08:00
( RpcReqInput , self . handle_rpc_request ) ,
2025-03-24 21:34:19 -07:00
( ExpertDistributionReq , self . expert_distribution_handle ) ,
2025-01-19 12:13:27 +08:00
]
)
2025-03-21 14:47:47 -07:00
self . disaggregation_mode = DisaggregationMode (
self . server_args . disaggregation_mode
)
self . init_disaggregation ( )
2025-03-06 00:13:20 -08:00
def init_tokenizer ( self ) :
server_args = self . server_args
2024-10-27 02:00:50 -07:00
2025-03-06 00:13:20 -08:00
self . model_config = ModelConfig (
server_args . model_path ,
trust_remote_code = server_args . trust_remote_code ,
revision = server_args . revision ,
context_length = server_args . context_length ,
model_override_args = server_args . json_model_override_args ,
is_embedding = server_args . is_embedding ,
dtype = server_args . dtype ,
quantization = server_args . quantization ,
)
self . is_generation = self . model_config . is_generation
2025-03-03 00:12:04 -08:00
2025-03-06 00:13:20 -08:00
if server_args . skip_tokenizer_init :
self . tokenizer = self . processor = None
else :
if self . model_config . is_multimodal :
self . processor = get_processor (
server_args . tokenizer_path ,
tokenizer_mode = server_args . tokenizer_mode ,
trust_remote_code = server_args . trust_remote_code ,
revision = server_args . revision ,
)
self . tokenizer = self . processor . tokenizer
else :
self . tokenizer = get_tokenizer (
server_args . tokenizer_path ,
tokenizer_mode = server_args . tokenizer_mode ,
trust_remote_code = server_args . trust_remote_code ,
revision = server_args . revision ,
)
def init_memory_pool_and_cache ( self ) :
server_args = self . server_args
self . req_to_token_pool , self . token_to_kv_pool_allocator = (
self . tp_worker . get_memory_pool ( )
)
if (
server_args . chunked_prefill_size is not None
and server_args . disable_radix_cache
) :
self . tree_cache = ChunkCache (
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool_allocator = self . token_to_kv_pool_allocator ,
)
else :
if self . enable_hierarchical_cache :
self . tree_cache = HiRadixCache (
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool_allocator = self . token_to_kv_pool_allocator ,
2025-03-12 11:22:35 -07:00
tp_cache_group = self . tp_worker . get_tp_cpu_group ( ) ,
2025-03-13 15:50:49 -07:00
page_size = self . page_size ,
2025-03-17 17:45:00 -07:00
hicache_ratio = server_args . hicache_ratio ,
2025-03-06 00:13:20 -08:00
)
else :
self . tree_cache = RadixCache (
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool_allocator = self . token_to_kv_pool_allocator ,
2025-03-12 22:22:39 -07:00
page_size = self . page_size ,
2025-03-06 00:13:20 -08:00
disable = server_args . disable_radix_cache ,
)
self . decode_mem_cache_buf_multiplier = (
1
if self . spec_algorithm . is_none ( )
else (
server_args . speculative_num_draft_tokens
+ (
server_args . speculative_eagle_topk
* server_args . speculative_num_steps
)
)
2025-03-03 00:12:04 -08:00
)
2025-03-06 00:13:20 -08:00
def init_metrics ( self ) :
# The largest prefill length of a single request
self . _largest_prefill_len : int = 0
# The largest context length (prefill + generation) of a single request
self . _largest_prefill_decode_len : int = 0
self . last_gen_throughput : float = 0.0
2025-03-12 22:22:39 -07:00
self . last_input_throughput : float = 0.0
2025-03-06 00:13:20 -08:00
self . step_time_dict = defaultdict ( list ) # Dict[batch size -> step time]
self . spec_num_total_accepted_tokens = 0
self . spec_num_total_forward_ct = 0
self . cum_spec_accept_length = 0
self . cum_spec_accept_count = 0
self . stats = SchedulerStats ( )
if self . enable_metrics :
engine_type = " unified "
self . metrics_collector = SchedulerMetricsCollector (
labels = {
" model_name " : self . server_args . served_model_name ,
" engine_type " : engine_type ,
} ,
)
2024-10-27 02:00:50 -07:00
2025-03-21 14:47:47 -07:00
def init_disaggregation ( self ) :
if (
self . disaggregation_mode == DisaggregationMode . DECODE
) : # *2 for the headroom.
buffer_size = ( self . req_to_token_pool . size ) * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator (
buffer_size
)
aux_dtype = torch . int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch . zeros (
( buffer_size , 16 ) , dtype = aux_dtype , device = " cpu "
)
metadata_buffers = [ output_id_buffer ]
# The decode requests polling kv cache
self . disagg_decode_transfer_queue = DecodeTransferQueue (
gloo_group = self . tp_worker . get_attention_tp_cpu_group ( ) ,
req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator ,
metadata_buffers = metadata_buffers ,
)
# The decode requests pending for pre-allocation
self . disagg_decode_prealloc_queue = DecodePreallocQueue (
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool_allocator = self . token_to_kv_pool_allocator ,
req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator ,
metadata_buffers = metadata_buffers ,
aux_dtype = aux_dtype ,
scheduler = self ,
transfer_queue = self . disagg_decode_transfer_queue ,
tree_cache = self . tree_cache ,
gloo_group = self . tp_worker . get_attention_tp_cpu_group ( ) ,
tp_rank = self . tp_rank ,
tp_size = self . tp_size ,
bootstrap_port = self . server_args . disaggregation_bootstrap_port ,
)
elif self . disaggregation_mode == DisaggregationMode . PREFILL :
# *2 for the headroom.
buffer_size = self . max_running_requests * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator (
buffer_size
)
aux_dtype = torch . int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch . zeros (
( buffer_size , 16 ) , dtype = aux_dtype , device = " cpu "
)
metadata_buffers = [ output_id_buffer ]
self . disagg_prefill_pending_queue = PrefillBootstrapQueue (
token_to_kv_pool = self . token_to_kv_pool_allocator . get_kvcache ( ) ,
req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator ,
metadata_buffers = metadata_buffers ,
aux_dtype = aux_dtype ,
tp_rank = self . tp_rank ,
tp_size = self . tp_size ,
bootstrap_port = self . server_args . disaggregation_bootstrap_port ,
gloo_group = self . tp_worker . get_attention_tp_cpu_group ( ) ,
)
# The prefill requests that are in the middle of kv sending
self . disagg_prefill_infight_queue : List [ Req ] = [ ]
2025-03-17 13:54:16 +08:00
@DynamicGradMode ( )
2024-10-13 19:54:02 -07:00
def event_loop_normal ( self ) :
2024-11-19 15:04:43 -08:00
""" A normal scheduler loop. """
2024-09-29 02:36:12 -07:00
while True :
2024-10-06 03:24:04 -07:00
recv_reqs = self . recv_requests ( )
self . process_input_requests ( recv_reqs )
2024-09-29 02:36:12 -07:00
2024-10-12 21:35:30 -07:00
batch = self . get_next_batch_to_run ( )
2024-10-27 02:00:50 -07:00
self . cur_batch = batch
2024-10-12 21:35:30 -07:00
if batch :
result = self . run_batch ( batch )
self . process_batch_result ( batch , result )
2024-10-14 05:25:00 -07:00
else :
2025-03-12 22:22:39 -07:00
# When the server is idle, do self-check and re-init some states
2024-10-14 05:25:00 -07:00
self . check_memory ( )
2024-10-26 16:39:41 -07:00
self . new_token_ratio = self . init_new_token_ratio
2024-10-12 21:35:30 -07:00
self . last_batch = batch
2024-09-29 02:36:12 -07:00
2025-03-17 13:54:16 +08:00
@DynamicGradMode ( )
2024-10-16 01:33:20 -07:00
def event_loop_overlap ( self ) :
2024-10-19 23:19:26 -07:00
""" A scheduler loop that overlaps the CPU processing and GPU computation. """
2025-01-27 03:00:41 -08:00
self . result_queue = deque ( )
2024-10-16 01:33:20 -07:00
while True :
recv_reqs = self . recv_requests ( )
self . process_input_requests ( recv_reqs )
batch = self . get_next_batch_to_run ( )
self . cur_batch = batch
2024-12-29 00:45:57 -08:00
2024-10-16 01:33:20 -07:00
if batch :
result = self . run_batch ( batch )
2025-01-27 03:00:41 -08:00
self . result_queue . append ( ( batch . copy ( ) , result ) )
2024-10-16 01:33:20 -07:00
2024-11-19 15:04:43 -08:00
if self . last_batch is None :
2025-01-19 17:10:29 -08:00
# Create a dummy first batch to start the pipeline for overlap schedule.
2024-11-19 15:04:43 -08:00
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch (
reqs = None ,
forward_mode = ForwardMode . DUMMY_FIRST ,
next_batch_sampling_info = self . tp_worker . cur_sampling_info ,
)
self . process_batch_result ( tmp_batch , None )
2024-10-16 01:33:20 -07:00
if self . last_batch :
2024-12-29 00:45:57 -08:00
# Process the results of the last batch
2025-01-27 03:00:41 -08:00
tmp_batch , tmp_result = self . result_queue . popleft ( )
2024-11-19 15:04:43 -08:00
tmp_batch . next_batch_sampling_info = (
self . tp_worker . cur_sampling_info if batch else None
)
2024-10-16 01:33:20 -07:00
self . process_batch_result ( tmp_batch , tmp_result )
elif batch is None :
2025-03-12 22:22:39 -07:00
# When the server is idle, do self-check and re-init some states
2024-10-16 01:33:20 -07:00
self . check_memory ( )
2024-10-26 16:39:41 -07:00
self . new_token_ratio = self . init_new_token_ratio
2024-10-16 01:33:20 -07:00
self . last_batch = batch
2025-03-21 14:47:47 -07:00
@torch.no_grad ( )
def event_loop_normal_disagg_prefill ( self ) :
""" A normal scheduler loop for prefill worker in disaggregation mode. """
while True :
recv_reqs = self . recv_requests ( )
self . process_input_requests ( recv_reqs )
self . waiting_queue . extend (
self . disagg_prefill_pending_queue . pop_bootstrapped ( )
)
self . process_prefill_chunk ( )
batch = self . get_new_batch_prefill ( )
self . cur_batch = batch
if batch :
result = self . run_batch ( batch )
self . process_batch_result_disagg_prefill ( batch , result )
if len ( self . disagg_prefill_infight_queue ) > 0 :
self . process_disagg_prefill_infight_queue ( )
if batch is None and len ( self . disagg_prefill_infight_queue ) == 0 :
self . check_memory ( )
self . new_token_ratio = self . init_new_token_ratio
self . last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self . running_batch . batch_is_full = False
@torch.no_grad ( )
def event_loop_normal_disagg_decode ( self ) :
""" A normal scheduler loop for decode worker in disaggregation mode. """
while True :
recv_reqs = self . recv_requests ( )
self . process_input_requests ( recv_reqs )
# polling and allocating kv cache
self . process_decode_queue ( )
batch = self . get_next_disagg_decode_batch_to_run ( )
self . cur_batch = batch
if batch :
# Generate fake extend output.
if batch . forward_mode . is_extend ( ) :
# Note: Logprobs should be handled on the prefill engine.
self . stream_output (
batch . reqs , [ False for _ in range ( len ( batch . reqs ) ) ]
)
else :
result = self . run_batch ( batch )
self . process_batch_result ( batch , result )
if batch is None and (
len ( self . disagg_decode_transfer_queue . queue )
+ len ( self . disagg_decode_prealloc_queue . queue )
== 0
) :
# When the server is idle, do self-check and re-init some states
self . check_memory ( )
self . new_token_ratio = self . init_new_token_ratio
self . last_batch = batch
2024-12-29 00:45:57 -08:00
def recv_requests ( self ) - > List [ Req ] :
""" Receive results at tp_rank = 0 and broadcast it to all other TP ranks. """
2025-01-16 11:15:00 -08:00
if self . attn_tp_rank == 0 :
2024-10-06 03:24:04 -07:00
recv_reqs = [ ]
2024-12-08 03:55:27 -08:00
while True :
try :
recv_req = self . recv_from_tokenizer . recv_pyobj ( zmq . NOBLOCK )
except zmq . ZMQError :
break
2024-12-08 04:08:04 -08:00
recv_reqs . append ( recv_req )
2025-03-14 15:40:44 +08:00
while True :
try :
recv_rpc = self . recv_from_rpc . recv_pyobj ( zmq . NOBLOCK )
except zmq . ZMQError :
break
recv_reqs . append ( recv_rpc )
2024-10-06 03:24:04 -07:00
else :
recv_reqs = None
2024-09-29 02:36:12 -07:00
2025-01-16 11:15:00 -08:00
if self . server_args . enable_dp_attention :
if self . attn_tp_rank == 0 :
work_reqs = [
req
for req in recv_reqs
if isinstance (
req , ( TokenizedGenerateReqInput , TokenizedEmbeddingReqInput )
)
]
control_reqs = [
req
for req in recv_reqs
if not isinstance (
req , ( TokenizedGenerateReqInput , TokenizedEmbeddingReqInput )
)
]
else :
work_reqs = None
control_reqs = None
if self . attn_tp_size != 1 :
attn_tp_rank_0 = self . dp_rank * self . attn_tp_size
work_reqs = broadcast_pyobj (
work_reqs ,
self . attn_tp_rank ,
self . attn_tp_cpu_group ,
src = attn_tp_rank_0 ,
)
if self . tp_size != 1 :
control_reqs = broadcast_pyobj (
control_reqs , self . tp_rank , self . tp_cpu_group
)
recv_reqs = work_reqs + control_reqs
elif self . tp_size != 1 :
2024-10-07 13:05:53 -07:00
recv_reqs = broadcast_pyobj ( recv_reqs , self . tp_rank , self . tp_cpu_group )
2024-09-29 02:36:12 -07:00
return recv_reqs
2024-10-06 03:24:04 -07:00
def process_input_requests ( self , recv_reqs : List ) :
2024-09-29 17:42:45 -07:00
for recv_req in recv_reqs :
2025-03-03 00:12:04 -08:00
# If it is a health check generation request and there are running requests, ignore it.
if is_health_check_generate_req ( recv_req ) and (
2025-03-12 22:22:39 -07:00
self . chunked_req is not None or not self . running_batch . is_empty ( )
2025-03-03 00:12:04 -08:00
) :
self . return_health_check_ct + = 1
continue
2025-01-19 17:10:29 -08:00
output = self . _request_dispatcher ( recv_req )
2025-01-19 12:13:27 +08:00
if output is not None :
2025-03-14 15:40:44 +08:00
if isinstance ( output , RpcReqOutput ) :
if self . recv_from_rpc is not None :
self . recv_from_rpc . send_pyobj ( output )
else :
self . send_to_tokenizer . send_pyobj ( output )
2024-09-29 17:42:45 -07:00
def handle_generate_request (
self ,
recv_req : TokenizedGenerateReqInput ,
) :
2024-11-29 03:15:58 -08:00
# Create a new request
2024-12-29 02:10:27 -08:00
if (
recv_req . session_params is None
or recv_req . session_params . id is None
or recv_req . session_params . id not in self . sessions
) :
2024-11-25 19:35:04 -05:00
if recv_req . input_embeds is not None :
# Generate fake input_ids based on the length of input_embeds
seq_length = len ( recv_req . input_embeds )
fake_input_ids = [ 1 ] * seq_length
recv_req . input_ids = fake_input_ids
2025-01-19 14:46:53 -08:00
# Handle custom logit processor passed to the request
custom_logit_processor = recv_req . custom_logit_processor
if (
not self . server_args . enable_custom_logit_processor
and custom_logit_processor is not None
) :
logger . warning (
" The SGLang server is not configured to enable custom logit processor. "
" The custom logit processor passed in will be ignored. "
" Please set --enable-custom-logits-processor to enable this feature. "
)
custom_logit_processor = None
2024-11-20 00:36:53 -08:00
req = Req (
recv_req . rid ,
recv_req . input_text ,
recv_req . input_ids ,
recv_req . sampling_params ,
2024-12-08 12:27:13 -08:00
return_logprob = recv_req . return_logprob ,
top_logprobs_num = recv_req . top_logprobs_num ,
2025-03-03 00:12:04 -08:00
token_ids_logprob = recv_req . token_ids_logprob ,
2024-12-08 12:27:13 -08:00
stream = recv_req . stream ,
2024-11-20 00:36:53 -08:00
lora_path = recv_req . lora_path ,
2024-11-25 19:35:04 -05:00
input_embeds = recv_req . input_embeds ,
2025-01-19 14:46:53 -08:00
custom_logit_processor = custom_logit_processor ,
2025-03-01 20:51:29 -05:00
return_hidden_states = recv_req . return_hidden_states ,
2024-12-27 11:23:46 -08:00
eos_token_ids = self . model_config . hf_eos_token_id ,
2024-11-20 00:36:53 -08:00
)
req . tokenizer = self . tokenizer
2024-11-25 16:38:43 -08:00
2024-12-29 02:10:27 -08:00
if (
recv_req . session_params is not None
and recv_req . session_params . id is not None
) :
2024-11-20 00:36:53 -08:00
req . finished_reason = FINISH_ABORT (
2024-12-29 02:10:27 -08:00
f " Invalid request: session id { recv_req . session_params . id } does not exist "
2024-11-20 00:36:53 -08:00
)
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-11-20 00:36:53 -08:00
return
else :
2024-12-29 02:10:27 -08:00
# Create a new request from a previous session
session = self . sessions [ recv_req . session_params . id ]
2024-11-25 12:32:51 -08:00
req = session . create_req ( recv_req , self . tokenizer )
2024-11-20 00:36:53 -08:00
if isinstance ( req . finished_reason , FINISH_ABORT ) :
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-11-20 00:36:53 -08:00
return
2024-09-29 17:42:45 -07:00
2025-01-27 03:00:41 -08:00
# Handle multimodal inputs
2025-03-25 11:08:40 +08:00
if recv_req . mm_inputs is not None :
image_inputs = MultimodalInputs . from_dict ( recv_req . mm_inputs )
2024-11-29 03:15:58 -08:00
# Expand a single image token into multiple dummy tokens for receiving image embeddings
2024-09-30 06:41:49 -07:00
req . origin_input_ids = self . pad_input_ids_func (
2024-11-27 00:03:29 -08:00
req . origin_input_ids , image_inputs
2024-09-29 17:42:45 -07:00
)
2024-11-29 03:15:58 -08:00
req . extend_image_inputs ( image_inputs )
2024-09-29 17:42:45 -07:00
2024-11-29 04:24:20 -08:00
if len ( req . origin_input_ids ) > = self . max_req_input_len :
2025-01-18 19:37:30 -08:00
error_msg = (
2024-11-29 04:24:20 -08:00
" Multimodal prompt is too long after expanding multimodal tokens. "
2025-01-18 19:37:30 -08:00
f " After expanding { len ( req . origin_input_ids_unpadded ) =} => { len ( req . origin_input_ids ) } >= { self . max_req_input_len } . "
2024-11-21 19:05:41 -08:00
)
2025-01-18 19:37:30 -08:00
logger . error ( error_msg )
2024-11-28 02:22:15 -08:00
req . origin_input_ids = [ 0 ]
2025-03-25 11:08:40 +08:00
req . multimodal_inputs = None
2024-11-21 19:05:41 -08:00
req . sampling_params . max_new_tokens = 0
2024-11-29 04:24:20 -08:00
req . finished_reason = FINISH_ABORT (
2025-01-18 19:37:30 -08:00
error_msg , HTTPStatus . BAD_REQUEST , " BadRequestError "
2024-11-29 04:24:20 -08:00
)
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-11-21 19:05:41 -08:00
return
2025-01-16 14:51:19 -08:00
# Validate prompts length
error_msg = validate_input_length (
req ,
self . max_req_input_len ,
self . server_args . allow_auto_truncate ,
)
if error_msg :
2025-02-27 22:59:43 -08:00
req . origin_input_ids = [ 0 ]
req . sampling_params . max_new_tokens = 0
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2025-01-16 14:51:19 -08:00
return
2024-10-21 16:12:04 -07:00
2025-01-26 01:39:28 -08:00
# Copy more attributes
2025-03-03 00:12:04 -08:00
if recv_req . logprob_start_len == - 1 or not recv_req . return_logprob :
2025-01-26 01:39:28 -08:00
# By default, only return the logprobs for output tokens
req . logprob_start_len = len ( req . origin_input_ids ) - 1
else :
req . logprob_start_len = recv_req . logprob_start_len
2025-03-03 00:12:04 -08:00
if req . logprob_start_len > = len ( req . origin_input_ids ) :
req . finished_reason = FINISH_ABORT (
f " logprob_start_len, ( { req . logprob_start_len } ) is higher than the number of input tokens ( { len ( req . origin_input_ids ) } ). Request with a lower logprob_start_len. " ,
HTTPStatus . BAD_REQUEST ,
" BadRequestError " ,
)
req . logprob_start_len = len ( req . origin_input_ids ) - 1
self . _add_request_to_queue ( req )
return
2024-09-29 17:42:45 -07:00
req . sampling_params . max_new_tokens = min (
(
req . sampling_params . max_new_tokens
if req . sampling_params . max_new_tokens is not None
else 1 << 30
) ,
2024-10-21 16:12:04 -07:00
self . max_req_len - len ( req . origin_input_ids ) - 1 ,
2024-09-29 17:42:45 -07:00
)
2024-11-13 01:49:45 -08:00
# Init grammar cache for this request
add_to_grammar_queue = False
if (
req . sampling_params . json_schema is not None
or req . sampling_params . regex is not None
2024-12-26 18:42:41 +05:30
or req . sampling_params . ebnf is not None
2025-02-28 15:33:41 +08:00
or req . sampling_params . structural_tag is not None
2024-11-13 01:49:45 -08:00
) :
assert self . grammar_backend is not None
if req . sampling_params . json_schema is not None :
key = ( " json " , req . sampling_params . json_schema )
elif req . sampling_params . regex is not None :
key = ( " regex " , req . sampling_params . regex )
2024-12-26 18:42:41 +05:30
elif req . sampling_params . ebnf is not None :
key = ( " ebnf " , req . sampling_params . ebnf )
2025-02-28 15:33:41 +08:00
elif req . sampling_params . structural_tag :
key = ( " structural_tag " , req . sampling_params . structural_tag )
2024-11-13 01:49:45 -08:00
req . grammar = self . grammar_backend . get_cached_value ( key )
if not req . grammar :
req . grammar = self . grammar_backend . get_future_value ( key )
add_to_grammar_queue = True
if add_to_grammar_queue :
2024-11-13 01:45:28 +09:00
self . grammar_queue . append ( req )
else :
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
def _add_request_to_queue ( self , req : Req ) :
2025-03-21 14:47:47 -07:00
if self . disaggregation_mode == DisaggregationMode . PREFILL :
self . disagg_prefill_pending_queue . add ( req )
2025-03-03 00:12:04 -08:00
2025-03-21 14:47:47 -07:00
elif self . disaggregation_mode == DisaggregationMode . DECODE :
self . disagg_decode_prealloc_queue . add ( req )
else :
self . waiting_queue . append ( req )
def _extend_requests_to_queue ( self , reqs : List [ Req ] , is_retracted : bool = False ) :
if self . disaggregation_mode == DisaggregationMode . DECODE :
self . disagg_decode_prealloc_queue . extend ( reqs )
else :
self . waiting_queue . extend ( reqs )
2024-09-29 17:42:45 -07:00
def handle_embedding_request (
self ,
2024-11-03 08:38:26 -08:00
recv_req : TokenizedEmbeddingReqInput ,
2024-09-29 17:42:45 -07:00
) :
req = Req (
recv_req . rid ,
recv_req . input_text ,
recv_req . input_ids ,
recv_req . sampling_params ,
)
req . tokenizer = self . tokenizer
2025-03-07 08:46:20 +08:00
# Handle multimodal inputs
if recv_req . image_inputs is not None :
2025-03-25 11:08:40 +08:00
image_inputs = MultimodalInputs . from_dict ( recv_req . image_inputs )
2025-03-07 08:46:20 +08:00
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req . origin_input_ids = self . pad_input_ids_func (
req . origin_input_ids , image_inputs
)
req . extend_image_inputs ( image_inputs )
if len ( req . origin_input_ids ) > = self . max_req_input_len :
error_msg = (
" Multimodal prompt is too long after expanding multimodal tokens. "
f " After expanding { len ( req . origin_input_ids_unpadded ) =} => { len ( req . origin_input_ids ) } >= { self . max_req_input_len } . "
)
logger . error ( error_msg )
req . origin_input_ids = [ 0 ]
2025-03-25 11:08:40 +08:00
req . multimodal_inputs = None
2025-03-07 08:46:20 +08:00
req . sampling_params . max_new_tokens = 0
req . finished_reason = FINISH_ABORT (
error_msg , HTTPStatus . BAD_REQUEST , " BadRequestError "
)
self . waiting_queue . append ( req )
return
2025-01-16 14:51:19 -08:00
# Validate prompts length
2025-01-26 01:39:28 -08:00
error_msg = validate_input_length (
2025-01-16 14:51:19 -08:00
req ,
self . max_req_input_len ,
self . server_args . allow_auto_truncate ,
)
2025-01-26 01:39:28 -08:00
if error_msg :
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2025-01-26 01:39:28 -08:00
return
2024-09-29 17:42:45 -07:00
2025-01-26 01:39:28 -08:00
# Copy more attributes
req . logprob_start_len = len ( req . origin_input_ids ) - 1
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-09-29 17:42:45 -07:00
2025-01-27 03:00:41 -08:00
def log_prefill_stats (
self ,
adder : PrefillAdder ,
can_run_list : List [ Req ] ,
2025-03-03 00:12:04 -08:00
running_bs : int ,
2025-01-27 03:00:41 -08:00
) :
2025-03-12 22:22:39 -07:00
gap_latency = time . time ( ) - self . last_prefill_stats_tic
self . last_prefill_stats_tic = time . time ( )
self . last_input_throughput = self . num_prefill_tokens / gap_latency
self . num_prefill_tokens = 0
2024-11-10 04:39:32 -08:00
num_used = self . max_total_num_tokens - (
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator . available_size ( )
+ self . tree_cache . evictable_size ( )
2024-11-10 04:39:32 -08:00
)
2025-03-03 00:12:04 -08:00
self . _largest_prefill_len = max (
self . _largest_prefill_len , adder . log_input_tokens
)
2024-11-10 04:39:32 -08:00
2025-03-03 00:12:04 -08:00
f = (
2024-11-10 04:39:32 -08:00
f " Prefill batch. "
f " #new-seq: { len ( can_run_list ) } , "
f " #new-token: { adder . log_input_tokens } , "
f " #cached-token: { adder . log_hit_tokens } , "
f " token usage: { num_used / self . max_total_num_tokens : .2f } , "
f " #running-req: { running_bs } , "
2025-03-03 00:12:04 -08:00
f " #queue-req: { len ( self . waiting_queue ) } , "
2024-11-10 04:39:32 -08:00
)
2025-03-03 00:12:04 -08:00
logger . info ( f )
2024-11-10 04:39:32 -08:00
if self . enable_metrics :
2025-03-03 00:12:04 -08:00
cache_hit_rate = adder . log_hit_tokens / (
adder . log_input_tokens + adder . log_hit_tokens
)
2024-11-10 04:39:32 -08:00
self . stats . num_running_reqs = running_bs
self . stats . num_used_tokens = num_used
self . stats . token_usage = round ( num_used / self . max_total_num_tokens , 2 )
2025-03-03 00:12:04 -08:00
self . stats . num_queue_reqs = len ( self . waiting_queue )
self . stats . cache_hit_rate = cache_hit_rate
2024-11-10 04:39:32 -08:00
self . metrics_collector . log_stats ( self . stats )
def log_decode_stats ( self ) :
2025-03-03 00:12:04 -08:00
gap_latency = time . time ( ) - self . last_decode_stats_tic
self . last_decode_stats_tic = time . time ( )
self . last_gen_throughput = self . num_generated_tokens / gap_latency
self . num_generated_tokens = 0
2025-03-12 22:22:39 -07:00
num_running_reqs = len ( self . running_batch . reqs )
2024-10-06 03:24:04 -07:00
num_used = self . max_total_num_tokens - (
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator . available_size ( )
+ self . tree_cache . evictable_size ( )
2024-10-06 03:24:04 -07:00
)
2025-03-03 00:12:04 -08:00
if RECORD_STEP_TIME :
self . step_time_dict [ num_running_reqs ] . append (
gap_latency / self . server_args . decode_log_interval
)
2024-10-06 03:24:04 -07:00
2025-01-19 17:10:29 -08:00
if self . spec_algorithm . is_none ( ) :
msg = (
f " Decode batch. "
f " #running-req: { num_running_reqs } , "
f " #token: { num_used } , "
f " token usage: { num_used / self . max_total_num_tokens : .2f } , "
2025-03-03 00:12:04 -08:00
f " gen throughput (token/s): { self . last_gen_throughput : .2f } , "
f " #queue-req: { len ( self . waiting_queue ) } , "
2025-01-19 17:10:29 -08:00
)
2025-01-19 18:36:59 -08:00
spec_accept_length = 0
2025-01-19 17:10:29 -08:00
else :
2025-01-19 18:36:59 -08:00
spec_accept_length = (
2025-01-19 17:10:29 -08:00
self . spec_num_total_accepted_tokens / self . spec_num_total_forward_ct
)
2025-03-03 00:12:04 -08:00
self . cum_spec_accept_length + = self . spec_num_total_accepted_tokens
self . cum_spec_accept_count + = self . spec_num_total_forward_ct
2025-01-19 17:10:29 -08:00
self . spec_num_total_accepted_tokens = self . spec_num_total_forward_ct = 0
msg = (
f " Decode batch. "
f " #running-req: { num_running_reqs } , "
f " #token: { num_used } , "
f " token usage: { num_used / self . max_total_num_tokens : .2f } , "
2025-01-19 18:36:59 -08:00
f " accept len: { spec_accept_length : .2f } , "
2025-03-03 00:12:04 -08:00
f " gen throughput (token/s): { self . last_gen_throughput : .2f } , "
f " #queue-req: { len ( self . waiting_queue ) } , "
2025-01-19 17:10:29 -08:00
)
logger . info ( msg )
2024-11-10 04:39:32 -08:00
if self . enable_metrics :
self . stats . num_running_reqs = num_running_reqs
self . stats . num_used_tokens = num_used
self . stats . token_usage = num_used / self . max_total_num_tokens
2025-03-03 00:12:04 -08:00
self . stats . cache_hit_rate = 0.0
self . stats . gen_throughput = self . last_gen_throughput
2024-11-10 04:39:32 -08:00
self . stats . num_queue_reqs = len ( self . waiting_queue )
2025-01-19 18:36:59 -08:00
self . stats . spec_accept_length = spec_accept_length
2024-11-10 04:39:32 -08:00
self . metrics_collector . log_stats ( self . stats )
2024-10-06 03:24:04 -07:00
def check_memory ( self ) :
available_size = (
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator . available_size ( )
+ self . tree_cache . evictable_size ( )
2024-10-06 03:24:04 -07:00
)
2025-01-27 12:28:17 -08:00
protected_size = self . tree_cache . protected_size ( )
memory_leak = available_size != (
self . max_total_num_tokens
if not self . enable_hierarchical_cache
else self . max_total_num_tokens - protected_size
)
if memory_leak :
2024-11-17 22:18:11 -08:00
msg = (
2025-03-12 22:22:39 -07:00
" KV cache pool leak detected! "
2025-01-27 12:28:17 -08:00
f " { available_size =} , { protected_size =} , { self . max_total_num_tokens =} \n "
2025-03-12 22:22:39 -07:00
f " { self . token_to_kv_pool_allocator . available_size ( ) =} \n "
f " { self . tree_cache . evictable_size ( ) =} \n "
2024-10-06 03:24:04 -07:00
)
2024-11-17 22:18:11 -08:00
warnings . warn ( msg )
if crash_on_warnings ( ) :
raise ValueError ( msg )
2024-10-06 03:24:04 -07:00
if len ( self . req_to_token_pool . free_slots ) != self . req_to_token_pool . size :
2024-11-17 22:18:11 -08:00
msg = (
2024-10-06 03:24:04 -07:00
" Memory pool leak detected! "
2024-11-17 22:18:11 -08:00
f " available_size= { len ( self . req_to_token_pool . free_slots ) } , "
f " total_size= { self . req_to_token_pool . size } \n "
2024-10-06 03:24:04 -07:00
)
2024-11-17 22:18:11 -08:00
warnings . warn ( msg )
if crash_on_warnings ( ) :
raise ValueError ( msg )
2024-10-06 03:24:04 -07:00
2025-03-03 00:12:04 -08:00
if (
self . enable_metrics
and self . attn_tp_rank == 0
and time . time ( ) > self . metrics_collector . last_log_time + 30
) :
# During idle time, also collect metrics every 30 seconds.
num_used = self . max_total_num_tokens - (
2025-03-05 21:32:42 -06:00
self . token_to_kv_pool_allocator . available_size ( )
2025-03-03 00:12:04 -08:00
+ self . tree_cache . evictable_size ( )
)
2025-03-12 22:22:39 -07:00
num_running_reqs = len ( self . running_batch . reqs )
2025-03-03 00:12:04 -08:00
self . stats . num_running_reqs = num_running_reqs
self . stats . num_used_tokens = num_used
self . stats . token_usage = num_used / self . max_total_num_tokens
self . stats . gen_throughput = 0
self . stats . num_queue_reqs = len ( self . waiting_queue )
self . metrics_collector . log_stats ( self . stats )
2024-12-11 12:51:50 -08:00
def get_next_batch_to_run ( self ) - > Optional [ ScheduleBatch ] :
2024-10-14 01:15:34 -07:00
# Merge the prefill batch into the running batch
2024-11-24 04:47:10 -08:00
if self . last_batch and self . last_batch . forward_mode . is_extend ( ) :
2025-03-03 00:12:04 -08:00
if self . chunked_req :
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
self . last_batch . filter_batch ( chunked_req_to_exclude = self . chunked_req )
self . tree_cache . cache_unfinished_req ( self . chunked_req )
# chunked request keeps its rid but will get a new req_pool_idx
self . req_to_token_pool . free ( self . chunked_req . req_pool_idx )
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = False
2024-11-24 04:47:10 -08:00
2025-03-07 22:12:13 -08:00
# Filter batch
2025-03-08 04:11:18 +08:00
last_bs = self . last_batch . batch_size ( )
2025-03-03 00:12:04 -08:00
self . last_batch . filter_batch ( )
2025-03-08 04:11:18 +08:00
if self . last_batch . batch_size ( ) < last_bs :
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = False
2025-03-08 04:11:18 +08:00
2025-03-07 22:12:13 -08:00
# Merge the new batch into the running batch
2024-10-14 01:15:34 -07:00
if not self . last_batch . is_empty ( ) :
2025-03-12 22:22:39 -07:00
if self . running_batch . is_empty ( ) :
2024-10-14 01:15:34 -07:00
self . running_batch = self . last_batch
else :
2025-03-12 22:22:39 -07:00
# Merge running_batch with prefill batch
2024-10-14 01:15:34 -07:00
self . running_batch . merge_batch ( self . last_batch )
2024-10-12 21:35:30 -07:00
2024-10-07 13:05:53 -07:00
new_batch = self . get_new_batch_prefill ( )
if new_batch is not None :
2025-01-19 17:10:29 -08:00
# Run prefill first if possible
ret = new_batch
else :
# Run decode
2025-03-12 22:22:39 -07:00
if not self . running_batch . is_empty ( ) :
2025-01-19 17:10:29 -08:00
self . running_batch = self . update_running_batch ( self . running_batch )
2025-03-12 22:22:39 -07:00
ret = self . running_batch if not self . running_batch . is_empty ( ) else None
else :
ret = None
2024-10-10 16:34:13 -07:00
2025-01-19 17:10:29 -08:00
# Handle DP attention
if self . server_args . enable_dp_attention :
2025-03-13 08:23:56 -07:00
ret , _ = self . prepare_dp_attn_batch ( ret )
2025-01-19 17:10:29 -08:00
return ret
2024-10-07 13:05:53 -07:00
2024-10-06 03:24:04 -07:00
def get_new_batch_prefill ( self ) - > Optional [ ScheduleBatch ] :
2024-11-12 21:17:38 -08:00
# Check if the grammar is ready in the grammar queue
2024-11-13 01:45:28 +09:00
if self . grammar_queue :
2024-11-13 01:49:45 -08:00
self . move_ready_grammar_requests ( )
2024-11-13 01:45:28 +09:00
2024-10-06 03:24:04 -07:00
# Handle the cases where prefill is not allowed
if (
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full or len ( self . waiting_queue ) == 0
2025-03-03 00:12:04 -08:00
) and self . chunked_req is None :
2024-10-06 03:24:04 -07:00
return None
2025-03-12 22:22:39 -07:00
running_bs = len ( self . running_batch . reqs )
2024-09-29 17:42:45 -07:00
if running_bs > = self . max_running_requests :
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = True
2024-09-29 17:42:45 -07:00
return None
2025-03-12 11:22:35 -07:00
if self . enable_hierarchical_cache :
# check for completion of hierarchical cache activities to release memory
self . tree_cache . writing_check ( )
self . tree_cache . loading_check ( )
2024-09-29 17:42:45 -07:00
# Get priority queue
prefix_computed = self . policy . calc_priority ( self . waiting_queue )
2024-10-06 03:24:04 -07:00
# Prefill policy
2024-09-29 17:42:45 -07:00
adder = PrefillAdder (
self . tree_cache ,
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator ,
2024-09-29 17:42:45 -07:00
self . running_batch ,
self . new_token_ratio ,
self . max_prefill_tokens ,
self . chunked_prefill_size ,
2024-11-10 04:39:32 -08:00
running_bs if self . is_mixed_chunk else 0 ,
2024-09-29 17:42:45 -07:00
)
2025-03-12 22:22:39 -07:00
if self . chunked_req is not None :
2025-03-03 00:12:04 -08:00
self . chunked_req . init_next_round_input ( )
self . chunked_req = adder . add_chunked_req ( self . chunked_req )
2024-10-25 10:24:44 -07:00
2024-10-14 05:25:00 -07:00
if self . lora_paths :
2025-03-12 22:22:39 -07:00
lora_set = set ( [ req . lora_path for req in self . running_batch . reqs ] )
2024-10-19 23:19:26 -07:00
# Get requests from the waiting queue to a new prefill batch
2024-09-29 17:42:45 -07:00
for req in self . waiting_queue :
if (
2024-10-14 05:25:00 -07:00
self . lora_paths
2024-09-29 17:42:45 -07:00
and len (
lora_set
| set ( [ req . lora_path for req in adder . can_run_list ] )
| set ( [ req . lora_path ] )
)
> self . max_loras_per_batch
) :
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = True
2024-09-29 17:42:45 -07:00
break
2024-10-04 18:00:18 -07:00
if running_bs + len ( adder . can_run_list ) > = self . max_running_requests :
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = True
2024-09-29 17:42:45 -07:00
break
2024-10-04 18:00:18 -07:00
2025-03-12 11:22:35 -07:00
req . init_next_round_input (
None if prefix_computed else self . tree_cache ,
self . enable_hierarchical_cache ,
)
2025-02-23 21:56:30 -08:00
2025-03-12 11:22:35 -07:00
res = adder . add_one_req (
req , self . chunked_req , self . enable_hierarchical_cache
)
2024-10-04 18:00:18 -07:00
if res != AddReqResult . CONTINUE :
if res == AddReqResult . NO_TOKEN :
2025-01-27 12:28:17 -08:00
if self . enable_hierarchical_cache :
# Set batch_is_full after making sure there are requests that can be served
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = len (
adder . can_run_list
) > 0 or (
2025-01-27 12:28:17 -08:00
self . running_batch is not None
and not self . running_batch . is_empty ( )
)
else :
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = True
2024-09-29 17:42:45 -07:00
break
2024-10-14 05:25:00 -07:00
# Update waiting queue
2025-03-03 00:12:04 -08:00
can_run_list : List [ Req ] = adder . can_run_list
2024-10-14 05:25:00 -07:00
if len ( can_run_list ) == 0 :
return None
self . waiting_queue = [
x for x in self . waiting_queue if x not in set ( can_run_list )
]
2024-09-29 17:42:45 -07:00
2025-03-12 11:22:35 -07:00
if self . enable_hierarchical_cache :
self . tree_cache . read_to_load_cache ( )
2025-03-03 00:12:04 -08:00
if adder . new_chunked_req is not None :
assert self . chunked_req is None
self . chunked_req = adder . new_chunked_req
2024-10-25 10:24:44 -07:00
2025-03-03 00:12:04 -08:00
if self . chunked_req :
self . chunked_req . is_chunked + = 1
2024-10-06 03:24:04 -07:00
2024-09-29 17:42:45 -07:00
# Print stats
2025-01-16 11:15:00 -08:00
if self . attn_tp_rank == 0 :
2025-03-03 00:12:04 -08:00
self . log_prefill_stats ( adder , can_run_list , running_bs )
2024-09-29 17:42:45 -07:00
2024-10-06 03:24:04 -07:00
# Create a new batch
2024-09-29 17:42:45 -07:00
new_batch = ScheduleBatch . init_new (
can_run_list ,
self . req_to_token_pool ,
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator ,
2024-09-29 17:42:45 -07:00
self . tree_cache ,
2024-10-21 15:01:21 -07:00
self . model_config ,
2024-11-24 06:29:38 -08:00
self . enable_overlap ,
2025-01-02 02:09:08 -08:00
self . spec_algorithm ,
2025-01-20 02:00:35 -08:00
self . server_args . enable_custom_logit_processor ,
2024-09-29 17:42:45 -07:00
)
2024-11-24 06:29:38 -08:00
new_batch . prepare_for_extend ( )
2024-09-29 17:42:45 -07:00
2024-10-06 03:24:04 -07:00
# Mixed-style chunked prefill
2024-11-24 07:17:37 -08:00
if (
self . is_mixed_chunk
2025-03-12 22:22:39 -07:00
and not self . running_batch . is_empty ( )
2024-11-24 07:17:37 -08:00
and not ( new_batch . return_logprob or self . running_batch . return_logprob )
) :
# TODO (lianmin): support return_logprob + mixed chunked prefill
2024-10-30 21:20:41 -07:00
self . running_batch . filter_batch ( )
if not self . running_batch . is_empty ( ) :
2024-11-24 06:29:38 -08:00
self . running_batch . prepare_for_decode ( )
2024-10-30 21:20:41 -07:00
new_batch . mix_with_running ( self . running_batch )
new_batch . decoding_reqs = self . running_batch . reqs
2025-03-12 22:22:39 -07:00
self . running_batch = ScheduleBatch (
reqs = [ ] , batch_is_full = self . running_batch . batch_is_full
)
2024-10-14 05:25:00 -07:00
else :
new_batch . decoding_reqs = None
2024-10-06 03:24:04 -07:00
return new_batch
2024-11-24 04:47:10 -08:00
def update_running_batch ( self , batch : ScheduleBatch ) - > Optional [ ScheduleBatch ] :
2024-10-19 23:19:26 -07:00
""" Update the current running decoding batch. """
2024-11-24 04:47:10 -08:00
initial_bs = batch . batch_size ( )
2024-10-06 03:24:04 -07:00
2024-10-14 01:15:34 -07:00
batch . filter_batch ( )
if batch . is_empty ( ) :
2025-03-12 22:22:39 -07:00
batch . batch_is_full = False
return batch
2024-10-14 01:15:34 -07:00
2024-10-06 03:24:04 -07:00
# Check if decode out of memory
2025-01-02 02:09:08 -08:00
if not batch . check_decode_mem ( self . decode_mem_cache_buf_multiplier ) or (
2025-03-03 00:12:04 -08:00
TEST_RETRACT and batch . batch_size ( ) > 10
2025-01-02 02:09:08 -08:00
) :
2024-10-06 03:24:04 -07:00
old_ratio = self . new_token_ratio
2025-03-03 00:12:04 -08:00
retracted_reqs , new_token_ratio = batch . retract_decode ( self . server_args )
2024-10-06 03:24:04 -07:00
self . new_token_ratio = new_token_ratio
2024-09-29 17:42:45 -07:00
2024-10-06 03:24:04 -07:00
logger . info (
" Decode out of memory happened. "
f " #retracted_reqs: { len ( retracted_reqs ) } , "
f " #new_token_ratio: { old_ratio : .4f } -> { self . new_token_ratio : .4f } "
)
2025-03-03 00:12:04 -08:00
self . _extend_requests_to_queue ( retracted_reqs )
2024-10-06 03:24:04 -07:00
else :
self . new_token_ratio = max (
2024-10-25 10:24:44 -07:00
self . new_token_ratio - self . new_token_ratio_decay ,
2024-10-06 03:24:04 -07:00
self . min_new_token_ratio ,
)
2024-11-24 04:47:10 -08:00
if batch . batch_size ( ) < initial_bs :
2025-03-12 22:22:39 -07:00
batch . batch_is_full = False
2024-10-06 03:24:04 -07:00
# Update batch tensors
2024-11-24 06:29:38 -08:00
batch . prepare_for_decode ( )
2024-11-24 04:47:10 -08:00
return batch
2024-10-06 03:24:04 -07:00
2025-01-16 12:51:11 -08:00
def run_batch (
self , batch : ScheduleBatch
) - > Union [ GenerationBatchResult , EmbeddingBatchResult ] :
2024-10-19 23:19:26 -07:00
""" Run a batch. """
2024-10-27 02:00:50 -07:00
self . forward_ct + = 1
2025-03-03 00:12:04 -08:00
# Check profiler
if (
self . profiler_target_forward_ct
and self . profiler_target_forward_ct < = self . forward_ct
) :
self . stop_profile ( )
2025-03-06 00:13:20 -08:00
# Run forward
2024-09-29 17:42:45 -07:00
if self . is_generation :
2025-01-26 01:39:28 -08:00
if self . spec_algorithm . is_none ( ) :
model_worker_batch = batch . get_model_worker_batch ( )
logits_output , next_token_ids = self . tp_worker . forward_batch_generation (
model_worker_batch
)
2025-03-05 08:06:07 -08:00
bid = model_worker_batch . bid
2024-10-06 03:24:04 -07:00
else :
2025-01-26 01:39:28 -08:00
(
logits_output ,
next_token_ids ,
2025-03-05 08:06:07 -08:00
bid ,
2025-01-26 01:39:28 -08:00
num_accepted_tokens ,
) = self . draft_worker . forward_batch_speculative_generation ( batch )
self . spec_num_total_accepted_tokens + = (
num_accepted_tokens + batch . batch_size ( )
)
self . spec_num_total_forward_ct + = batch . batch_size ( )
self . num_generated_tokens + = num_accepted_tokens
2024-10-13 19:54:02 -07:00
batch . output_ids = next_token_ids
2025-03-06 00:13:20 -08:00
2025-03-03 00:12:04 -08:00
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
if batch . return_logprob :
extend_input_len_per_req = [ req . extend_input_len for req in batch . reqs ]
extend_logprob_start_len_per_req = [
req . extend_logprob_start_len for req in batch . reqs
]
else :
extend_input_len_per_req = None
extend_logprob_start_len_per_req = None
2025-01-16 12:51:11 -08:00
ret = GenerationBatchResult (
logits_output = logits_output ,
next_token_ids = next_token_ids ,
2025-03-03 00:12:04 -08:00
extend_input_len_per_req = extend_input_len_per_req ,
extend_logprob_start_len_per_req = extend_logprob_start_len_per_req ,
2025-03-05 08:06:07 -08:00
bid = bid ,
2025-01-16 12:51:11 -08:00
)
2024-10-06 03:24:04 -07:00
else : # embedding or reward model
model_worker_batch = batch . get_model_worker_batch ( )
embeddings = self . tp_worker . forward_batch_embedding ( model_worker_batch )
2025-01-16 12:51:11 -08:00
ret = EmbeddingBatchResult (
embeddings = embeddings , bid = model_worker_batch . bid
)
2024-10-12 21:35:30 -07:00
return ret
2024-11-07 15:42:47 -08:00
2025-01-16 12:51:11 -08:00
def process_batch_result (
self ,
batch : ScheduleBatch ,
result : Union [ GenerationBatchResult , EmbeddingBatchResult ] ,
) :
2024-10-06 03:24:04 -07:00
if batch . forward_mode . is_decode ( ) :
self . process_batch_result_decode ( batch , result )
2024-11-19 15:04:43 -08:00
elif batch . forward_mode . is_extend ( ) :
2024-10-06 03:24:04 -07:00
self . process_batch_result_prefill ( batch , result )
2025-01-16 11:15:00 -08:00
elif batch . forward_mode . is_idle ( ) :
if self . enable_overlap :
2025-01-16 12:51:11 -08:00
self . tp_worker . resolve_batch_result ( result . bid )
2025-02-26 01:32:05 +08:00
if batch . next_batch_sampling_info :
batch . next_batch_sampling_info . update_regex_vocab_mask ( )
self . current_stream . synchronize ( )
batch . next_batch_sampling_info . sampling_info_done . set ( )
2024-11-19 15:04:43 -08:00
elif batch . forward_mode . is_dummy_first ( ) :
batch . next_batch_sampling_info . update_regex_vocab_mask ( )
2024-12-06 05:49:29 -08:00
self . current_stream . synchronize ( )
2024-11-19 15:04:43 -08:00
batch . next_batch_sampling_info . sampling_info_done . set ( )
2024-10-06 03:24:04 -07:00
2025-03-03 00:12:04 -08:00
if self . return_health_check_ct :
# Return some signal for the health check.
# This is used to prevent the health check signal being blocked by long context prefill.
# However, one minor issue is that this code path does not check the status of detokenizer manager.
self . return_health_check_ct - = 1
self . send_to_tokenizer . send_pyobj ( HealthCheckOutput ( ) )
2024-12-06 05:49:29 -08:00
def prepare_dp_attn_batch ( self , local_batch : ScheduleBatch ) :
# Check if other DP workers have running batches
if local_batch is None :
num_tokens = 0
2025-03-13 08:23:56 -07:00
global_num_tokens_for_logprob = 0
2024-12-06 05:49:29 -08:00
elif local_batch . forward_mode . is_decode ( ) :
num_tokens = local_batch . batch_size ( )
2025-03-13 08:23:56 -07:00
if not self . spec_algorithm . is_none ( ) and self . spec_algorithm . is_eagle ( ) :
num_tokens = num_tokens * self . server_args . speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
2024-12-06 05:49:29 -08:00
else :
num_tokens = local_batch . extend_num_tokens
2025-03-13 08:23:56 -07:00
global_num_tokens_for_logprob = sum (
[
# We should have at least 1 token for sample in every case.
max ( extend_len - logprob_start_len , 1 )
for logprob_start_len , extend_len in zip (
local_batch . extend_logprob_start_lens , local_batch . extend_lens
)
]
)
if local_batch is None or local_batch . forward_mode . is_decode_or_idle ( ) :
can_cuda_graph = 1
else :
can_cuda_graph = 0
if not self . spec_algorithm . is_none ( ) :
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch . forward_mode . is_idle ( ) :
can_cuda_graph = 0
2024-12-06 05:49:29 -08:00
2025-03-13 08:23:56 -07:00
is_extend_in_batch = (
local_batch . forward_mode . is_extend ( ) if local_batch else False
)
local_info = torch . tensor (
[
num_tokens ,
can_cuda_graph ,
global_num_tokens_for_logprob ,
is_extend_in_batch ,
] ,
dtype = torch . int64 ,
)
global_info = torch . empty (
( self . server_args . dp_size , self . attn_tp_size , 4 ) ,
dtype = torch . int64 ,
)
2024-12-06 05:49:29 -08:00
torch . distributed . all_gather_into_tensor (
2025-03-13 08:23:56 -07:00
global_info . flatten ( ) ,
local_info ,
2024-12-06 05:49:29 -08:00
group = self . tp_cpu_group ,
)
2025-03-13 08:23:56 -07:00
global_num_tokens = global_info [ : , 0 , 0 ] . tolist ( )
can_cuda_graph = min ( global_info [ : , 0 , 1 ] . tolist ( ) )
global_num_tokens_for_logprob = global_info [ : , 0 , 2 ] . tolist ( )
is_extend_in_batch = global_info [ : , 0 , 3 ] . tolist ( )
2024-12-06 05:49:29 -08:00
2025-03-13 08:23:56 -07:00
if local_batch is None and max ( global_num_tokens ) > 0 :
2024-12-06 05:49:29 -08:00
local_batch = self . get_idle_batch ( )
if local_batch is not None :
2025-03-13 08:23:56 -07:00
local_batch . global_num_tokens = global_num_tokens
local_batch . global_num_tokens_for_logprob = global_num_tokens_for_logprob
2024-12-06 05:49:29 -08:00
# Check forward mode for cuda graph
if not self . server_args . disable_cuda_graph :
2025-03-13 08:23:56 -07:00
local_batch . can_run_dp_cuda_graph = can_cuda_graph
2024-12-06 05:49:29 -08:00
2025-03-13 08:23:56 -07:00
return local_batch , any ( is_extend_in_batch )
2024-12-06 05:49:29 -08:00
def get_idle_batch ( self ) :
idle_batch = ScheduleBatch . init_new (
[ ] ,
self . req_to_token_pool ,
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator ,
2024-12-06 05:49:29 -08:00
self . tree_cache ,
self . model_config ,
self . enable_overlap ,
2025-01-02 02:09:08 -08:00
self . spec_algorithm ,
2025-01-20 02:00:35 -08:00
self . server_args . enable_custom_logit_processor ,
2024-12-06 05:49:29 -08:00
)
idle_batch . prepare_for_idle ( )
return idle_batch
2024-11-13 01:49:45 -08:00
def move_ready_grammar_requests ( self ) :
""" Move requests whose grammar objects are ready from grammar_queue to waiting_queue. """
num_ready_reqs = 0
for req in self . grammar_queue :
try :
req . grammar = req . grammar . result ( timeout = 0.05 )
num_ready_reqs + = 1
except futures . _base . TimeoutError :
break
2025-02-26 01:32:05 +08:00
if self . server_args . enable_dp_attention :
2025-03-04 13:40:40 -08:00
tp_size = self . attn_tp_size
tp_group = self . attn_tp_cpu_group
2025-02-26 01:32:05 +08:00
else :
2025-03-04 13:40:40 -08:00
tp_size = self . tp_size
tp_group = self . tp_cpu_group
if tp_size > 1 :
# Sync across TP ranks to make sure they have the same number of ready requests
tensor = torch . tensor ( num_ready_reqs , dtype = torch . int32 )
torch . distributed . all_reduce (
tensor , op = torch . distributed . ReduceOp . MAX , group = tp_group
)
num_ready_reqs_max = tensor . item ( )
for i in range ( num_ready_reqs , num_ready_reqs_max ) :
self . grammar_queue [ i ] . grammar = self . grammar_queue [ i ] . grammar . result ( )
num_ready_reqs = num_ready_reqs_max
2024-11-13 01:49:45 -08:00
2025-03-03 00:12:04 -08:00
self . _extend_requests_to_queue ( self . grammar_queue [ : num_ready_reqs ] )
2024-11-13 01:49:45 -08:00
self . grammar_queue = self . grammar_queue [ num_ready_reqs : ]
2025-03-06 00:13:20 -08:00
def watchdog_thread ( self ) :
""" A watch dog thread that will try to kill the server itself if one forward batch takes too long. """
self . watchdog_last_forward_ct = 0
self . watchdog_last_time = time . time ( )
while True :
current = time . time ( )
if self . cur_batch is not None :
if self . watchdog_last_forward_ct == self . forward_ct :
if current > self . watchdog_last_time + self . watchdog_timeout :
logger . error ( f " Watchdog timeout ( { self . watchdog_timeout =} ) " )
break
else :
self . watchdog_last_forward_ct = self . forward_ct
self . watchdog_last_time = current
time . sleep ( self . watchdog_timeout / / 2 )
# Print batch size and memory pool info to check whether there are de-sync issues.
logger . error (
f " { self . cur_batch . batch_size ( ) =} , "
f " { self . cur_batch . reqs =} , "
f " { self . token_to_kv_pool_allocator . available_size ( ) =} , "
f " { self . tree_cache . evictable_size ( ) =} , "
)
# Wait for some time so that the parent process can print the error.
pyspy_dump_schedulers ( )
print ( file = sys . stderr , flush = True )
print ( file = sys . stdout , flush = True )
time . sleep ( 5 )
self . parent_process . send_signal ( signal . SIGQUIT )
2025-01-19 12:13:27 +08:00
def flush_cache_wrapped ( self , recv_req : FlushCacheReq ) :
self . flush_cache ( )
2024-09-29 17:42:45 -07:00
def flush_cache ( self ) :
2024-10-19 23:19:26 -07:00
""" Flush the memory pool and cache. """
2025-03-12 22:22:39 -07:00
if len ( self . waiting_queue ) == 0 and self . running_batch . is_empty ( ) :
2025-03-03 00:12:04 -08:00
self . cur_batch = None
self . last_batch = None
2024-09-29 17:42:45 -07:00
self . tree_cache . reset ( )
2024-11-13 01:49:45 -08:00
if self . grammar_backend :
2024-11-12 21:17:38 -08:00
self . grammar_backend . reset ( )
2024-09-29 17:42:45 -07:00
self . req_to_token_pool . clear ( )
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator . clear ( )
2025-01-20 20:25:13 -08:00
if not self . spec_algorithm . is_none ( ) :
self . draft_worker . model_runner . req_to_token_pool . clear ( )
2025-03-05 08:06:07 -08:00
self . draft_worker . model_runner . token_to_kv_pool_allocator . clear ( )
2025-01-20 20:25:13 -08:00
self . num_generated_tokens = 0
self . forward_ct_decode = 0
self . spec_num_total_accepted_tokens = 0
self . spec_num_total_forward_ct = 0
2025-03-03 00:12:04 -08:00
self . cum_spec_accept_length = 0
self . cum_spec_accept_count = 0
2024-09-29 17:42:45 -07:00
torch . cuda . empty_cache ( )
logger . info ( " Cache flushed successfully! " )
if_success = True
else :
logging . warning (
f " Cache not flushed because there are pending requests. "
f " #queue-req: { len ( self . waiting_queue ) } , "
2025-03-12 22:22:39 -07:00
f " #running-req: { len ( self . running_batch . reqs ) } "
2024-09-29 17:42:45 -07:00
)
if_success = False
return if_success
2025-03-03 00:12:04 -08:00
def get_internal_state ( self , recv_req : GetInternalStateReq ) :
ret = dict ( global_server_args_dict )
ret [ " last_gen_throughput " ] = self . last_gen_throughput
if not self . spec_algorithm . is_none ( ) and self . cum_spec_accept_count > 0 :
ret [ " avg_spec_accept_length " ] = (
self . cum_spec_accept_length / self . cum_spec_accept_count
)
if RECORD_STEP_TIME :
ret [ " step_time_dict " ] = self . step_time_dict
return GetInternalStateReqOutput (
internal_state = ret ,
)
def set_internal_state ( self , recv_req : SetInternalStateReq ) :
server_args_dict = recv_req . server_args
args_allow_update = set (
[
" speculative_accept_threshold_single " ,
" speculative_accept_threshold_acc " ,
]
)
if_success = True
for k , v in server_args_dict . items ( ) :
if k not in args_allow_update :
logging . warning ( f " Updating { k } is not supported. " )
if_success = False
break
if if_success :
if not self . spec_algorithm . is_none ( ) and self . cum_spec_accept_count > 0 :
avg_spec_accept_length = (
self . cum_spec_accept_length / self . cum_spec_accept_count
)
logger . info ( f " { avg_spec_accept_length =} " )
self . cum_spec_accept_length = self . cum_spec_accept_count = 0
for k , v in server_args_dict . items ( ) :
global_server_args_dict [ k ] = v
logger . info ( f " Global server args updated! " f " { global_server_args_dict =} " )
return SetInternalStateReqOutput (
updated = True ,
server_args = global_server_args_dict ,
)
2025-03-14 15:40:44 +08:00
def handle_rpc_request ( self , recv_req : RpcReqInput ) :
# Handle RPC requests
logger . info (
f " handle_rpc_request: { recv_req . method } , param: { recv_req . parameters } "
)
success = True
exec = None
try :
func = getattr ( self , recv_req . method )
func ( recv_req . parameters )
except Exception as e :
success = False
exec = e
logger . error ( f " Failed to call rpc { recv_req . method } : { str ( e ) } " )
barrier ( )
return RpcReqOutput ( success , " " if not exec else str ( exec ) )
def save_remote_model ( self , params ) :
url = params [ " url " ]
if isinstance ( self . tp_worker , TpModelWorkerClient ) :
worker = self . tp_worker . worker
else :
worker = self . tp_worker
worker . model_runner . save_remote_model ( url )
def save_sharded_model ( self , params ) :
if isinstance ( self . tp_worker , TpModelWorkerClient ) :
worker = self . tp_worker . worker
else :
worker = self . tp_worker
worker . model_runner . save_sharded_model (
path = params [ " path " ] ,
pattern = params [ " pattern " ] ,
max_size = params [ " max_size " ] ,
)
2024-09-29 17:42:45 -07:00
def abort_request ( self , recv_req : AbortReq ) :
# Delete requests in the waiting queue
2025-03-12 22:22:39 -07:00
to_del = [ ]
2024-09-29 17:42:45 -07:00
for i , req in enumerate ( self . waiting_queue ) :
2025-03-12 22:22:39 -07:00
if req . rid . startswith ( recv_req . rid ) :
to_del . append ( i )
2024-09-29 17:42:45 -07:00
break
2025-03-12 22:22:39 -07:00
# Sort in reverse order to avoid index issues when deleting
for i in sorted ( to_del , reverse = True ) :
req = self . waiting_queue . pop ( i )
2024-11-28 02:22:15 -08:00
logger . debug ( f " Abort queued request. { req . rid =} " )
return
2024-09-29 17:42:45 -07:00
# Delete requests in the running batch
2025-03-12 22:22:39 -07:00
for req in self . running_batch . reqs :
if req . rid . startswith ( recv_req . rid ) and not req . finished ( ) :
logger . debug ( f " Abort running request. { req . rid =} " )
req . to_abort = True
return
2024-09-29 17:42:45 -07:00
2025-03-06 00:13:20 -08:00
def _pause_engine ( self ) - > Tuple [ List [ Req ] , int ] :
raise NotImplementedError ( )
2024-11-29 17:17:00 -08:00
def update_weights_from_disk ( self , recv_req : UpdateWeightFromDiskReqInput ) :
""" In-place update of the weights from disk. """
success , message = self . tp_worker . update_weights_from_disk ( recv_req )
2024-09-29 17:42:45 -07:00
if success :
flash_cache_success = self . flush_cache ( )
assert flash_cache_success , " Cache flush failed after updating weights "
else :
logger . error ( message )
2025-03-03 00:12:04 -08:00
return UpdateWeightFromDiskReqOutput ( success , message , 0 )
2024-09-29 17:42:45 -07:00
2024-12-01 23:23:18 -08:00
def init_weights_update_group ( self , recv_req : InitWeightsUpdateGroupReqInput ) :
""" Initialize the online model parameter update group. """
success , message = self . tp_worker . init_weights_update_group ( recv_req )
2025-01-19 12:13:27 +08:00
return InitWeightsUpdateGroupReqOutput ( success , message )
2024-12-01 23:23:18 -08:00
def update_weights_from_distributed (
2025-01-07 02:52:53 -08:00
self ,
recv_req : UpdateWeightsFromDistributedReqInput ,
) - > Tuple [ bool , str ] :
2024-12-01 23:23:18 -08:00
""" Update the online model parameter. """
success , message = self . tp_worker . update_weights_from_distributed ( recv_req )
if success :
flash_cache_success = self . flush_cache ( )
assert flash_cache_success , " Cache flush failed after updating weights "
else :
logger . error ( message )
2025-01-19 12:13:27 +08:00
return UpdateWeightsFromDistributedReqOutput ( success , message )
2024-12-01 23:23:18 -08:00
2024-12-29 05:30:27 +08:00
def update_weights_from_tensor ( self , recv_req : UpdateWeightsFromTensorReqInput ) :
""" Update the online model parameter from tensors. """
success , message = self . tp_worker . update_weights_from_tensor ( recv_req )
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success :
2025-03-01 01:53:10 +08:00
if recv_req . flush_cache :
flash_cache_success = self . flush_cache ( )
assert flash_cache_success , " Cache flush failed after updating weights "
2024-12-29 05:30:27 +08:00
else :
logger . error ( message )
2025-01-19 12:13:27 +08:00
return UpdateWeightsFromTensorReqOutput ( success , message )
2024-12-29 05:30:27 +08:00
2024-11-29 23:36:38 -08:00
def get_weights_by_name ( self , recv_req : GetWeightsByNameReqInput ) :
parameter = self . tp_worker . get_weights_by_name ( recv_req )
2025-01-19 12:13:27 +08:00
return GetWeightsByNameReqOutput ( parameter )
2024-11-29 23:36:38 -08:00
2025-03-03 00:12:04 -08:00
def release_memory_occupation ( self , recv_req : ReleaseMemoryOccupationReqInput ) :
2025-03-26 15:18:14 +08:00
self . memory_saver_adapter . check_validity (
caller_name = " release_memory_occupation "
)
2025-01-14 03:38:51 +08:00
self . stashed_model_static_state = _export_static_state (
self . tp_worker . worker . model_runner . model
)
self . memory_saver_adapter . pause ( )
self . flush_cache ( )
2025-01-19 12:13:27 +08:00
return ReleaseMemoryOccupationReqOutput ( )
2025-01-14 03:38:51 +08:00
2025-03-03 00:12:04 -08:00
def resume_memory_occupation ( self , recv_req : ResumeMemoryOccupationReqInput ) :
2025-03-26 15:18:14 +08:00
self . memory_saver_adapter . check_validity ( caller_name = " resume_memory_occupation " )
2025-01-14 03:38:51 +08:00
self . memory_saver_adapter . resume ( )
_import_static_state (
self . tp_worker . worker . model_runner . model , self . stashed_model_static_state
)
del self . stashed_model_static_state
2025-01-19 12:13:27 +08:00
return ResumeMemoryOccupationReqOutput ( )
def profile ( self , recv_req : ProfileReq ) :
2025-03-03 00:12:04 -08:00
if recv_req . type == ProfileReqType . START_PROFILE :
return self . start_profile (
recv_req . output_dir , recv_req . num_steps , recv_req . activities
)
2025-01-19 12:13:27 +08:00
else :
2025-03-03 00:12:04 -08:00
return self . stop_profile ( )
def start_profile (
self ,
output_dir : Optional [ str ] ,
num_steps : Optional [ int ] ,
activities : Optional [ List [ str ] ] ,
) - > None :
if self . torch_profiler_activities :
return ProfileReqOutput (
success = False ,
message = " Profiling is already in progress. Call /stop_profile first. " ,
)
if output_dir is None :
output_dir = os . getenv ( " SGLANG_TORCH_PROFILER_DIR " , " /tmp " )
if activities is None :
activities = [ " CPU " , " GPU " ]
self . torch_profiler_output_dir = output_dir
self . torch_profiler_activities = activities
logger . info (
" Profiling starts. Traces will be saved to: %s " ,
self . torch_profiler_output_dir ,
)
activity_map = {
" CPU " : torch . profiler . ProfilerActivity . CPU ,
" GPU " : torch . profiler . ProfilerActivity . CUDA ,
}
torchprof_activities = [
activity_map [ a ] for a in activities if a in activity_map
]
if torchprof_activities :
self . torch_profiler = torch . profiler . profile (
activities = torchprof_activities ,
with_stack = True ,
)
self . torch_profiler . start ( )
if " MEM " in activities :
torch . cuda . memory . _record_memory_history ( max_entries = 100000 )
2025-01-14 03:38:51 +08:00
2025-03-03 00:12:04 -08:00
if num_steps :
self . profiler_target_forward_ct = self . forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else :
self . profiler_target_forward_ct = None
return ProfileReqOutput ( success = True , message = " Succeeded " )
2024-10-11 17:34:25 +08:00
def stop_profile ( self ) - > None :
2025-03-03 00:12:04 -08:00
if self . torch_profiler_activities is None :
return
logger . info ( " Stop profiling... " )
if self . torch_profiler is not None :
self . torch_profiler . stop ( )
self . torch_profiler . export_chrome_trace (
os . path . join (
self . torch_profiler_output_dir ,
str ( time . time ( ) ) + f " -TP- { self . tp_rank } " + " .trace.json.gz " ,
)
)
if " MEM " in self . torch_profiler_activities :
memory_profile_path = os . path . join (
self . torch_profiler_trace_dir ,
str ( time . time ( ) ) + f " -TP- { self . tp_rank } -memory " + " .pickle " ,
)
torch . cuda . memory . _dump_snapshot ( memory_profile_path )
torch . cuda . memory . _record_memory_history ( enabled = None )
logger . info (
" Profiling done. Traces are saved to: %s " ,
self . torch_profiler_output_dir ,
2024-10-11 17:34:25 +08:00
)
2025-03-03 00:12:04 -08:00
self . torch_profiler = None
self . torch_profiler_output_dir = None
self . torch_profiler_activities = None
if self . profiler_target_forward_ct :
self . send_to_tokenizer . send_pyobj (
ProfileReqOutput ( success = True , message = " Succeeded. " )
)
2024-10-11 17:34:25 +08:00
2025-03-24 21:34:19 -07:00
def expert_distribution_handle ( self , recv_req : ExpertDistributionReq ) :
if recv_req == ExpertDistributionReq . START_RECORD :
expert_distribution_recorder . start_record ( )
elif recv_req == ExpertDistributionReq . STOP_RECORD :
expert_distribution_recorder . stop_record ( )
elif recv_req == ExpertDistributionReq . DUMP_RECORD :
expert_distribution_recorder . dump_record ( )
else :
raise ValueError ( " Unrecognized ExpertDistributionReq value " )
2025-03-25 16:17:03 +08:00
return ExpertDistributionReqOutput ( )
2025-03-24 21:34:19 -07:00
2025-01-19 12:13:27 +08:00
def open_session ( self , recv_req : OpenSessionReqInput ) :
2024-11-20 00:36:53 -08:00
# handle error
session_id = recv_req . session_id
if session_id in self . sessions :
logger . warning ( f " session id { session_id } already exist, cannot open. " )
2025-01-19 12:13:27 +08:00
return OpenSessionReqOutput ( session_id , False )
2024-12-29 02:10:27 -08:00
elif session_id is None :
2025-03-03 00:12:04 -08:00
logger . warning ( " session id is None, cannot open. " )
2025-01-19 12:13:27 +08:00
return OpenSessionReqOutput ( session_id , False )
2024-11-20 00:36:53 -08:00
else :
self . sessions [ session_id ] = Session (
recv_req . capacity_of_str_len , session_id
)
2025-01-19 12:13:27 +08:00
return OpenSessionReqOutput ( session_id , True )
2024-11-20 00:36:53 -08:00
def close_session ( self , recv_req : CloseSessionReqInput ) :
# handle error
session_id = recv_req . session_id
if session_id not in self . sessions :
logger . warning ( f " session id { session_id } does not exist, cannot delete. " )
else :
del self . sessions [ session_id ]
2024-09-29 02:36:12 -07:00
2025-03-03 00:12:04 -08:00
def is_health_check_generate_req ( recv_req ) :
return getattr ( recv_req , " rid " , " " ) . startswith ( " HEALTH_CHECK " )
2025-01-14 03:38:51 +08:00
def _export_static_state ( model ) :
return dict (
buffers = [
( name , buffer . detach ( ) . clone ( ) ) for name , buffer in model . named_buffers ( )
]
)
def _import_static_state ( model , static_params ) :
self_named_buffers = dict ( model . named_buffers ( ) )
for name , tensor in static_params [ " buffers " ] :
self_named_buffers [ name ] [ . . . ] = tensor
2024-09-29 02:36:12 -07:00
def run_scheduler_process (
server_args : ServerArgs ,
port_args : PortArgs ,
gpu_id : int ,
tp_rank : int ,
2024-10-11 07:22:48 -07:00
dp_rank : Optional [ int ] ,
2024-10-06 00:10:48 -07:00
pipe_writer ,
2024-09-29 02:36:12 -07:00
) :
2025-03-13 10:29:35 +08:00
# Generate the prefix
if dp_rank is None :
prefix = f " TP { tp_rank } "
else :
prefix = f " DP { dp_rank } TP { tp_rank } "
2025-03-03 00:12:04 -08:00
# Config the process
2025-03-17 05:13:16 -07:00
kill_itself_when_parent_died ( )
2025-03-13 10:29:35 +08:00
setproctitle . setproctitle ( f " sglang::scheduler { prefix . replace ( ' ' , ' _ ' ) } " )
2025-01-13 01:39:14 -08:00
faulthandler . enable ( )
2025-03-03 00:12:04 -08:00
parent_process = psutil . Process ( ) . parent ( )
2024-12-08 01:06:15 -08:00
2024-11-27 09:36:36 -08:00
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and " SGLANG_DP_RANK " in os . environ :
dp_rank = int ( os . environ [ " SGLANG_DP_RANK " ] )
2024-11-21 17:13:33 -08:00
2025-02-25 08:51:23 +08:00
# Configure the logger
2025-03-03 00:12:04 -08:00
configure_logger ( server_args , prefix = prefix )
2024-12-29 00:45:57 -08:00
suppress_other_loggers ( )
2024-10-11 07:22:48 -07:00
2024-12-29 00:45:57 -08:00
# Set cpu affinity to this gpu process
2024-12-06 05:49:29 -08:00
if get_bool_env_var ( " SGLANG_SET_CPU_AFFINITY " ) :
set_gpu_proc_affinity ( server_args . tp_size , server_args . nnodes , gpu_id )
2024-12-29 00:45:57 -08:00
# Create a scheduler and run the event loop
2024-09-29 02:36:12 -07:00
try :
2024-10-19 17:39:38 -07:00
scheduler = Scheduler ( server_args , port_args , gpu_id , tp_rank , dp_rank )
2024-11-22 15:10:10 -08:00
pipe_writer . send (
2025-01-19 06:14:19 +08:00
{
" status " : " ready " ,
" max_total_num_tokens " : scheduler . max_total_num_tokens ,
" max_req_input_len " : scheduler . max_req_input_len ,
}
2024-11-22 15:10:10 -08:00
)
2025-03-21 14:47:47 -07:00
disaggregation_mode : DisaggregationMode = scheduler . disaggregation_mode
if disaggregation_mode == DisaggregationMode . NULL :
if scheduler . enable_overlap :
scheduler . event_loop_overlap ( )
else :
scheduler . event_loop_normal ( )
elif disaggregation_mode == DisaggregationMode . PREFILL :
scheduler . event_loop_normal_disagg_prefill ( )
elif disaggregation_mode == DisaggregationMode . DECODE :
scheduler . event_loop_normal_disagg_decode ( )
2024-09-29 02:36:12 -07:00
except Exception :
2024-11-28 00:22:39 -08:00
traceback = get_exception_traceback ( )
logger . error ( f " Scheduler hit an exception: { traceback } " )
parent_process . send_signal ( signal . SIGQUIT )