2024-11-22 22:16:53 +08:00
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-09-29 02:36:12 -07:00
""" A scheduler that manages a tensor parallel GPU worker. """
2025-01-13 01:39:14 -08:00
import faulthandler
2024-09-29 02:36:12 -07:00
import logging
2024-09-29 17:42:45 -07:00
import os
2024-11-28 00:22:39 -08:00
import signal
2025-03-03 00:12:04 -08:00
import sys
2024-10-27 02:00:50 -07:00
import threading
2024-09-29 17:42:45 -07:00
import time
2025-03-03 00:12:04 -08:00
from collections import defaultdict , deque
2024-11-12 21:17:38 -08:00
from concurrent import futures
2025-01-16 12:51:11 -08:00
from dataclasses import dataclass
2025-01-18 19:37:30 -08:00
from http import HTTPStatus
2025-06-05 15:07:03 +08:00
from pathlib import Path
2024-10-12 21:35:30 -07:00
from types import SimpleNamespace
2025-01-16 12:51:11 -08:00
from typing import Dict , List , Optional , Tuple , Union
2024-09-29 02:36:12 -07:00
2024-11-28 00:22:39 -08:00
import psutil
2024-12-08 01:06:15 -08:00
import setproctitle
2024-09-29 17:42:45 -07:00
import torch
2024-09-29 02:36:12 -07:00
import zmq
2025-03-14 15:40:44 +08:00
from torch . distributed import barrier
2024-09-29 02:36:12 -07:00
2024-09-29 17:42:45 -07:00
from sglang . global_config import global_config
2024-11-24 04:47:10 -08:00
from sglang . srt . configs . model_config import ModelConfig
2025-06-19 00:56:37 -07:00
from sglang . srt . constants import GPU_MEMORY_TYPE_KV_CACHE , GPU_MEMORY_TYPE_WEIGHTS
2025-06-01 19:00:07 -07:00
from sglang . srt . constrained . base_grammar_backend import (
INVALID_GRAMMAR_OBJ ,
create_grammar_backend ,
)
2025-03-21 14:47:47 -07:00
from sglang . srt . disaggregation . decode import (
DecodePreallocQueue ,
DecodeTransferQueue ,
SchedulerDisaggregationDecodeMixin ,
)
2025-05-19 14:19:54 -07:00
from sglang . srt . disaggregation . kv_events import EventPublisherFactory , KVEventBatch
2025-03-21 14:47:47 -07:00
from sglang . srt . disaggregation . prefill import (
PrefillBootstrapQueue ,
SchedulerDisaggregationPrefillMixin ,
)
from sglang . srt . disaggregation . utils import (
DisaggregationMode ,
2025-05-23 14:29:20 -07:00
MetadataBuffers ,
2025-03-21 14:47:47 -07:00
ReqToMetadataIdxAllocator ,
2025-04-13 10:39:39 -07:00
TransferBackend ,
2025-05-21 21:44:25 -07:00
prepare_abort ,
2025-03-21 14:47:47 -07:00
)
2025-04-30 18:18:07 -07:00
from sglang . srt . distributed import get_pp_group , get_world_group
2025-05-02 13:38:59 +08:00
from sglang . srt . hf_transformers_utils import (
get_processor ,
get_tokenizer ,
get_tokenizer_from_processor ,
)
2025-01-16 11:15:00 -08:00
from sglang . srt . layers . dp_attention import compute_dp_attention_world_info
2024-09-29 17:42:45 -07:00
from sglang . srt . layers . logits_processor import LogitsProcessorOutput
2025-05-20 11:07:43 +08:00
from sglang . srt . managers . expert_distribution import (
get_global_expert_distribution_recorder ,
)
2024-09-29 17:42:45 -07:00
from sglang . srt . managers . io_struct import (
AbortReq ,
2024-11-20 00:36:53 -08:00
CloseSessionReqInput ,
2025-03-24 21:34:19 -07:00
ExpertDistributionReq ,
2025-03-25 16:17:03 +08:00
ExpertDistributionReqOutput ,
2025-04-21 09:15:03 +08:00
FlushCacheReqInput ,
FlushCacheReqOutput ,
2025-03-03 00:12:04 -08:00
GetInternalStateReq ,
GetInternalStateReqOutput ,
2024-11-29 23:36:38 -08:00
GetWeightsByNameReqInput ,
GetWeightsByNameReqOutput ,
2025-03-03 00:12:04 -08:00
HealthCheckOutput ,
2024-12-01 23:23:18 -08:00
InitWeightsUpdateGroupReqInput ,
InitWeightsUpdateGroupReqOutput ,
2024-11-20 00:36:53 -08:00
OpenSessionReqInput ,
OpenSessionReqOutput ,
2024-10-11 17:34:25 +08:00
ProfileReq ,
2025-03-03 00:12:04 -08:00
ProfileReqOutput ,
ProfileReqType ,
2025-01-14 03:38:51 +08:00
ReleaseMemoryOccupationReqInput ,
ReleaseMemoryOccupationReqOutput ,
ResumeMemoryOccupationReqInput ,
ResumeMemoryOccupationReqOutput ,
2025-03-14 15:40:44 +08:00
RpcReqInput ,
RpcReqOutput ,
2025-03-03 00:12:04 -08:00
SetInternalStateReq ,
SetInternalStateReqOutput ,
2025-05-08 16:03:08 +08:00
SlowDownReqInput ,
SlowDownReqOutput ,
2024-09-29 17:42:45 -07:00
TokenizedEmbeddingReqInput ,
TokenizedGenerateReqInput ,
2024-11-29 17:17:00 -08:00
UpdateWeightFromDiskReqInput ,
UpdateWeightFromDiskReqOutput ,
2024-12-01 23:23:18 -08:00
UpdateWeightsFromDistributedReqInput ,
UpdateWeightsFromDistributedReqOutput ,
2024-12-29 05:30:27 +08:00
UpdateWeightsFromTensorReqInput ,
UpdateWeightsFromTensorReqOutput ,
2024-09-29 17:42:45 -07:00
)
2025-05-22 20:32:41 -07:00
from sglang . srt . managers . mm_utils import init_embedding_cache
2024-09-29 17:42:45 -07:00
from sglang . srt . managers . schedule_batch import (
FINISH_ABORT ,
2025-03-25 11:08:40 +08:00
MultimodalInputs ,
2024-09-29 17:42:45 -07:00
Req ,
ScheduleBatch ,
2024-10-19 23:19:26 -07:00
global_server_args_dict ,
2024-09-29 17:42:45 -07:00
)
2024-10-04 18:00:18 -07:00
from sglang . srt . managers . schedule_policy import (
AddReqResult ,
PrefillAdder ,
SchedulePolicy ,
)
2025-03-12 16:21:49 -07:00
from sglang . srt . managers . scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin ,
)
2024-11-20 00:36:53 -08:00
from sglang . srt . managers . session_controller import Session
2024-09-30 02:41:11 -07:00
from sglang . srt . managers . tp_worker import TpModelWorker
2024-10-20 00:29:29 -07:00
from sglang . srt . managers . tp_worker_overlap_thread import TpModelWorkerClient
2025-03-27 04:21:25 +08:00
from sglang . srt . managers . utils import validate_input_length
2024-09-29 17:42:45 -07:00
from sglang . srt . mem_cache . chunk_cache import ChunkCache
2025-02-23 21:56:30 -08:00
from sglang . srt . mem_cache . hiradix_cache import HiRadixCache
2024-09-29 17:42:45 -07:00
from sglang . srt . mem_cache . radix_cache import RadixCache
2024-11-10 04:39:32 -08:00
from sglang . srt . metrics . collector import SchedulerMetricsCollector , SchedulerStats
2025-05-10 21:54:46 -07:00
from sglang . srt . model_executor . forward_batch_info import ForwardMode , PPProxyTensors
2025-04-08 12:46:47 +08:00
from sglang . srt . reasoning_parser import ReasoningParser
2024-09-29 02:36:12 -07:00
from sglang . srt . server_args import PortArgs , ServerArgs
2025-01-02 02:09:08 -08:00
from sglang . srt . speculative . spec_info import SpeculativeAlgorithm
2025-01-13 14:24:00 -08:00
from sglang . srt . torch_memory_saver_adapter import TorchMemorySaverAdapter
2025-05-25 08:39:07 +08:00
from sglang . srt . two_batch_overlap import TboDPAttentionPreparer
2024-09-29 17:42:45 -07:00
from sglang . srt . utils import (
2025-05-25 08:39:07 +08:00
DeepEPMode ,
2025-03-17 13:54:16 +08:00
DynamicGradMode ,
2024-09-29 17:42:45 -07:00
broadcast_pyobj ,
configure_logger ,
2025-05-10 21:54:46 -07:00
disable_request_logging ,
2025-05-31 15:53:55 -07:00
get_available_gpu_memory ,
2024-11-27 02:52:46 -08:00
get_bool_env_var ,
2024-10-25 23:07:07 -07:00
get_zmq_socket ,
2025-03-12 22:22:39 -07:00
kill_itself_when_parent_died ,
2025-04-30 18:18:07 -07:00
point_to_point_pyobj ,
2025-03-03 00:12:04 -08:00
pyspy_dump_schedulers ,
2024-11-27 02:52:46 -08:00
set_gpu_proc_affinity ,
2024-09-29 17:42:45 -07:00
set_random_seed ,
suppress_other_loggers ,
)
2025-01-19 12:13:27 +08:00
from sglang . utils import TypeBasedDispatcher , get_exception_traceback
2024-09-29 02:36:12 -07:00
logger = logging . getLogger ( __name__ )
2024-12-29 00:45:57 -08:00
# Test retract decode for debugging purposes
2025-03-03 00:12:04 -08:00
TEST_RETRACT = get_bool_env_var ( " SGLANG_TEST_RETRACT " )
RECORD_STEP_TIME = get_bool_env_var ( " SGLANG_RECORD_STEP_TIME " )
2025-05-11 08:36:16 -07:00
GRAMMAR_TIMEOUT = float ( os . environ . get ( " SGLANG_GRAMMAR_TIMEOUT " , 300 ) )
2024-10-13 20:32:37 -07:00
2024-09-29 02:36:12 -07:00
2025-01-16 12:51:11 -08:00
@dataclass
class GenerationBatchResult :
2025-04-30 18:18:07 -07:00
logits_output : Optional [ LogitsProcessorOutput ]
pp_hidden_states_proxy_tensors : Optional [ torch . Tensor ]
next_token_ids : Optional [ List [ int ] ]
2025-03-03 00:12:04 -08:00
extend_input_len_per_req : List [ int ]
extend_logprob_start_len_per_req : List [ int ]
2025-01-16 12:51:11 -08:00
bid : int
2025-05-12 00:17:33 -07:00
can_run_cuda_graph : bool
2025-01-16 12:51:11 -08:00
@dataclass
class EmbeddingBatchResult :
embeddings : torch . Tensor
bid : int
2025-06-13 00:58:22 +03:00
class IdleSleeper :
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when sglang does nothing . This would lead not only
to power savings , but also to more CPU thermal headroom when a request
eventually comes . This is important in cases when multiple GPUs are connected
as each GPU would otherwise pin one thread at 100 % CPU usage .
The simplest solution is to use zmq . Poller on all sockets that may receive
data that needs handling immediately .
"""
def __init__ ( self , sockets ) :
self . poller = zmq . Poller ( )
for s in sockets :
self . poller . register ( s , zmq . POLLIN )
def maybe_sleep ( self ) :
self . poller . poll ( 1000 )
2025-03-21 14:47:47 -07:00
class Scheduler (
SchedulerOutputProcessorMixin ,
SchedulerDisaggregationDecodeMixin ,
SchedulerDisaggregationPrefillMixin ,
) :
2024-09-29 02:36:12 -07:00
""" A scheduler that manages a tensor parallel GPU worker. """
def __init__ (
self ,
server_args : ServerArgs ,
port_args : PortArgs ,
gpu_id : int ,
tp_rank : int ,
2025-04-30 18:18:07 -07:00
pp_rank : int ,
2024-10-19 17:39:38 -07:00
dp_rank : Optional [ int ] ,
2024-09-29 02:36:12 -07:00
) :
# Parse args
2024-09-29 17:42:45 -07:00
self . server_args = server_args
2024-09-29 02:36:12 -07:00
self . tp_rank = tp_rank
2025-04-30 18:18:07 -07:00
self . pp_rank = pp_rank
2024-09-29 02:36:12 -07:00
self . tp_size = server_args . tp_size
2025-04-30 18:18:07 -07:00
self . pp_size = server_args . pp_size
self . dp_size = server_args . dp_size
2024-09-29 17:42:45 -07:00
self . schedule_policy = server_args . schedule_policy
self . lora_paths = server_args . lora_paths
self . max_loras_per_batch = server_args . max_loras_per_batch
2024-11-19 22:07:58 -08:00
self . enable_overlap = not server_args . disable_overlap_schedule
2024-10-25 18:51:59 -07:00
self . skip_tokenizer_init = server_args . skip_tokenizer_init
2024-11-10 04:39:32 -08:00
self . enable_metrics = server_args . enable_metrics
2025-05-19 14:19:54 -07:00
self . enable_kv_cache_events = server_args . kv_events_config is not None
2025-03-03 00:12:04 -08:00
self . stream_interval = server_args . stream_interval
2025-01-02 02:09:08 -08:00
self . spec_algorithm = SpeculativeAlgorithm . from_string (
server_args . speculative_algorithm
)
2025-03-03 00:12:04 -08:00
self . gpu_id = gpu_id
self . enable_hierarchical_cache = server_args . enable_hierarchical_cache
2025-03-12 22:22:39 -07:00
self . page_size = server_args . page_size
2025-05-13 02:51:39 -04:00
self . dp_size = server_args . dp_size
self . attn_tp_rank , self . attn_tp_size , self . attn_dp_rank = (
2025-01-16 11:15:00 -08:00
compute_dp_attention_world_info (
server_args . enable_dp_attention ,
self . tp_rank ,
self . tp_size ,
self . dp_size ,
)
)
2025-01-19 17:10:29 -08:00
# Init inter-process communication
context = zmq . Context ( 2 )
2025-06-13 00:58:22 +03:00
self . idle_sleeper = None
2025-04-30 18:18:07 -07:00
if self . pp_rank == 0 and self . attn_tp_rank == 0 :
2024-10-25 23:07:07 -07:00
self . recv_from_tokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PULL , port_args . scheduler_input_ipc_name , False
2024-10-25 23:07:07 -07:00
)
2024-11-15 05:02:44 -08:00
self . send_to_tokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PUSH , port_args . tokenizer_ipc_name , False
2024-11-15 05:02:44 -08:00
)
2024-09-29 02:36:12 -07:00
2024-10-25 18:51:59 -07:00
if server_args . skip_tokenizer_init :
2025-02-24 14:47:59 -08:00
# Directly send to the TokenizerManager
2024-10-25 23:07:07 -07:00
self . send_to_detokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PUSH , port_args . tokenizer_ipc_name , False
2024-10-25 18:51:59 -07:00
)
else :
2024-12-29 00:45:57 -08:00
# Send to the DetokenizerManager
2024-10-25 23:07:07 -07:00
self . send_to_detokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PUSH , port_args . detokenizer_ipc_name , False
2024-10-25 18:51:59 -07:00
)
2025-03-14 15:40:44 +08:00
self . recv_from_rpc = get_zmq_socket (
context , zmq . DEALER , port_args . rpc_ipc_name , False
)
2025-06-13 00:58:22 +03:00
if self . server_args . sleep_on_idle :
self . idle_sleeper = IdleSleeper (
[
self . recv_from_tokenizer ,
self . recv_from_rpc ,
]
)
2024-09-29 02:36:12 -07:00
else :
2024-10-12 21:35:30 -07:00
self . recv_from_tokenizer = None
2025-03-14 15:40:44 +08:00
self . recv_from_rpc = None
2024-11-22 15:46:16 -08:00
self . send_to_tokenizer = SimpleNamespace ( send_pyobj = lambda x : None )
self . send_to_detokenizer = SimpleNamespace ( send_pyobj = lambda x : None )
2024-09-29 17:42:45 -07:00
# Init tokenizer
2025-03-06 00:13:20 -08:00
self . init_tokenizer ( )
2024-09-29 02:36:12 -07:00
2025-04-08 12:46:47 +08:00
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
if self . server_args . reasoning_parser and self . tokenizer :
reasoning_parser = ReasoningParser (
model_type = self . server_args . reasoning_parser , stream_reasoning = False
)
self . tokenizer . think_end_id = self . tokenizer . encode (
reasoning_parser . detector . think_end_token , add_special_tokens = False
) [ 0 ]
2024-11-19 22:07:58 -08:00
# Check whether overlap can be enabled
if not self . is_generation :
self . enable_overlap = False
logger . info ( " Overlap scheduler is disabled for embedding models. " )
2024-11-27 23:44:33 -08:00
2024-09-29 17:42:45 -07:00
# Launch a tensor parallel worker
2024-10-20 18:17:41 -07:00
if self . enable_overlap :
2024-10-20 00:29:29 -07:00
TpWorkerClass = TpModelWorkerClient
2024-10-19 23:19:26 -07:00
else :
TpWorkerClass = TpModelWorker
2024-10-20 00:29:29 -07:00
2024-10-19 23:19:26 -07:00
self . tp_worker = TpWorkerClass (
2024-10-19 17:39:38 -07:00
server_args = server_args ,
2024-09-29 02:36:12 -07:00
gpu_id = gpu_id ,
tp_rank = tp_rank ,
2025-04-30 18:18:07 -07:00
pp_rank = pp_rank ,
2024-10-19 17:39:38 -07:00
dp_rank = dp_rank ,
2024-10-11 07:22:48 -07:00
nccl_port = port_args . nccl_port ,
2024-09-29 02:36:12 -07:00
)
2024-09-29 17:42:45 -07:00
2025-03-03 00:12:04 -08:00
# Launch a draft worker for speculative decoding
2025-01-02 02:09:08 -08:00
if self . spec_algorithm . is_eagle ( ) :
from sglang . srt . speculative . eagle_worker import EAGLEWorker
self . draft_worker = EAGLEWorker (
gpu_id = gpu_id ,
tp_rank = tp_rank ,
server_args = server_args ,
nccl_port = port_args . nccl_port ,
target_worker = self . tp_worker ,
dp_rank = dp_rank ,
)
else :
self . draft_worker = None
2024-10-06 00:10:48 -07:00
# Get token and memory info from the model worker
2024-09-29 17:42:45 -07:00
(
self . max_total_num_tokens ,
self . max_prefill_tokens ,
self . max_running_requests ,
2024-10-21 16:12:04 -07:00
self . max_req_len ,
2024-09-29 17:42:45 -07:00
self . max_req_input_len ,
self . random_seed ,
2024-10-19 17:39:38 -07:00
self . device ,
2024-10-19 23:19:26 -07:00
worker_global_server_args_dict ,
_ ,
_ ,
_ ,
) = self . tp_worker . get_worker_info ( )
2025-04-30 18:18:07 -07:00
if global_server_args_dict [ " max_micro_batch_size " ] is None :
global_server_args_dict [ " max_micro_batch_size " ] = max (
self . max_running_requests / / server_args . pp_size , 1
)
self . tp_group = self . tp_worker . get_tp_group ( )
self . tp_cpu_group = self . tp_group . cpu_group
self . attn_tp_group = self . tp_worker . get_attention_tp_group ( )
2025-01-16 11:15:00 -08:00
self . attn_tp_cpu_group = self . tp_worker . get_attention_tp_cpu_group ( )
2025-04-30 18:18:07 -07:00
self . pp_group = get_pp_group ( )
self . world_group = get_world_group ( )
2024-10-19 17:39:38 -07:00
self . pad_input_ids_func = self . tp_worker . get_pad_input_ids_func ( )
2024-10-19 23:19:26 -07:00
global_server_args_dict . update ( worker_global_server_args_dict )
2024-09-29 17:42:45 -07:00
set_random_seed ( self . random_seed )
2025-03-03 00:12:04 -08:00
2024-09-29 17:42:45 -07:00
# Print debug info
2025-05-12 00:17:33 -07:00
if tp_rank == 0 :
2025-05-31 15:53:55 -07:00
avail_mem = get_available_gpu_memory (
self . device , self . gpu_id , empty_cache = False
)
2025-05-12 00:17:33 -07:00
logger . info (
f " max_total_num_tokens= { self . max_total_num_tokens } , "
f " chunked_prefill_size= { server_args . chunked_prefill_size } , "
f " max_prefill_tokens= { self . max_prefill_tokens } , "
f " max_running_requests= { self . max_running_requests } , "
2025-05-31 15:53:55 -07:00
f " context_len= { self . model_config . context_len } , "
f " available_gpu_mem= { avail_mem : .2f } GB "
2025-05-12 00:17:33 -07:00
)
2024-09-29 17:42:45 -07:00
2025-03-12 22:22:39 -07:00
# Init memory pool and cache
2025-03-06 00:13:20 -08:00
self . init_memory_pool_and_cache ( )
2024-09-29 17:42:45 -07:00
# Init running status
self . waiting_queue : List [ Req ] = [ ]
2024-11-19 15:04:43 -08:00
# The running decoding batch for continuous batching
2025-03-12 22:22:39 -07:00
self . running_batch : ScheduleBatch = ScheduleBatch ( reqs = [ ] , batch_is_full = False )
2024-11-19 15:04:43 -08:00
# The current forward batch
2024-10-16 01:33:20 -07:00
self . cur_batch : Optional [ ScheduleBatch ] = None
2025-03-12 22:22:39 -07:00
# The last forward batch
2024-11-19 15:04:43 -08:00
self . last_batch : Optional [ ScheduleBatch ] = None
2024-10-27 02:00:50 -07:00
self . forward_ct = 0
self . forward_ct_decode = 0
2024-09-29 17:42:45 -07:00
self . num_generated_tokens = 0
2025-06-16 23:30:26 +08:00
self . last_prefill_tokens = 0
2025-05-17 16:49:18 -07:00
self . last_decode_stats_tic = time . perf_counter ( )
self . last_prefill_stats_tic = time . perf_counter ( )
2025-03-03 00:12:04 -08:00
self . return_health_check_ct = 0
2024-12-06 05:49:29 -08:00
self . current_stream = torch . get_device_module ( self . device ) . current_stream ( )
2025-01-17 13:22:53 +08:00
if self . device == " cpu " :
self . current_stream . synchronize = lambda : None # No-op for CPU
2025-05-31 15:53:55 -07:00
self . forward_sleep_time = None
2024-12-06 05:49:29 -08:00
2025-03-06 00:13:20 -08:00
# Init session info
2024-12-22 06:25:57 -08:00
self . sessions : Dict [ str , Session ] = { }
2024-09-29 17:42:45 -07:00
# Init chunked prefill
self . chunked_prefill_size = server_args . chunked_prefill_size
2024-11-29 16:03:32 -08:00
if self . chunked_prefill_size < = 0 : # -1 means disable
self . chunked_prefill_size = None
2025-03-03 00:12:04 -08:00
self . chunked_req = None
2024-09-29 17:42:45 -07:00
self . is_mixed_chunk = (
self . chunked_prefill_size is not None and server_args . enable_mixed_chunk
)
2024-11-12 21:17:38 -08:00
# Init the grammar backend for constrained generation
2024-11-13 01:45:28 +09:00
self . grammar_queue : List [ Req ] = [ ]
2024-10-25 10:24:44 -07:00
if not server_args . skip_tokenizer_init :
2025-01-19 17:10:29 -08:00
self . grammar_backend = create_grammar_backend (
server_args , self . tokenizer , self . model_config . vocab_size
)
2024-11-12 21:17:38 -08:00
else :
self . grammar_backend = None
2024-09-29 17:42:45 -07:00
2025-03-06 00:13:20 -08:00
# Init schedule policy and new token estimation
2025-03-12 11:22:35 -07:00
self . policy = SchedulePolicy (
2025-03-12 22:22:39 -07:00
self . schedule_policy ,
self . tree_cache ,
self . enable_hierarchical_cache ,
2025-03-12 11:22:35 -07:00
)
2024-10-25 10:24:44 -07:00
assert (
server_args . schedule_conservativeness > = 0
) , " Invalid schedule_conservativeness "
2024-10-26 16:39:41 -07:00
self . init_new_token_ratio = min (
global_config . default_init_new_token_ratio
2024-10-25 10:24:44 -07:00
* server_args . schedule_conservativeness ,
1.0 ,
2024-09-29 17:42:45 -07:00
)
2024-10-26 16:39:41 -07:00
self . min_new_token_ratio = min (
self . init_new_token_ratio
* global_config . default_min_new_token_ratio_factor ,
1.0 ,
)
self . new_token_ratio_decay = (
self . init_new_token_ratio - self . min_new_token_ratio
) / global_config . default_new_token_ratio_decay_steps
self . new_token_ratio = self . init_new_token_ratio
2024-10-27 02:00:50 -07:00
# Init watchdog thread
self . watchdog_timeout = server_args . watchdog_timeout
t = threading . Thread ( target = self . watchdog_thread , daemon = True )
t . start ( )
2024-11-28 00:22:39 -08:00
self . parent_process = psutil . Process ( ) . parent ( )
2025-01-14 03:38:51 +08:00
self . memory_saver_adapter = TorchMemorySaverAdapter . create (
enable = server_args . enable_memory_saver
)
2024-10-16 11:20:17 -07:00
# Init profiler
2025-03-03 00:12:04 -08:00
self . torch_profiler = None
self . torch_profiler_output_dir : Optional [ str ] = None
2025-03-28 13:21:13 +08:00
self . profiler_activities : Optional [ List [ str ] ] = None
2025-06-03 02:17:22 +08:00
self . profile_id : Optional [ str ] = None
2025-03-03 00:12:04 -08:00
self . profiler_target_forward_ct : Optional [ int ] = None
2025-05-31 15:53:55 -07:00
self . profiler_target_prefill_ct : Optional [ int ] = None
self . profiler_target_decode_ct : Optional [ int ] = None
self . profiler_prefill_ct : Optional [ int ] = None
self . profiler_decode_ct : Optional [ int ] = None
self . profile_by_stage : bool = False
self . profile_steps : Optional [ int ] = None
self . profile_in_progress : bool = False
self . rpd_profiler = None
2025-05-08 16:03:08 +08:00
2024-11-06 12:42:53 +08:00
# Init metrics stats
2025-03-06 00:13:20 -08:00
self . init_metrics ( )
2025-05-19 14:19:54 -07:00
self . init_kv_events ( server_args . kv_events_config )
2024-10-11 17:34:25 +08:00
2025-01-19 17:10:29 -08:00
# Init request dispatcher
self . _request_dispatcher = TypeBasedDispatcher (
2025-01-19 12:13:27 +08:00
[
( TokenizedGenerateReqInput , self . handle_generate_request ) ,
( TokenizedEmbeddingReqInput , self . handle_embedding_request ) ,
2025-04-21 09:15:03 +08:00
( FlushCacheReqInput , self . flush_cache_wrapped ) ,
2025-01-19 12:13:27 +08:00
( AbortReq , self . abort_request ) ,
2025-03-03 00:12:04 -08:00
( OpenSessionReqInput , self . open_session ) ,
( CloseSessionReqInput , self . close_session ) ,
2025-01-19 12:13:27 +08:00
( UpdateWeightFromDiskReqInput , self . update_weights_from_disk ) ,
( InitWeightsUpdateGroupReqInput , self . init_weights_update_group ) ,
(
UpdateWeightsFromDistributedReqInput ,
self . update_weights_from_distributed ,
) ,
( UpdateWeightsFromTensorReqInput , self . update_weights_from_tensor ) ,
( GetWeightsByNameReqInput , self . get_weights_by_name ) ,
2025-03-03 00:12:04 -08:00
( ReleaseMemoryOccupationReqInput , self . release_memory_occupation ) ,
( ResumeMemoryOccupationReqInput , self . resume_memory_occupation ) ,
2025-05-08 16:03:08 +08:00
( SlowDownReqInput , self . slow_down ) ,
2025-01-19 12:13:27 +08:00
( ProfileReq , self . profile ) ,
2025-03-03 00:12:04 -08:00
( GetInternalStateReq , self . get_internal_state ) ,
2025-03-06 00:13:20 -08:00
( SetInternalStateReq , self . set_internal_state ) ,
2025-03-14 15:40:44 +08:00
( RpcReqInput , self . handle_rpc_request ) ,
2025-03-24 21:34:19 -07:00
( ExpertDistributionReq , self . expert_distribution_handle ) ,
2025-01-19 12:13:27 +08:00
]
)
2025-03-21 14:47:47 -07:00
self . disaggregation_mode = DisaggregationMode (
self . server_args . disaggregation_mode
)
self . init_disaggregation ( )
2025-06-13 00:58:22 +03:00
def maybe_sleep_on_idle ( self ) :
if self . idle_sleeper is not None :
self . idle_sleeper . maybe_sleep ( )
2025-03-06 00:13:20 -08:00
def init_tokenizer ( self ) :
server_args = self . server_args
2024-10-27 02:00:50 -07:00
2025-05-08 16:02:43 +08:00
self . model_config = ModelConfig . from_server_args ( server_args )
2025-03-06 00:13:20 -08:00
self . is_generation = self . model_config . is_generation
2025-03-03 00:12:04 -08:00
2025-03-06 00:13:20 -08:00
if server_args . skip_tokenizer_init :
self . tokenizer = self . processor = None
else :
if self . model_config . is_multimodal :
self . processor = get_processor (
server_args . tokenizer_path ,
tokenizer_mode = server_args . tokenizer_mode ,
trust_remote_code = server_args . trust_remote_code ,
revision = server_args . revision ,
2025-04-12 12:46:58 +08:00
use_fast = not server_args . disable_fast_image_processor ,
2025-03-06 00:13:20 -08:00
)
2025-05-02 13:38:59 +08:00
self . tokenizer = get_tokenizer_from_processor ( self . processor )
2025-03-06 00:13:20 -08:00
else :
self . tokenizer = get_tokenizer (
server_args . tokenizer_path ,
tokenizer_mode = server_args . tokenizer_mode ,
trust_remote_code = server_args . trust_remote_code ,
revision = server_args . revision ,
)
def init_memory_pool_and_cache ( self ) :
server_args = self . server_args
self . req_to_token_pool , self . token_to_kv_pool_allocator = (
self . tp_worker . get_memory_pool ( )
)
if (
server_args . chunked_prefill_size is not None
and server_args . disable_radix_cache
) :
self . tree_cache = ChunkCache (
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool_allocator = self . token_to_kv_pool_allocator ,
2025-05-05 16:02:55 -07:00
page_size = self . page_size ,
2025-03-06 00:13:20 -08:00
)
else :
if self . enable_hierarchical_cache :
self . tree_cache = HiRadixCache (
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool_allocator = self . token_to_kv_pool_allocator ,
2025-06-20 11:34:36 +08:00
tp_cache_group = (
self . attn_tp_cpu_group
if self . server_args . enable_dp_attention
else self . tp_cpu_group
) ,
2025-03-13 15:50:49 -07:00
page_size = self . page_size ,
2025-03-17 17:45:00 -07:00
hicache_ratio = server_args . hicache_ratio ,
2025-04-20 23:08:30 -07:00
hicache_size = server_args . hicache_size ,
hicache_write_policy = server_args . hicache_write_policy ,
2025-03-06 00:13:20 -08:00
)
2025-06-17 17:44:57 -07:00
self . tp_worker . register_hicache_layer_transfer_counter (
self . tree_cache . cache_controller . layer_done_counter
)
2025-03-06 00:13:20 -08:00
else :
self . tree_cache = RadixCache (
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool_allocator = self . token_to_kv_pool_allocator ,
2025-03-12 22:22:39 -07:00
page_size = self . page_size ,
2025-03-06 00:13:20 -08:00
disable = server_args . disable_radix_cache ,
2025-05-19 14:19:54 -07:00
enable_kv_cache_events = self . enable_kv_cache_events ,
2025-03-06 00:13:20 -08:00
)
self . decode_mem_cache_buf_multiplier = (
1
if self . spec_algorithm . is_none ( )
else (
server_args . speculative_num_draft_tokens
+ (
server_args . speculative_eagle_topk
* server_args . speculative_num_steps
)
)
2025-03-03 00:12:04 -08:00
)
2025-03-06 00:13:20 -08:00
def init_metrics ( self ) :
self . last_gen_throughput : float = 0.0
2025-03-12 22:22:39 -07:00
self . last_input_throughput : float = 0.0
2025-03-06 00:13:20 -08:00
self . step_time_dict = defaultdict ( list ) # Dict[batch size -> step time]
self . spec_num_total_accepted_tokens = 0
self . spec_num_total_forward_ct = 0
self . cum_spec_accept_length = 0
self . cum_spec_accept_count = 0
self . stats = SchedulerStats ( )
if self . enable_metrics :
engine_type = " unified "
self . metrics_collector = SchedulerMetricsCollector (
labels = {
" model_name " : self . server_args . served_model_name ,
" engine_type " : engine_type ,
} ,
)
2024-10-27 02:00:50 -07:00
2025-05-19 14:19:54 -07:00
def init_kv_events ( self , kv_events_config : Optional [ str ] ) :
if self . enable_kv_cache_events :
2025-06-04 15:29:34 -07:00
self . kv_event_publisher = EventPublisherFactory . create (
kv_events_config , self . attn_dp_rank
)
2025-05-19 14:19:54 -07:00
2025-03-21 14:47:47 -07:00
def init_disaggregation ( self ) :
2025-04-13 10:39:39 -07:00
self . transfer_backend = TransferBackend (
self . server_args . disaggregation_transfer_backend
)
2025-03-21 14:47:47 -07:00
if (
self . disaggregation_mode == DisaggregationMode . DECODE
) : # *2 for the headroom.
buffer_size = ( self . req_to_token_pool . size ) * 2
2025-06-14 15:59:54 -07:00
self . req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator (
2025-03-21 14:47:47 -07:00
buffer_size
)
2025-06-18 08:21:37 +08:00
self . disagg_metadata_buffers = MetadataBuffers (
buffer_size ,
2025-06-20 02:22:47 +08:00
hidden_size = self . model_config . hf_text_config . hidden_size ,
dtype = self . model_config . dtype ,
2025-06-18 08:21:37 +08:00
custom_mem_pool = self . token_to_kv_pool_allocator . get_kvcache ( ) . maybe_get_custom_mem_pool ( ) ,
)
2025-03-21 14:47:47 -07:00
# The decode requests polling kv cache
self . disagg_decode_transfer_queue = DecodeTransferQueue (
2025-04-15 23:04:41 -07:00
gloo_group = self . attn_tp_cpu_group ,
2025-06-14 15:59:54 -07:00
req_to_metadata_buffer_idx_allocator = self . req_to_metadata_buffer_idx_allocator ,
2025-06-14 19:48:05 -07:00
tp_rank = self . tp_rank ,
2025-05-23 14:29:20 -07:00
metadata_buffers = self . disagg_metadata_buffers ,
2025-05-22 20:32:03 -07:00
scheduler = self ,
tree_cache = self . tree_cache ,
2025-03-21 14:47:47 -07:00
)
# The decode requests pending for pre-allocation
self . disagg_decode_prealloc_queue = DecodePreallocQueue (
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool_allocator = self . token_to_kv_pool_allocator ,
2025-05-23 12:03:05 -07:00
draft_token_to_kv_pool = (
None
if self . draft_worker is None
else self . draft_worker . model_runner . token_to_kv_pool
) ,
2025-06-14 15:59:54 -07:00
req_to_metadata_buffer_idx_allocator = self . req_to_metadata_buffer_idx_allocator ,
2025-05-23 14:29:20 -07:00
metadata_buffers = self . disagg_metadata_buffers ,
2025-03-21 14:47:47 -07:00
scheduler = self ,
transfer_queue = self . disagg_decode_transfer_queue ,
tree_cache = self . tree_cache ,
2025-04-15 23:04:41 -07:00
gloo_group = self . attn_tp_cpu_group ,
2025-03-21 14:47:47 -07:00
tp_rank = self . tp_rank ,
tp_size = self . tp_size ,
2025-06-14 19:48:05 -07:00
dp_size = self . server_args . dp_size ,
gpu_id = self . gpu_id ,
2025-03-21 14:47:47 -07:00
bootstrap_port = self . server_args . disaggregation_bootstrap_port ,
2025-06-14 19:48:05 -07:00
max_total_num_tokens = self . max_total_num_tokens ,
prefill_pp_size = self . server_args . disaggregation_prefill_pp ,
2025-06-15 11:51:03 -07:00
num_reserved_decode_tokens = self . server_args . num_reserved_decode_tokens ,
2025-04-13 10:39:39 -07:00
transfer_backend = self . transfer_backend ,
2025-03-21 14:47:47 -07:00
)
2025-04-25 17:25:45 +08:00
# Metric for pre-allocation
self . num_tokens_pre_allocated = 0
2025-03-21 14:47:47 -07:00
elif self . disaggregation_mode == DisaggregationMode . PREFILL :
# *2 for the headroom.
buffer_size = self . max_running_requests * 2
2025-06-14 15:59:54 -07:00
self . req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator (
2025-03-21 14:47:47 -07:00
buffer_size
)
2025-06-18 08:21:37 +08:00
self . disagg_metadata_buffers = MetadataBuffers (
buffer_size ,
2025-06-20 02:22:47 +08:00
hidden_size = self . model_config . hf_text_config . hidden_size ,
dtype = self . model_config . dtype ,
2025-06-18 08:21:37 +08:00
custom_mem_pool = self . token_to_kv_pool_allocator . get_kvcache ( ) . maybe_get_custom_mem_pool ( ) ,
)
2025-03-21 14:47:47 -07:00
2025-04-25 17:25:45 +08:00
self . disagg_prefill_bootstrap_queue = PrefillBootstrapQueue (
2025-03-21 14:47:47 -07:00
token_to_kv_pool = self . token_to_kv_pool_allocator . get_kvcache ( ) ,
2025-05-23 12:03:05 -07:00
draft_token_to_kv_pool = (
None
if self . draft_worker is None
else self . draft_worker . model_runner . token_to_kv_pool
) ,
2025-06-14 15:59:54 -07:00
req_to_metadata_buffer_idx_allocator = self . req_to_metadata_buffer_idx_allocator ,
2025-05-23 14:29:20 -07:00
metadata_buffers = self . disagg_metadata_buffers ,
2025-03-21 14:47:47 -07:00
tp_rank = self . tp_rank ,
tp_size = self . tp_size ,
2025-06-14 15:59:54 -07:00
gpu_id = self . gpu_id ,
2025-03-21 14:47:47 -07:00
bootstrap_port = self . server_args . disaggregation_bootstrap_port ,
2025-04-15 23:04:41 -07:00
gloo_group = self . attn_tp_cpu_group ,
2025-06-14 15:59:54 -07:00
max_total_num_tokens = self . max_total_num_tokens ,
decode_tp_size = self . server_args . disaggregation_decode_tp ,
decode_dp_size = self . server_args . disaggregation_decode_dp ,
2025-04-15 19:29:31 +08:00
scheduler = self ,
2025-06-14 15:59:54 -07:00
pp_rank = self . pp_rank ,
pp_size = self . pp_size ,
transfer_backend = self . transfer_backend ,
2025-03-21 14:47:47 -07:00
)
# The prefill requests that are in the middle of kv sending
2025-04-14 01:25:30 +08:00
self . disagg_prefill_inflight_queue : List [ Req ] = [ ]
2025-03-21 14:47:47 -07:00
2025-03-17 13:54:16 +08:00
@DynamicGradMode ( )
2024-10-13 19:54:02 -07:00
def event_loop_normal ( self ) :
2024-11-19 15:04:43 -08:00
""" A normal scheduler loop. """
2024-09-29 02:36:12 -07:00
while True :
2024-10-06 03:24:04 -07:00
recv_reqs = self . recv_requests ( )
self . process_input_requests ( recv_reqs )
2024-09-29 02:36:12 -07:00
2024-10-12 21:35:30 -07:00
batch = self . get_next_batch_to_run ( )
2024-10-27 02:00:50 -07:00
self . cur_batch = batch
2024-10-12 21:35:30 -07:00
if batch :
result = self . run_batch ( batch )
self . process_batch_result ( batch , result )
2024-10-14 05:25:00 -07:00
else :
2025-03-12 22:22:39 -07:00
# When the server is idle, do self-check and re-init some states
2024-10-14 05:25:00 -07:00
self . check_memory ( )
2024-10-26 16:39:41 -07:00
self . new_token_ratio = self . init_new_token_ratio
2025-06-13 00:58:22 +03:00
self . maybe_sleep_on_idle ( )
2024-10-12 21:35:30 -07:00
self . last_batch = batch
2024-09-29 02:36:12 -07:00
2025-03-17 13:54:16 +08:00
@DynamicGradMode ( )
2024-10-16 01:33:20 -07:00
def event_loop_overlap ( self ) :
2024-10-19 23:19:26 -07:00
""" A scheduler loop that overlaps the CPU processing and GPU computation. """
2025-01-27 03:00:41 -08:00
self . result_queue = deque ( )
2024-10-16 01:33:20 -07:00
while True :
recv_reqs = self . recv_requests ( )
self . process_input_requests ( recv_reqs )
batch = self . get_next_batch_to_run ( )
self . cur_batch = batch
2024-12-29 00:45:57 -08:00
2024-10-16 01:33:20 -07:00
if batch :
2025-04-28 11:19:16 +08:00
batch . launch_done = threading . Event ( )
2024-10-16 01:33:20 -07:00
result = self . run_batch ( batch )
2025-01-27 03:00:41 -08:00
self . result_queue . append ( ( batch . copy ( ) , result ) )
2024-10-16 01:33:20 -07:00
2024-11-19 15:04:43 -08:00
if self . last_batch is None :
2025-01-19 17:10:29 -08:00
# Create a dummy first batch to start the pipeline for overlap schedule.
2024-11-19 15:04:43 -08:00
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch (
reqs = None ,
forward_mode = ForwardMode . DUMMY_FIRST ,
next_batch_sampling_info = self . tp_worker . cur_sampling_info ,
)
2025-04-28 11:19:16 +08:00
self . process_batch_result ( tmp_batch , None , batch . launch_done )
2024-11-19 15:04:43 -08:00
2024-10-16 01:33:20 -07:00
if self . last_batch :
2024-12-29 00:45:57 -08:00
# Process the results of the last batch
2025-01-27 03:00:41 -08:00
tmp_batch , tmp_result = self . result_queue . popleft ( )
2024-11-19 15:04:43 -08:00
tmp_batch . next_batch_sampling_info = (
self . tp_worker . cur_sampling_info if batch else None
)
2025-04-28 11:19:16 +08:00
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
self . process_batch_result (
tmp_batch , tmp_result , batch . launch_done if batch else None
)
2024-10-16 01:33:20 -07:00
elif batch is None :
2025-03-12 22:22:39 -07:00
# When the server is idle, do self-check and re-init some states
2024-10-16 01:33:20 -07:00
self . check_memory ( )
2024-10-26 16:39:41 -07:00
self . new_token_ratio = self . init_new_token_ratio
2025-06-13 00:58:22 +03:00
self . maybe_sleep_on_idle ( )
2024-10-16 01:33:20 -07:00
self . last_batch = batch
2025-04-30 18:18:07 -07:00
@DynamicGradMode ( )
def event_loop_pp ( self ) :
""" A non-overlap scheduler loop for pipeline parallelism. """
mbs = [ None ] * self . pp_size
last_mbs = [ None ] * self . pp_size
self . running_mbs = [
ScheduleBatch ( reqs = [ ] , batch_is_full = False ) for _ in range ( self . pp_size )
]
bids = [ None ] * self . pp_size
pp_outputs : Optional [ PPProxyTensors ] = None
while True :
server_is_idle = True
for mb_id in range ( self . pp_size ) :
self . running_batch = self . running_mbs [ mb_id ]
self . last_batch = last_mbs [ mb_id ]
recv_reqs = self . recv_requests ( )
self . process_input_requests ( recv_reqs )
mbs [ mb_id ] = self . get_next_batch_to_run ( )
self . running_mbs [ mb_id ] = self . running_batch
self . cur_batch = mbs [ mb_id ]
if self . cur_batch :
server_is_idle = False
result = self . run_batch ( self . cur_batch )
2025-05-12 12:38:09 -07:00
# (last rank) send the outputs to the next step
2025-04-30 18:18:07 -07:00
if self . pp_group . is_last_rank :
if self . cur_batch :
next_token_ids , bids [ mb_id ] = (
result . next_token_ids ,
result . bid ,
)
pp_outputs = PPProxyTensors (
{
" next_token_ids " : next_token_ids ,
}
)
# send the output from the last round to let the next stage worker run post processing
self . pp_group . send_tensor_dict (
pp_outputs . tensors ,
all_gather_group = self . attn_tp_group ,
)
# receive outputs and post-process (filter finished reqs) the coming microbatch
next_mb_id = ( mb_id + 1 ) % self . pp_size
next_pp_outputs = None
if mbs [ next_mb_id ] is not None :
next_pp_outputs : Optional [ PPProxyTensors ] = PPProxyTensors (
self . pp_group . recv_tensor_dict (
all_gather_group = self . attn_tp_group
)
)
mbs [ next_mb_id ] . output_ids = next_pp_outputs [ " next_token_ids " ]
output_result = GenerationBatchResult (
logits_output = None ,
pp_hidden_states_proxy_tensors = None ,
next_token_ids = next_pp_outputs [ " next_token_ids " ] ,
extend_input_len_per_req = None ,
extend_logprob_start_len_per_req = None ,
bid = bids [ next_mb_id ] ,
2025-05-12 00:17:33 -07:00
can_run_cuda_graph = result . can_run_cuda_graph ,
2025-04-30 18:18:07 -07:00
)
self . process_batch_result ( mbs [ next_mb_id ] , output_result )
last_mbs [ next_mb_id ] = mbs [ next_mb_id ]
2025-05-12 12:38:09 -07:00
# (not last rank)
2025-04-30 18:18:07 -07:00
if not self . pp_group . is_last_rank :
if self . cur_batch :
bids [ mb_id ] = result . bid
2025-05-12 12:38:09 -07:00
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
2025-04-30 18:18:07 -07:00
if pp_outputs :
self . pp_group . send_tensor_dict (
pp_outputs . tensors ,
all_gather_group = self . attn_tp_group ,
)
# send out reqs to the next stage
2025-05-13 02:51:39 -04:00
dp_offset = self . attn_dp_rank * self . attn_tp_size
2025-04-30 18:18:07 -07:00
if self . attn_tp_rank == 0 :
point_to_point_pyobj (
recv_reqs ,
self . pp_rank * self . tp_size + dp_offset ,
self . world_group . cpu_group ,
self . pp_rank * self . tp_size + dp_offset ,
( self . pp_rank + 1 ) * self . tp_size + dp_offset ,
)
# send out proxy tensors to the next stage
if self . cur_batch :
self . pp_group . send_tensor_dict (
result . pp_hidden_states_proxy_tensors ,
all_gather_group = self . attn_tp_group ,
)
pp_outputs = next_pp_outputs
# When the server is idle, self-check and re-init some states
if server_is_idle :
self . check_memory ( )
self . new_token_ratio = self . init_new_token_ratio
2025-06-13 00:58:22 +03:00
self . maybe_sleep_on_idle ( )
2025-04-30 18:18:07 -07:00
2024-12-29 00:45:57 -08:00
def recv_requests ( self ) - > List [ Req ] :
""" Receive results at tp_rank = 0 and broadcast it to all other TP ranks. """
2025-04-30 18:18:07 -07:00
if self . pp_rank == 0 :
if self . attn_tp_rank == 0 :
recv_reqs = [ ]
while True :
try :
recv_req = self . recv_from_tokenizer . recv_pyobj ( zmq . NOBLOCK )
except zmq . ZMQError :
break
recv_reqs . append ( recv_req )
while True :
try :
recv_rpc = self . recv_from_rpc . recv_pyobj ( zmq . NOBLOCK )
except zmq . ZMQError :
break
recv_reqs . append ( recv_rpc )
else :
recv_reqs = None
2024-10-06 03:24:04 -07:00
else :
2025-04-30 18:18:07 -07:00
if self . attn_tp_rank == 0 :
2025-05-13 02:51:39 -04:00
dp_offset = self . attn_dp_rank * self . attn_tp_size
2025-04-30 18:18:07 -07:00
recv_reqs = point_to_point_pyobj (
[ ] ,
self . pp_rank * self . tp_size + dp_offset ,
self . world_group . cpu_group ,
( self . pp_rank - 1 ) * self . tp_size + dp_offset ,
self . pp_rank * self . tp_size + dp_offset ,
)
else :
recv_reqs = None
2024-09-29 02:36:12 -07:00
2025-01-16 11:15:00 -08:00
if self . server_args . enable_dp_attention :
if self . attn_tp_rank == 0 :
work_reqs = [
req
for req in recv_reqs
if isinstance (
req , ( TokenizedGenerateReqInput , TokenizedEmbeddingReqInput )
)
]
control_reqs = [
req
for req in recv_reqs
if not isinstance (
req , ( TokenizedGenerateReqInput , TokenizedEmbeddingReqInput )
)
]
else :
work_reqs = None
control_reqs = None
if self . attn_tp_size != 1 :
work_reqs = broadcast_pyobj (
work_reqs ,
2025-04-30 18:18:07 -07:00
self . attn_tp_group . rank ,
2025-01-16 11:15:00 -08:00
self . attn_tp_cpu_group ,
2025-04-30 18:18:07 -07:00
src = self . attn_tp_group . ranks [ 0 ] ,
2025-01-16 11:15:00 -08:00
)
if self . tp_size != 1 :
control_reqs = broadcast_pyobj (
2025-04-30 18:18:07 -07:00
control_reqs ,
self . tp_group . rank ,
self . tp_cpu_group ,
src = self . tp_group . ranks [ 0 ] ,
2025-01-16 11:15:00 -08:00
)
recv_reqs = work_reqs + control_reqs
elif self . tp_size != 1 :
2025-04-30 18:18:07 -07:00
recv_reqs = broadcast_pyobj (
recv_reqs ,
self . tp_group . rank ,
self . tp_cpu_group ,
src = self . tp_group . ranks [ 0 ] ,
)
2024-09-29 02:36:12 -07:00
return recv_reqs
2024-10-06 03:24:04 -07:00
def process_input_requests ( self , recv_reqs : List ) :
2024-09-29 17:42:45 -07:00
for recv_req in recv_reqs :
2025-03-03 00:12:04 -08:00
# If it is a health check generation request and there are running requests, ignore it.
if is_health_check_generate_req ( recv_req ) and (
2025-03-12 22:22:39 -07:00
self . chunked_req is not None or not self . running_batch . is_empty ( )
2025-03-03 00:12:04 -08:00
) :
self . return_health_check_ct + = 1
continue
2025-01-19 17:10:29 -08:00
output = self . _request_dispatcher ( recv_req )
2025-01-19 12:13:27 +08:00
if output is not None :
2025-03-14 15:40:44 +08:00
if isinstance ( output , RpcReqOutput ) :
if self . recv_from_rpc is not None :
self . recv_from_rpc . send_pyobj ( output )
else :
self . send_to_tokenizer . send_pyobj ( output )
2024-09-29 17:42:45 -07:00
def handle_generate_request (
self ,
recv_req : TokenizedGenerateReqInput ,
) :
2024-11-29 03:15:58 -08:00
# Create a new request
2024-12-29 02:10:27 -08:00
if (
recv_req . session_params is None
or recv_req . session_params . id is None
or recv_req . session_params . id not in self . sessions
) :
2024-11-25 19:35:04 -05:00
if recv_req . input_embeds is not None :
# Generate fake input_ids based on the length of input_embeds
seq_length = len ( recv_req . input_embeds )
fake_input_ids = [ 1 ] * seq_length
recv_req . input_ids = fake_input_ids
2025-05-07 01:12:57 +08:00
if recv_req . bootstrap_port is None :
# Use default bootstrap port
recv_req . bootstrap_port = self . server_args . disaggregation_bootstrap_port
2024-11-20 00:36:53 -08:00
req = Req (
recv_req . rid ,
recv_req . input_text ,
recv_req . input_ids ,
recv_req . sampling_params ,
2024-12-08 12:27:13 -08:00
return_logprob = recv_req . return_logprob ,
top_logprobs_num = recv_req . top_logprobs_num ,
2025-03-03 00:12:04 -08:00
token_ids_logprob = recv_req . token_ids_logprob ,
2024-12-08 12:27:13 -08:00
stream = recv_req . stream ,
2024-11-20 00:36:53 -08:00
lora_path = recv_req . lora_path ,
2024-11-25 19:35:04 -05:00
input_embeds = recv_req . input_embeds ,
2025-05-10 21:54:46 -07:00
custom_logit_processor = recv_req . custom_logit_processor ,
2025-03-01 20:51:29 -05:00
return_hidden_states = recv_req . return_hidden_states ,
2024-12-27 11:23:46 -08:00
eos_token_ids = self . model_config . hf_eos_token_id ,
2025-04-10 14:23:23 +08:00
bootstrap_host = recv_req . bootstrap_host ,
2025-04-26 00:30:47 +08:00
bootstrap_port = recv_req . bootstrap_port ,
2025-04-10 14:23:23 +08:00
bootstrap_room = recv_req . bootstrap_room ,
2025-06-09 11:44:05 -07:00
data_parallel_rank = recv_req . data_parallel_rank ,
2024-11-20 00:36:53 -08:00
)
req . tokenizer = self . tokenizer
2024-11-25 16:38:43 -08:00
2025-05-21 21:44:25 -07:00
if self . disaggregation_mode != DisaggregationMode . NULL :
# Invalid request for disaggregated mode
if recv_req . bootstrap_room is None :
2025-06-01 19:00:07 -07:00
error_msg = (
2025-05-21 21:44:25 -07:00
f " Invalid request: Disaggregated request received without "
f " boostrap room id. { req . rid =} "
)
2025-06-01 19:00:07 -07:00
logger . error ( error_msg )
prepare_abort ( req , error_msg )
2025-05-21 21:44:25 -07:00
self . stream_output ( [ req ] , req . return_logprob )
return
2024-12-29 02:10:27 -08:00
if (
recv_req . session_params is not None
and recv_req . session_params . id is not None
) :
2024-11-20 00:36:53 -08:00
req . finished_reason = FINISH_ABORT (
2024-12-29 02:10:27 -08:00
f " Invalid request: session id { recv_req . session_params . id } does not exist "
2024-11-20 00:36:53 -08:00
)
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-11-20 00:36:53 -08:00
return
else :
2024-12-29 02:10:27 -08:00
# Create a new request from a previous session
session = self . sessions [ recv_req . session_params . id ]
2024-11-25 12:32:51 -08:00
req = session . create_req ( recv_req , self . tokenizer )
2024-11-20 00:36:53 -08:00
if isinstance ( req . finished_reason , FINISH_ABORT ) :
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-11-20 00:36:53 -08:00
return
2024-09-29 17:42:45 -07:00
2025-01-27 03:00:41 -08:00
# Handle multimodal inputs
2025-03-25 11:08:40 +08:00
if recv_req . mm_inputs is not None :
image_inputs = MultimodalInputs . from_dict ( recv_req . mm_inputs )
2024-11-29 03:15:58 -08:00
# Expand a single image token into multiple dummy tokens for receiving image embeddings
2024-09-30 06:41:49 -07:00
req . origin_input_ids = self . pad_input_ids_func (
2024-11-27 00:03:29 -08:00
req . origin_input_ids , image_inputs
2024-09-29 17:42:45 -07:00
)
2024-11-29 03:15:58 -08:00
req . extend_image_inputs ( image_inputs )
2024-09-29 17:42:45 -07:00
2024-11-29 04:24:20 -08:00
if len ( req . origin_input_ids ) > = self . max_req_input_len :
2025-06-01 19:00:07 -07:00
req . set_finish_with_abort (
error_msg = (
" Multimodal prompt is too long after expanding multimodal tokens. "
f " After expanding { len ( req . origin_input_ids_unpadded ) =} => { len ( req . origin_input_ids ) } >= { self . max_req_input_len } . "
)
2024-11-29 04:24:20 -08:00
)
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-11-21 19:05:41 -08:00
return
2025-06-01 19:00:07 -07:00
# Validate prompt length
2025-01-16 14:51:19 -08:00
error_msg = validate_input_length (
req ,
self . max_req_input_len ,
self . server_args . allow_auto_truncate ,
)
if error_msg :
2025-06-01 19:00:07 -07:00
req . set_finish_with_abort ( error_msg )
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2025-01-16 14:51:19 -08:00
return
2024-10-21 16:12:04 -07:00
2025-01-26 01:39:28 -08:00
# Copy more attributes
2025-03-03 00:12:04 -08:00
if recv_req . logprob_start_len == - 1 or not recv_req . return_logprob :
2025-01-26 01:39:28 -08:00
# By default, only return the logprobs for output tokens
req . logprob_start_len = len ( req . origin_input_ids ) - 1
else :
req . logprob_start_len = recv_req . logprob_start_len
2025-03-03 00:12:04 -08:00
if req . logprob_start_len > = len ( req . origin_input_ids ) :
2025-06-01 19:00:07 -07:00
error_msg = f " { req . logprob_start_len =} is higher than the number of input tokens { len ( req . origin_input_ids ) =} . Please use a smaller logprob_start_len. "
2025-03-03 00:12:04 -08:00
req . logprob_start_len = len ( req . origin_input_ids ) - 1
2025-06-01 19:00:07 -07:00
req . set_finish_with_abort ( error_msg )
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
return
2024-09-29 17:42:45 -07:00
req . sampling_params . max_new_tokens = min (
(
req . sampling_params . max_new_tokens
if req . sampling_params . max_new_tokens is not None
else 1 << 30
) ,
2024-10-21 16:12:04 -07:00
self . max_req_len - len ( req . origin_input_ids ) - 1 ,
2024-09-29 17:42:45 -07:00
)
2024-11-13 01:49:45 -08:00
# Init grammar cache for this request
add_to_grammar_queue = False
if (
req . sampling_params . json_schema is not None
or req . sampling_params . regex is not None
2024-12-26 18:42:41 +05:30
or req . sampling_params . ebnf is not None
2025-02-28 15:33:41 +08:00
or req . sampling_params . structural_tag is not None
2024-11-13 01:49:45 -08:00
) :
assert self . grammar_backend is not None
if req . sampling_params . json_schema is not None :
key = ( " json " , req . sampling_params . json_schema )
elif req . sampling_params . regex is not None :
key = ( " regex " , req . sampling_params . regex )
2024-12-26 18:42:41 +05:30
elif req . sampling_params . ebnf is not None :
key = ( " ebnf " , req . sampling_params . ebnf )
2025-02-28 15:33:41 +08:00
elif req . sampling_params . structural_tag :
key = ( " structural_tag " , req . sampling_params . structural_tag )
2024-11-13 01:49:45 -08:00
2025-05-11 08:36:16 -07:00
value , cache_hit = self . grammar_backend . get_cached_or_future_value ( key )
req . grammar = value
if not cache_hit :
req . grammar_key = key
2024-11-13 01:49:45 -08:00
add_to_grammar_queue = True
2025-06-01 19:00:07 -07:00
else :
if value is INVALID_GRAMMAR_OBJ : # We hit a cached invalid grammar.
error_msg = f " Invalid grammar request with cache hit: { key =} "
req . set_finish_with_abort ( error_msg )
2024-11-13 01:49:45 -08:00
if add_to_grammar_queue :
2025-05-17 16:49:18 -07:00
req . queue_time_start = time . perf_counter ( )
2024-11-13 01:45:28 +09:00
self . grammar_queue . append ( req )
else :
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
def _add_request_to_queue ( self , req : Req ) :
2025-05-17 16:49:18 -07:00
req . queue_time_start = time . perf_counter ( )
2025-03-21 14:47:47 -07:00
if self . disaggregation_mode == DisaggregationMode . PREFILL :
2025-06-14 15:59:54 -07:00
self . disagg_prefill_bootstrap_queue . add (
req , self . model_config . num_key_value_heads
)
2025-03-21 14:47:47 -07:00
elif self . disaggregation_mode == DisaggregationMode . DECODE :
self . disagg_decode_prealloc_queue . add ( req )
else :
self . waiting_queue . append ( req )
2025-06-14 19:48:05 -07:00
def _extend_requests_to_queue ( self , reqs : List [ Req ] , is_retracted : bool = False ) :
2025-05-23 21:49:00 -07:00
if self . disaggregation_mode == DisaggregationMode . PREFILL :
2025-06-14 15:59:54 -07:00
self . disagg_prefill_bootstrap_queue . extend (
reqs , self . model_config . num_key_value_heads
)
2025-05-23 21:49:00 -07:00
elif self . disaggregation_mode == DisaggregationMode . DECODE :
# If this is a decode server, we put the request to the decode pending prealloc queue
2025-06-14 19:48:05 -07:00
self . disagg_decode_prealloc_queue . extend ( reqs , is_retracted )
2025-03-21 14:47:47 -07:00
else :
self . waiting_queue . extend ( reqs )
2024-09-29 17:42:45 -07:00
def handle_embedding_request (
self ,
2024-11-03 08:38:26 -08:00
recv_req : TokenizedEmbeddingReqInput ,
2024-09-29 17:42:45 -07:00
) :
req = Req (
recv_req . rid ,
recv_req . input_text ,
recv_req . input_ids ,
recv_req . sampling_params ,
2025-06-17 01:50:01 +08:00
token_type_ids = recv_req . token_type_ids ,
2024-09-29 17:42:45 -07:00
)
req . tokenizer = self . tokenizer
2025-03-07 08:46:20 +08:00
# Handle multimodal inputs
if recv_req . image_inputs is not None :
2025-03-25 11:08:40 +08:00
image_inputs = MultimodalInputs . from_dict ( recv_req . image_inputs )
2025-03-07 08:46:20 +08:00
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req . origin_input_ids = self . pad_input_ids_func (
req . origin_input_ids , image_inputs
)
req . extend_image_inputs ( image_inputs )
if len ( req . origin_input_ids ) > = self . max_req_input_len :
2025-06-01 19:00:07 -07:00
req . set_finish_with_abort (
error_msg = (
" Multimodal prompt is too long after expanding multimodal tokens. "
f " After expanding { len ( req . origin_input_ids_unpadded ) =} => { len ( req . origin_input_ids ) } >= { self . max_req_input_len } . "
)
2025-03-07 08:46:20 +08:00
)
2025-06-01 19:00:07 -07:00
self . _add_request_to_queue ( req )
2025-03-07 08:46:20 +08:00
return
2025-01-16 14:51:19 -08:00
# Validate prompts length
2025-01-26 01:39:28 -08:00
error_msg = validate_input_length (
2025-01-16 14:51:19 -08:00
req ,
self . max_req_input_len ,
self . server_args . allow_auto_truncate ,
)
2025-01-26 01:39:28 -08:00
if error_msg :
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2025-01-26 01:39:28 -08:00
return
2024-09-29 17:42:45 -07:00
2025-01-26 01:39:28 -08:00
# Copy more attributes
req . logprob_start_len = len ( req . origin_input_ids ) - 1
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-09-29 17:42:45 -07:00
2025-01-27 03:00:41 -08:00
def log_prefill_stats (
self ,
adder : PrefillAdder ,
can_run_list : List [ Req ] ,
2025-03-03 00:12:04 -08:00
running_bs : int ,
2025-01-27 03:00:41 -08:00
) :
2025-05-17 16:49:18 -07:00
gap_latency = time . perf_counter ( ) - self . last_prefill_stats_tic
self . last_prefill_stats_tic = time . perf_counter ( )
2025-06-16 23:30:26 +08:00
self . last_input_throughput = self . last_prefill_tokens / gap_latency
self . last_prefill_tokens = adder . log_input_tokens
2025-03-12 22:22:39 -07:00
2024-11-10 04:39:32 -08:00
num_used = self . max_total_num_tokens - (
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator . available_size ( )
+ self . tree_cache . evictable_size ( )
2024-11-10 04:39:32 -08:00
)
2025-04-09 17:19:27 -07:00
num_new_seq = len ( can_run_list )
2025-03-03 00:12:04 -08:00
f = (
2024-11-10 04:39:32 -08:00
f " Prefill batch. "
2025-04-09 17:19:27 -07:00
f " #new-seq: { num_new_seq } , "
2024-11-10 04:39:32 -08:00
f " #new-token: { adder . log_input_tokens } , "
f " #cached-token: { adder . log_hit_tokens } , "
f " token usage: { num_used / self . max_total_num_tokens : .2f } , "
)
2025-04-25 17:25:45 +08:00
if self . disaggregation_mode == DisaggregationMode . PREFILL :
f + = f " #unbootstrapped-req: { len ( self . disagg_prefill_bootstrap_queue . queue ) } , "
f + = f " #queue-req: { len ( self . waiting_queue ) } , "
2025-06-03 01:28:27 +08:00
f + = f " #transferring-req: { len ( self . disagg_prefill_inflight_queue ) } , "
2025-06-12 14:38:24 +08:00
f + = f " input throughput (token/s): { self . last_input_throughput : .2f } "
2025-04-25 17:25:45 +08:00
else :
2025-06-12 14:38:24 +08:00
f + = f " #running-req: { running_bs } , "
2025-04-25 17:25:45 +08:00
f + = f " #queue-req: { len ( self . waiting_queue ) } "
2025-03-03 00:12:04 -08:00
logger . info ( f )
2024-11-10 04:39:32 -08:00
if self . enable_metrics :
2025-03-03 00:12:04 -08:00
cache_hit_rate = adder . log_hit_tokens / (
adder . log_input_tokens + adder . log_hit_tokens
)
2024-11-10 04:39:32 -08:00
self . stats . num_running_reqs = running_bs
self . stats . num_used_tokens = num_used
self . stats . token_usage = round ( num_used / self . max_total_num_tokens , 2 )
2025-03-03 00:12:04 -08:00
self . stats . num_queue_reqs = len ( self . waiting_queue )
self . stats . cache_hit_rate = cache_hit_rate
2025-04-09 17:19:27 -07:00
total_queue_latency = 0
for req in can_run_list :
total_queue_latency + = req . queue_time_end - req . queue_time_start
self . stats . avg_request_queue_latency = total_queue_latency / num_new_seq
2024-11-10 04:39:32 -08:00
self . metrics_collector . log_stats ( self . stats )
2025-05-19 14:19:54 -07:00
self . _publish_kv_events ( )
2024-11-10 04:39:32 -08:00
2025-05-12 00:17:33 -07:00
def log_decode_stats (
self , can_run_cuda_graph : bool , running_batch : ScheduleBatch = None
) :
2025-04-30 18:18:07 -07:00
batch = running_batch or self . running_batch
2025-05-17 16:49:18 -07:00
gap_latency = time . perf_counter ( ) - self . last_decode_stats_tic
self . last_decode_stats_tic = time . perf_counter ( )
2025-03-03 00:12:04 -08:00
self . last_gen_throughput = self . num_generated_tokens / gap_latency
self . num_generated_tokens = 0
2025-04-30 18:18:07 -07:00
num_running_reqs = len ( batch . reqs )
2024-10-06 03:24:04 -07:00
num_used = self . max_total_num_tokens - (
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator . available_size ( )
+ self . tree_cache . evictable_size ( )
2024-10-06 03:24:04 -07:00
)
2025-03-03 00:12:04 -08:00
if RECORD_STEP_TIME :
self . step_time_dict [ num_running_reqs ] . append (
gap_latency / self . server_args . decode_log_interval
)
2024-10-06 03:24:04 -07:00
2025-04-25 17:25:45 +08:00
msg = (
f " Decode batch. "
f " #running-req: { num_running_reqs } , "
f " #token: { num_used } , "
f " token usage: { num_used / self . max_total_num_tokens : .2f } , "
)
2025-01-19 17:10:29 -08:00
if self . spec_algorithm . is_none ( ) :
2025-01-19 18:36:59 -08:00
spec_accept_length = 0
2025-01-19 17:10:29 -08:00
else :
2025-01-19 18:36:59 -08:00
spec_accept_length = (
2025-01-19 17:10:29 -08:00
self . spec_num_total_accepted_tokens / self . spec_num_total_forward_ct
)
2025-03-03 00:12:04 -08:00
self . cum_spec_accept_length + = self . spec_num_total_accepted_tokens
self . cum_spec_accept_count + = self . spec_num_total_forward_ct
2025-01-19 17:10:29 -08:00
self . spec_num_total_accepted_tokens = self . spec_num_total_forward_ct = 0
2025-04-25 17:25:45 +08:00
msg + = f " accept len: { spec_accept_length : .2f } , "
if self . disaggregation_mode == DisaggregationMode . DECODE :
msg + = f " pre-allocated usage: { self . num_tokens_pre_allocated / self . max_total_num_tokens : .2f } , "
2025-06-14 19:48:05 -07:00
msg + = f " #retracted-req: { len ( self . disagg_decode_prealloc_queue . retracted_queue ) } , "
2025-04-25 17:25:45 +08:00
msg + = (
2025-05-12 00:17:33 -07:00
f " cuda graph: { can_run_cuda_graph } , "
2025-04-25 17:25:45 +08:00
f " gen throughput (token/s): { self . last_gen_throughput : .2f } , "
f " #queue-req: { len ( self . waiting_queue ) } "
)
2025-01-19 17:10:29 -08:00
logger . info ( msg )
2024-11-10 04:39:32 -08:00
if self . enable_metrics :
self . stats . num_running_reqs = num_running_reqs
self . stats . num_used_tokens = num_used
self . stats . token_usage = num_used / self . max_total_num_tokens
2025-03-03 00:12:04 -08:00
self . stats . cache_hit_rate = 0.0
self . stats . gen_throughput = self . last_gen_throughput
2024-11-10 04:39:32 -08:00
self . stats . num_queue_reqs = len ( self . waiting_queue )
2025-05-11 08:36:16 -07:00
self . stats . num_grammar_queue_reqs = len ( self . grammar_queue )
2025-01-19 18:36:59 -08:00
self . stats . spec_accept_length = spec_accept_length
2024-11-10 04:39:32 -08:00
self . metrics_collector . log_stats ( self . stats )
2025-05-19 14:19:54 -07:00
self . _publish_kv_events ( )
2024-11-10 04:39:32 -08:00
2024-10-06 03:24:04 -07:00
def check_memory ( self ) :
available_size = (
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator . available_size ( )
+ self . tree_cache . evictable_size ( )
2024-10-06 03:24:04 -07:00
)
2025-01-27 12:28:17 -08:00
protected_size = self . tree_cache . protected_size ( )
memory_leak = available_size != (
self . max_total_num_tokens
if not self . enable_hierarchical_cache
else self . max_total_num_tokens - protected_size
)
if memory_leak :
2024-11-17 22:18:11 -08:00
msg = (
2025-03-30 00:46:23 -07:00
" token_to_kv_pool_allocator memory leak detected! "
2025-01-27 12:28:17 -08:00
f " { available_size =} , { protected_size =} , { self . max_total_num_tokens =} \n "
2025-03-12 22:22:39 -07:00
f " { self . token_to_kv_pool_allocator . available_size ( ) =} \n "
f " { self . tree_cache . evictable_size ( ) =} \n "
2024-10-06 03:24:04 -07:00
)
2025-05-10 21:54:46 -07:00
raise ValueError ( msg )
2024-10-06 03:24:04 -07:00
if len ( self . req_to_token_pool . free_slots ) != self . req_to_token_pool . size :
2024-11-17 22:18:11 -08:00
msg = (
2025-03-30 00:46:23 -07:00
" req_to_token_pool memory leak detected! "
2024-11-17 22:18:11 -08:00
f " available_size= { len ( self . req_to_token_pool . free_slots ) } , "
f " total_size= { self . req_to_token_pool . size } \n "
2024-10-06 03:24:04 -07:00
)
2025-05-10 21:54:46 -07:00
raise ValueError ( msg )
2024-10-06 03:24:04 -07:00
2025-03-03 00:12:04 -08:00
if (
self . enable_metrics
and self . attn_tp_rank == 0
2025-05-17 16:49:18 -07:00
and time . perf_counter ( ) > self . metrics_collector . last_log_time + 30
2025-03-03 00:12:04 -08:00
) :
# During idle time, also collect metrics every 30 seconds.
num_used = self . max_total_num_tokens - (
2025-03-05 21:32:42 -06:00
self . token_to_kv_pool_allocator . available_size ( )
2025-03-03 00:12:04 -08:00
+ self . tree_cache . evictable_size ( )
)
2025-03-12 22:22:39 -07:00
num_running_reqs = len ( self . running_batch . reqs )
2025-03-03 00:12:04 -08:00
self . stats . num_running_reqs = num_running_reqs
self . stats . num_used_tokens = num_used
self . stats . token_usage = num_used / self . max_total_num_tokens
self . stats . gen_throughput = 0
self . stats . num_queue_reqs = len ( self . waiting_queue )
2025-05-11 08:36:16 -07:00
self . stats . num_grammar_queue_reqs = len ( self . grammar_queue )
2025-03-03 00:12:04 -08:00
self . metrics_collector . log_stats ( self . stats )
2025-05-19 14:19:54 -07:00
self . _publish_kv_events ( )
2025-03-03 00:12:04 -08:00
2025-06-17 15:33:28 +08:00
def coordinate_spec_dp_attn_batch ( self , new_batch : Optional [ ScheduleBatch ] ) :
""" Coordinate the DP attention batch. """
local_info = torch . tensor (
[
( new_batch is not None ) ,
] ,
dtype = torch . int64 ,
)
global_info = torch . empty (
( self . server_args . dp_size , self . attn_tp_size , 1 ) ,
dtype = torch . int64 ,
)
torch . distributed . all_gather_into_tensor (
global_info . flatten ( ) ,
local_info ,
group = self . tp_cpu_group ,
)
any_new_batch = any (
global_info [ : , 0 , 0 ] . tolist ( )
) # Any DP worker has forward batch
return any_new_batch
2024-12-11 12:51:50 -08:00
def get_next_batch_to_run ( self ) - > Optional [ ScheduleBatch ] :
2024-10-14 01:15:34 -07:00
# Merge the prefill batch into the running batch
2025-04-30 18:18:07 -07:00
chunked_req_to_exclude = set ( )
if self . chunked_req :
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
chunked_req_to_exclude . add ( self . chunked_req )
self . tree_cache . cache_unfinished_req ( self . chunked_req )
# chunked request keeps its rid but will get a new req_pool_idx
self . req_to_token_pool . free ( self . chunked_req . req_pool_idx )
2024-11-24 04:47:10 -08:00
if self . last_batch and self . last_batch . forward_mode . is_extend ( ) :
2025-04-30 18:18:07 -07:00
if self . last_batch . chunked_req is not None :
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
# We need to discard it.
chunked_req_to_exclude . add ( self . last_batch . chunked_req )
2024-11-24 04:47:10 -08:00
2025-03-07 22:12:13 -08:00
# Filter batch
2025-03-08 04:11:18 +08:00
last_bs = self . last_batch . batch_size ( )
2025-04-30 18:18:07 -07:00
self . last_batch . filter_batch (
chunked_req_to_exclude = list ( chunked_req_to_exclude )
)
2025-03-08 04:11:18 +08:00
if self . last_batch . batch_size ( ) < last_bs :
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = False
2025-03-08 04:11:18 +08:00
2025-03-07 22:12:13 -08:00
# Merge the new batch into the running batch
2024-10-14 01:15:34 -07:00
if not self . last_batch . is_empty ( ) :
2025-03-12 22:22:39 -07:00
if self . running_batch . is_empty ( ) :
2024-10-14 01:15:34 -07:00
self . running_batch = self . last_batch
else :
2025-03-12 22:22:39 -07:00
# Merge running_batch with prefill batch
2024-10-14 01:15:34 -07:00
self . running_batch . merge_batch ( self . last_batch )
2024-10-12 21:35:30 -07:00
2024-10-07 13:05:53 -07:00
new_batch = self . get_new_batch_prefill ( )
2025-06-17 15:33:28 +08:00
# TODO(ch-wan): minor refactor is needed here to improve readability
any_new_batch = (
self . server_args . enable_dp_attention
and not self . spec_algorithm . is_none ( )
and self . coordinate_spec_dp_attn_batch ( new_batch )
)
if new_batch is not None or any_new_batch :
2025-01-19 17:10:29 -08:00
# Run prefill first if possible
ret = new_batch
else :
# Run decode
2025-03-12 22:22:39 -07:00
if not self . running_batch . is_empty ( ) :
2025-01-19 17:10:29 -08:00
self . running_batch = self . update_running_batch ( self . running_batch )
2025-03-12 22:22:39 -07:00
ret = self . running_batch if not self . running_batch . is_empty ( ) else None
else :
ret = None
2024-10-10 16:34:13 -07:00
2025-01-19 17:10:29 -08:00
# Handle DP attention
2025-03-27 20:09:35 -04:00
if self . server_args . enable_dp_attention or self . server_args . enable_sp_layernorm :
2025-03-13 08:23:56 -07:00
ret , _ = self . prepare_dp_attn_batch ( ret )
2025-01-19 17:10:29 -08:00
return ret
2024-10-07 13:05:53 -07:00
2025-04-30 18:18:07 -07:00
def get_num_allocatable_reqs ( self , running_bs ) :
res = global_server_args_dict [ " max_micro_batch_size " ] - running_bs
if self . pp_size > 1 :
res = min ( res , self . req_to_token_pool . available_size ( ) )
return res
2024-10-06 03:24:04 -07:00
def get_new_batch_prefill ( self ) - > Optional [ ScheduleBatch ] :
2024-11-12 21:17:38 -08:00
# Check if the grammar is ready in the grammar queue
2024-11-13 01:45:28 +09:00
if self . grammar_queue :
2024-11-13 01:49:45 -08:00
self . move_ready_grammar_requests ( )
2024-11-13 01:45:28 +09:00
2024-10-06 03:24:04 -07:00
# Handle the cases where prefill is not allowed
if (
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full or len ( self . waiting_queue ) == 0
2025-03-03 00:12:04 -08:00
) and self . chunked_req is None :
2024-10-06 03:24:04 -07:00
return None
2025-03-12 22:22:39 -07:00
running_bs = len ( self . running_batch . reqs )
2025-05-11 00:55:00 -04:00
# Ignore the check if self.chunked_req is not None.
2025-04-30 18:18:07 -07:00
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
# as the space for the chunked request has just been released.
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
if self . get_num_allocatable_reqs ( running_bs ) < = 0 and not self . chunked_req :
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = True
2024-09-29 17:42:45 -07:00
return None
2025-03-12 11:22:35 -07:00
if self . enable_hierarchical_cache :
2025-06-19 09:58:48 -07:00
self . tree_cache . check_hicache_events ( )
2025-03-12 11:22:35 -07:00
2024-09-29 17:42:45 -07:00
# Get priority queue
2025-06-19 09:58:48 -07:00
self . policy . calc_priority ( self . waiting_queue )
2024-09-29 17:42:45 -07:00
2024-10-06 03:24:04 -07:00
# Prefill policy
2024-09-29 17:42:45 -07:00
adder = PrefillAdder (
2025-06-19 09:58:48 -07:00
self . page_size ,
2024-09-29 17:42:45 -07:00
self . tree_cache ,
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator ,
2024-09-29 17:42:45 -07:00
self . running_batch ,
self . new_token_ratio ,
self . max_prefill_tokens ,
self . chunked_prefill_size ,
2024-11-10 04:39:32 -08:00
running_bs if self . is_mixed_chunk else 0 ,
2024-09-29 17:42:45 -07:00
)
2025-03-12 22:22:39 -07:00
if self . chunked_req is not None :
2025-03-03 00:12:04 -08:00
self . chunked_req . init_next_round_input ( )
self . chunked_req = adder . add_chunked_req ( self . chunked_req )
2024-10-25 10:24:44 -07:00
2024-10-14 05:25:00 -07:00
if self . lora_paths :
2025-03-12 22:22:39 -07:00
lora_set = set ( [ req . lora_path for req in self . running_batch . reqs ] )
2024-10-19 23:19:26 -07:00
# Get requests from the waiting queue to a new prefill batch
2024-09-29 17:42:45 -07:00
for req in self . waiting_queue :
if (
2024-10-14 05:25:00 -07:00
self . lora_paths
2024-09-29 17:42:45 -07:00
and len (
lora_set
| set ( [ req . lora_path for req in adder . can_run_list ] )
| set ( [ req . lora_path ] )
)
> self . max_loras_per_batch
) :
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = True
2024-09-29 17:42:45 -07:00
break
2025-04-30 18:18:07 -07:00
if len ( adder . can_run_list ) > = self . get_num_allocatable_reqs ( running_bs ) :
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = True
2024-09-29 17:42:45 -07:00
break
2024-10-04 18:00:18 -07:00
2025-05-23 12:03:05 -07:00
if self . disaggregation_mode == DisaggregationMode . PREFILL :
# In prefill mode, prealloc queue and transfer queue can also take memory,
# so we need to check if the available size for the actual available size.
if len ( adder . can_run_list ) > = self . req_to_token_pool . available_size ( ) :
self . running_batch . batch_is_full = True
break
2025-06-19 09:58:48 -07:00
req . init_next_round_input ( self . tree_cache )
res = adder . add_one_req ( req , has_chunked_req = ( self . chunked_req is not None ) )
2025-04-30 18:18:07 -07:00
2024-10-04 18:00:18 -07:00
if res != AddReqResult . CONTINUE :
if res == AddReqResult . NO_TOKEN :
2025-01-27 12:28:17 -08:00
if self . enable_hierarchical_cache :
# Set batch_is_full after making sure there are requests that can be served
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = len (
adder . can_run_list
2025-04-30 18:18:07 -07:00
) > 0 or ( not self . running_batch . is_empty ( ) )
2025-01-27 12:28:17 -08:00
else :
2025-03-12 22:22:39 -07:00
self . running_batch . batch_is_full = True
2024-09-29 17:42:45 -07:00
break
2024-10-14 05:25:00 -07:00
# Update waiting queue
2025-03-03 00:12:04 -08:00
can_run_list : List [ Req ] = adder . can_run_list
2024-10-14 05:25:00 -07:00
if len ( can_run_list ) == 0 :
return None
2025-04-09 17:19:27 -07:00
if self . enable_metrics :
# only record queue time when enable_metrics is True to avoid overhead
for req in can_run_list :
2025-05-17 16:49:18 -07:00
req . queue_time_end = time . perf_counter ( )
2025-04-09 17:19:27 -07:00
2024-10-14 05:25:00 -07:00
self . waiting_queue = [
x for x in self . waiting_queue if x not in set ( can_run_list )
]
2024-09-29 17:42:45 -07:00
2025-03-03 00:12:04 -08:00
if adder . new_chunked_req is not None :
assert self . chunked_req is None
self . chunked_req = adder . new_chunked_req
2024-10-25 10:24:44 -07:00
2025-03-03 00:12:04 -08:00
if self . chunked_req :
self . chunked_req . is_chunked + = 1
2024-10-06 03:24:04 -07:00
2024-09-29 17:42:45 -07:00
# Print stats
2025-01-16 11:15:00 -08:00
if self . attn_tp_rank == 0 :
2025-03-03 00:12:04 -08:00
self . log_prefill_stats ( adder , can_run_list , running_bs )
2024-09-29 17:42:45 -07:00
2024-10-06 03:24:04 -07:00
# Create a new batch
2024-09-29 17:42:45 -07:00
new_batch = ScheduleBatch . init_new (
can_run_list ,
self . req_to_token_pool ,
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator ,
2024-09-29 17:42:45 -07:00
self . tree_cache ,
2024-10-21 15:01:21 -07:00
self . model_config ,
2024-11-24 06:29:38 -08:00
self . enable_overlap ,
2025-01-02 02:09:08 -08:00
self . spec_algorithm ,
2025-01-20 02:00:35 -08:00
self . server_args . enable_custom_logit_processor ,
2025-04-30 18:18:07 -07:00
chunked_req = self . chunked_req ,
2024-09-29 17:42:45 -07:00
)
2025-06-17 17:44:57 -07:00
if self . enable_hierarchical_cache :
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
2025-06-19 09:58:48 -07:00
new_batch . hicache_consumer_index = (
self . tree_cache . ready_to_load_host_cache ( )
)
2025-06-17 17:44:57 -07:00
2024-11-24 06:29:38 -08:00
new_batch . prepare_for_extend ( )
2024-09-29 17:42:45 -07:00
2024-10-06 03:24:04 -07:00
# Mixed-style chunked prefill
2024-11-24 07:17:37 -08:00
if (
self . is_mixed_chunk
2025-03-12 22:22:39 -07:00
and not self . running_batch . is_empty ( )
2024-11-24 07:17:37 -08:00
and not ( new_batch . return_logprob or self . running_batch . return_logprob )
) :
# TODO (lianmin): support return_logprob + mixed chunked prefill
2024-10-30 21:20:41 -07:00
self . running_batch . filter_batch ( )
if not self . running_batch . is_empty ( ) :
2024-11-24 06:29:38 -08:00
self . running_batch . prepare_for_decode ( )
2024-10-30 21:20:41 -07:00
new_batch . mix_with_running ( self . running_batch )
new_batch . decoding_reqs = self . running_batch . reqs
2025-03-12 22:22:39 -07:00
self . running_batch = ScheduleBatch (
reqs = [ ] , batch_is_full = self . running_batch . batch_is_full
)
2024-10-14 05:25:00 -07:00
else :
new_batch . decoding_reqs = None
2024-10-06 03:24:04 -07:00
return new_batch
2024-11-24 04:47:10 -08:00
def update_running_batch ( self , batch : ScheduleBatch ) - > Optional [ ScheduleBatch ] :
2024-10-19 23:19:26 -07:00
""" Update the current running decoding batch. """
2024-11-24 04:47:10 -08:00
initial_bs = batch . batch_size ( )
2024-10-06 03:24:04 -07:00
2024-10-14 01:15:34 -07:00
batch . filter_batch ( )
if batch . is_empty ( ) :
2025-03-12 22:22:39 -07:00
batch . batch_is_full = False
return batch
2024-10-14 01:15:34 -07:00
2024-10-06 03:24:04 -07:00
# Check if decode out of memory
2025-01-02 02:09:08 -08:00
if not batch . check_decode_mem ( self . decode_mem_cache_buf_multiplier ) or (
2025-03-03 00:12:04 -08:00
TEST_RETRACT and batch . batch_size ( ) > 10
2025-01-02 02:09:08 -08:00
) :
2024-10-06 03:24:04 -07:00
old_ratio = self . new_token_ratio
2025-03-03 00:12:04 -08:00
retracted_reqs , new_token_ratio = batch . retract_decode ( self . server_args )
2024-10-06 03:24:04 -07:00
self . new_token_ratio = new_token_ratio
2024-09-29 17:42:45 -07:00
2024-10-06 03:24:04 -07:00
logger . info (
2025-05-31 15:53:55 -07:00
" KV cache pool is full. Retract requests. "
2024-10-06 03:24:04 -07:00
f " #retracted_reqs: { len ( retracted_reqs ) } , "
f " #new_token_ratio: { old_ratio : .4f } -> { self . new_token_ratio : .4f } "
)
2025-06-14 19:48:05 -07:00
self . _extend_requests_to_queue ( retracted_reqs , is_retracted = True )
2024-10-06 03:24:04 -07:00
else :
self . new_token_ratio = max (
2024-10-25 10:24:44 -07:00
self . new_token_ratio - self . new_token_ratio_decay ,
2024-10-06 03:24:04 -07:00
self . min_new_token_ratio ,
)
2024-11-24 04:47:10 -08:00
if batch . batch_size ( ) < initial_bs :
2025-03-12 22:22:39 -07:00
batch . batch_is_full = False
2024-10-06 03:24:04 -07:00
# Update batch tensors
2024-11-24 06:29:38 -08:00
batch . prepare_for_decode ( )
2024-11-24 04:47:10 -08:00
return batch
2024-10-06 03:24:04 -07:00
2025-01-16 12:51:11 -08:00
def run_batch (
self , batch : ScheduleBatch
) - > Union [ GenerationBatchResult , EmbeddingBatchResult ] :
2024-10-19 23:19:26 -07:00
""" Run a batch. """
2024-10-27 02:00:50 -07:00
self . forward_ct + = 1
2025-05-31 15:53:55 -07:00
# Whether to run the profiler
self . _profile_batch_predicate ( batch )
2025-05-08 16:03:08 +08:00
if self . forward_sleep_time is not None :
logger . info ( f " Scheduler.run_batch sleep { self . forward_sleep_time } s " )
time . sleep ( self . forward_sleep_time )
2025-03-06 00:13:20 -08:00
# Run forward
2024-09-29 17:42:45 -07:00
if self . is_generation :
2025-01-26 01:39:28 -08:00
if self . spec_algorithm . is_none ( ) :
model_worker_batch = batch . get_model_worker_batch ( )
2025-06-17 17:44:57 -07:00
# update the consumer index of hicache to the running batch
self . tp_worker . set_hicache_consumer (
model_worker_batch . hicache_consumer_index
)
2025-04-30 18:18:07 -07:00
if self . pp_group . is_last_rank :
2025-05-12 00:17:33 -07:00
logits_output , next_token_ids , can_run_cuda_graph = (
2025-04-30 18:18:07 -07:00
self . tp_worker . forward_batch_generation ( model_worker_batch )
)
else :
2025-05-12 00:17:33 -07:00
pp_hidden_states_proxy_tensors , _ , can_run_cuda_graph = (
2025-04-30 18:18:07 -07:00
self . tp_worker . forward_batch_generation ( model_worker_batch )
)
2025-03-05 08:06:07 -08:00
bid = model_worker_batch . bid
2024-10-06 03:24:04 -07:00
else :
2025-01-26 01:39:28 -08:00
(
logits_output ,
next_token_ids ,
2025-03-05 08:06:07 -08:00
bid ,
2025-01-26 01:39:28 -08:00
num_accepted_tokens ,
2025-05-12 00:17:33 -07:00
can_run_cuda_graph ,
2025-01-26 01:39:28 -08:00
) = self . draft_worker . forward_batch_speculative_generation ( batch )
2025-06-09 16:39:21 -07:00
bs = batch . batch_size ( )
self . spec_num_total_accepted_tokens + = num_accepted_tokens + bs
self . spec_num_total_forward_ct + = bs
2025-01-26 01:39:28 -08:00
self . num_generated_tokens + = num_accepted_tokens
2025-04-30 18:18:07 -07:00
if self . pp_group . is_last_rank :
batch . output_ids = next_token_ids
2025-03-06 00:13:20 -08:00
2025-03-03 00:12:04 -08:00
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
2025-06-20 02:22:47 +08:00
if batch . return_logprob or self . spec_algorithm . is_eagle ( ) :
2025-03-03 00:12:04 -08:00
extend_input_len_per_req = [ req . extend_input_len for req in batch . reqs ]
2025-06-20 02:22:47 +08:00
else :
extend_input_len_per_req = None
if batch . return_logprob :
2025-03-03 00:12:04 -08:00
extend_logprob_start_len_per_req = [
req . extend_logprob_start_len for req in batch . reqs
]
else :
extend_logprob_start_len_per_req = None
2025-01-16 12:51:11 -08:00
ret = GenerationBatchResult (
2025-04-30 18:18:07 -07:00
logits_output = logits_output if self . pp_group . is_last_rank else None ,
pp_hidden_states_proxy_tensors = (
pp_hidden_states_proxy_tensors
if not self . pp_group . is_last_rank
else None
) ,
next_token_ids = next_token_ids if self . pp_group . is_last_rank else None ,
2025-03-03 00:12:04 -08:00
extend_input_len_per_req = extend_input_len_per_req ,
extend_logprob_start_len_per_req = extend_logprob_start_len_per_req ,
2025-03-05 08:06:07 -08:00
bid = bid ,
2025-05-12 00:17:33 -07:00
can_run_cuda_graph = can_run_cuda_graph ,
2025-01-16 12:51:11 -08:00
)
2024-10-06 03:24:04 -07:00
else : # embedding or reward model
model_worker_batch = batch . get_model_worker_batch ( )
embeddings = self . tp_worker . forward_batch_embedding ( model_worker_batch )
2025-01-16 12:51:11 -08:00
ret = EmbeddingBatchResult (
embeddings = embeddings , bid = model_worker_batch . bid
)
2024-10-12 21:35:30 -07:00
return ret
2024-11-07 15:42:47 -08:00
2025-01-16 12:51:11 -08:00
def process_batch_result (
self ,
batch : ScheduleBatch ,
result : Union [ GenerationBatchResult , EmbeddingBatchResult ] ,
2025-04-28 11:19:16 +08:00
launch_done : Optional [ threading . Event ] = None ,
2025-01-16 12:51:11 -08:00
) :
2024-10-06 03:24:04 -07:00
if batch . forward_mode . is_decode ( ) :
2025-04-28 11:19:16 +08:00
self . process_batch_result_decode ( batch , result , launch_done )
2024-11-19 15:04:43 -08:00
elif batch . forward_mode . is_extend ( ) :
2025-04-28 11:19:16 +08:00
self . process_batch_result_prefill ( batch , result , launch_done )
2025-01-16 11:15:00 -08:00
elif batch . forward_mode . is_idle ( ) :
if self . enable_overlap :
2025-04-28 11:19:16 +08:00
self . tp_worker . resolve_last_batch_result ( launch_done )
2025-05-12 14:33:38 -07:00
self . set_next_batch_sampling_info_done ( batch )
2024-11-19 15:04:43 -08:00
elif batch . forward_mode . is_dummy_first ( ) :
2025-05-12 14:33:38 -07:00
self . set_next_batch_sampling_info_done ( batch )
2024-10-06 03:24:04 -07:00
2025-03-03 00:12:04 -08:00
if self . return_health_check_ct :
# Return some signal for the health check.
# This is used to prevent the health check signal being blocked by long context prefill.
# However, one minor issue is that this code path does not check the status of detokenizer manager.
self . return_health_check_ct - = 1
self . send_to_tokenizer . send_pyobj ( HealthCheckOutput ( ) )
2024-12-06 05:49:29 -08:00
def prepare_dp_attn_batch ( self , local_batch : ScheduleBatch ) :
2025-04-09 14:44:25 +08:00
return self . prepare_dp_attn_batch_raw (
local_batch ,
dp_size = self . server_args . dp_size ,
attn_tp_size = self . attn_tp_size ,
2025-05-13 02:51:39 -04:00
moe_dense_tp_size = self . server_args . moe_dense_tp_size ,
2025-04-09 14:44:25 +08:00
tp_cpu_group = self . tp_cpu_group ,
get_idle_batch = self . get_idle_batch ,
disable_cuda_graph = self . server_args . disable_cuda_graph ,
spec_algorithm = self . spec_algorithm ,
speculative_num_draft_tokens = self . server_args . speculative_num_draft_tokens ,
2025-05-25 08:39:07 +08:00
enable_two_batch_overlap = self . server_args . enable_two_batch_overlap ,
enable_deepep_moe = self . server_args . enable_deepep_moe ,
deepep_mode = DeepEPMode [ self . server_args . deepep_mode ] ,
2025-04-09 14:44:25 +08:00
)
@staticmethod
def prepare_dp_attn_batch_raw (
local_batch : ScheduleBatch ,
dp_size ,
attn_tp_size : int ,
2025-05-13 02:51:39 -04:00
moe_dense_tp_size : Optional [ int ] ,
2025-04-09 14:44:25 +08:00
tp_cpu_group ,
get_idle_batch ,
disable_cuda_graph : bool ,
spec_algorithm ,
speculative_num_draft_tokens ,
2025-05-25 08:39:07 +08:00
enable_two_batch_overlap : bool ,
enable_deepep_moe : bool ,
deepep_mode : DeepEPMode ,
2025-04-09 14:44:25 +08:00
) :
2024-12-06 05:49:29 -08:00
# Check if other DP workers have running batches
if local_batch is None :
num_tokens = 0
2025-05-13 02:51:39 -04:00
num_tokens_for_logprob = 0
2024-12-06 05:49:29 -08:00
elif local_batch . forward_mode . is_decode ( ) :
num_tokens = local_batch . batch_size ( )
2025-05-13 02:51:39 -04:00
num_tokens_for_logprob = num_tokens
2024-12-06 05:49:29 -08:00
else :
num_tokens = local_batch . extend_num_tokens
2025-05-13 02:51:39 -04:00
num_tokens_for_logprob = sum (
2025-03-13 08:23:56 -07:00
[
# We should have at least 1 token for sample in every case.
max ( extend_len - logprob_start_len , 1 )
for logprob_start_len , extend_len in zip (
local_batch . extend_logprob_start_lens , local_batch . extend_lens
)
]
)
if local_batch is None or local_batch . forward_mode . is_decode_or_idle ( ) :
can_cuda_graph = 1
else :
can_cuda_graph = 0
2025-04-09 14:44:25 +08:00
if not spec_algorithm . is_none ( ) :
2025-05-12 12:53:26 -07:00
# TODO(sang): Support cuda graph when idle batch is there.
2025-03-13 08:23:56 -07:00
if local_batch is None or local_batch . forward_mode . is_idle ( ) :
can_cuda_graph = 0
2024-12-06 05:49:29 -08:00
2025-03-13 08:23:56 -07:00
is_extend_in_batch = (
local_batch . forward_mode . is_extend ( ) if local_batch else False
)
2025-05-25 08:39:07 +08:00
tbo_preparer = TboDPAttentionPreparer ( )
2025-03-13 08:23:56 -07:00
local_info = torch . tensor (
[
num_tokens ,
can_cuda_graph ,
2025-05-13 02:51:39 -04:00
num_tokens_for_logprob ,
2025-03-13 08:23:56 -07:00
is_extend_in_batch ,
2025-05-25 08:39:07 +08:00
* tbo_preparer . prepare_all_gather (
local_batch ,
deepep_mode ,
enable_deepep_moe ,
enable_two_batch_overlap ,
) ,
2025-03-13 08:23:56 -07:00
] ,
dtype = torch . int64 ,
)
global_info = torch . empty (
2025-05-25 08:39:07 +08:00
( dp_size , attn_tp_size , 6 ) ,
2025-03-13 08:23:56 -07:00
dtype = torch . int64 ,
)
2024-12-06 05:49:29 -08:00
torch . distributed . all_gather_into_tensor (
2025-03-13 08:23:56 -07:00
global_info . flatten ( ) ,
local_info ,
2025-04-09 14:44:25 +08:00
group = tp_cpu_group ,
2024-12-06 05:49:29 -08:00
)
2025-03-13 08:23:56 -07:00
global_num_tokens = global_info [ : , 0 , 0 ] . tolist ( )
can_cuda_graph = min ( global_info [ : , 0 , 1 ] . tolist ( ) )
global_num_tokens_for_logprob = global_info [ : , 0 , 2 ] . tolist ( )
is_extend_in_batch = global_info [ : , 0 , 3 ] . tolist ( )
2024-12-06 05:49:29 -08:00
2025-05-25 08:39:07 +08:00
tbo_split_seq_index , global_forward_mode = tbo_preparer . compute_output (
global_info [ : , : , 4 : 6 ]
)
2025-03-13 08:23:56 -07:00
if local_batch is None and max ( global_num_tokens ) > 0 :
2025-04-09 14:44:25 +08:00
local_batch = get_idle_batch ( )
2024-12-06 05:49:29 -08:00
if local_batch is not None :
2025-05-13 02:51:39 -04:00
# TODO: handle the case when moe_dense_tp_size != 1
if moe_dense_tp_size == 1 and global_server_args_dict [ " enable_dp_lm_head " ] :
local_batch . global_num_tokens = [ num_tokens ]
local_batch . global_num_tokens_for_logprob = [ num_tokens_for_logprob ]
else :
local_batch . global_num_tokens = global_num_tokens
local_batch . global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)
2025-06-17 15:33:28 +08:00
local_batch . is_extend_in_batch = any ( is_extend_in_batch )
2025-05-25 08:39:07 +08:00
local_batch . tbo_split_seq_index = tbo_split_seq_index
local_batch . global_forward_mode = global_forward_mode
2024-12-06 05:49:29 -08:00
2025-05-12 12:53:26 -07:00
# Check forward mode for cuda graph
2025-04-09 14:44:25 +08:00
if not disable_cuda_graph :
2025-03-13 08:23:56 -07:00
local_batch . can_run_dp_cuda_graph = can_cuda_graph
2024-12-06 05:49:29 -08:00
2025-06-17 15:33:28 +08:00
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
2025-03-13 08:23:56 -07:00
return local_batch , any ( is_extend_in_batch )
2024-12-06 05:49:29 -08:00
def get_idle_batch ( self ) :
idle_batch = ScheduleBatch . init_new (
[ ] ,
self . req_to_token_pool ,
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator ,
2024-12-06 05:49:29 -08:00
self . tree_cache ,
self . model_config ,
self . enable_overlap ,
2025-01-02 02:09:08 -08:00
self . spec_algorithm ,
2025-01-20 02:00:35 -08:00
self . server_args . enable_custom_logit_processor ,
2024-12-06 05:49:29 -08:00
)
idle_batch . prepare_for_idle ( )
return idle_batch
2024-11-13 01:49:45 -08:00
def move_ready_grammar_requests ( self ) :
""" Move requests whose grammar objects are ready from grammar_queue to waiting_queue. """
2025-04-30 18:18:07 -07:00
2024-11-13 01:49:45 -08:00
num_ready_reqs = 0
2025-06-01 19:00:07 -07:00
num_timeout_reqs = 0
2024-11-13 01:49:45 -08:00
for req in self . grammar_queue :
try :
2025-06-01 19:00:07 -07:00
if req . finished ( ) : # It is aborted by AbortReq
num_ready_reqs + = 1
continue
2025-05-11 08:36:16 -07:00
req . grammar = req . grammar . result ( timeout = 0.03 )
2025-06-01 19:00:07 -07:00
self . grammar_backend . set_cache ( req . grammar_key , req . grammar . copy ( ) )
if req . grammar is INVALID_GRAMMAR_OBJ :
req . set_finish_with_abort (
f " Invalid grammar request: { req . grammar_key =} "
)
2024-11-13 01:49:45 -08:00
num_ready_reqs + = 1
except futures . _base . TimeoutError :
2025-05-11 08:36:16 -07:00
req . grammar_wait_ct + = 1
2025-06-01 19:00:07 -07:00
# NOTE(lianmin): this timeout is the waiting time of the above line. It is
# not the waiting time from it enters the grammar queue.
2025-05-11 08:36:16 -07:00
if req . grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03 :
2025-06-01 19:00:07 -07:00
num_timeout_reqs = 1
2024-11-13 01:49:45 -08:00
break
2025-02-26 01:32:05 +08:00
if self . server_args . enable_dp_attention :
2025-03-04 13:40:40 -08:00
tp_size = self . attn_tp_size
tp_group = self . attn_tp_cpu_group
2025-02-26 01:32:05 +08:00
else :
2025-03-04 13:40:40 -08:00
tp_size = self . tp_size
tp_group = self . tp_cpu_group
if tp_size > 1 :
# Sync across TP ranks to make sure they have the same number of ready requests
2025-06-01 19:00:07 -07:00
tensor = torch . tensor ( [ num_ready_reqs , num_timeout_reqs ] , dtype = torch . int32 )
2025-03-04 13:40:40 -08:00
torch . distributed . all_reduce (
tensor , op = torch . distributed . ReduceOp . MAX , group = tp_group
)
2025-06-01 19:00:07 -07:00
num_ready_reqs_max , num_timeout_reqs_max = tensor . tolist ( )
2025-05-11 08:36:16 -07:00
2025-03-04 13:40:40 -08:00
for i in range ( num_ready_reqs , num_ready_reqs_max ) :
2025-05-11 08:36:16 -07:00
req = self . grammar_queue [ i ]
2025-06-01 19:00:07 -07:00
if req . finished ( ) : # It is aborted by AbortReq
continue
2025-05-11 08:36:16 -07:00
req . grammar = req . grammar . result ( )
2025-06-01 19:00:07 -07:00
self . grammar_backend . set_cache ( req . grammar_key , req . grammar . copy ( ) )
if req . grammar is INVALID_GRAMMAR_OBJ :
req . set_finish_with_abort (
f " Invalid grammar request: { req . grammar_key =} "
)
else :
num_ready_reqs_max = num_ready_reqs
num_timeout_reqs_max = num_timeout_reqs
2025-05-11 08:36:16 -07:00
2025-06-01 19:00:07 -07:00
for i in range ( num_ready_reqs , num_ready_reqs + num_timeout_reqs_max ) :
req = self . grammar_queue [ i ]
req . grammar . cancel ( )
error_msg = f " Grammar preprocessing timed out for { req . grammar_key =} "
req . set_finish_with_abort ( error_msg )
self . grammar_backend . set_cache ( req . grammar_key , INVALID_GRAMMAR_OBJ )
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2024-11-13 01:49:45 -08:00
2025-03-03 00:12:04 -08:00
self . _extend_requests_to_queue ( self . grammar_queue [ : num_ready_reqs ] )
2024-11-13 01:49:45 -08:00
self . grammar_queue = self . grammar_queue [ num_ready_reqs : ]
2025-05-12 14:33:38 -07:00
def set_next_batch_sampling_info_done ( self , batch : ScheduleBatch ) :
if batch . next_batch_sampling_info :
if batch . next_batch_sampling_info . grammars is not None :
batch . next_batch_sampling_info . update_regex_vocab_mask ( )
self . current_stream . synchronize ( )
batch . next_batch_sampling_info . sampling_info_done . set ( )
2025-03-06 00:13:20 -08:00
def watchdog_thread ( self ) :
""" A watch dog thread that will try to kill the server itself if one forward batch takes too long. """
self . watchdog_last_forward_ct = 0
2025-05-17 16:49:18 -07:00
self . watchdog_last_time = time . perf_counter ( )
2025-03-06 00:13:20 -08:00
while True :
2025-05-17 16:49:18 -07:00
current = time . perf_counter ( )
2025-03-06 00:13:20 -08:00
if self . cur_batch is not None :
if self . watchdog_last_forward_ct == self . forward_ct :
if current > self . watchdog_last_time + self . watchdog_timeout :
break
else :
self . watchdog_last_forward_ct = self . forward_ct
self . watchdog_last_time = current
time . sleep ( self . watchdog_timeout / / 2 )
2025-05-10 21:54:46 -07:00
if not disable_request_logging ( ) :
# Print batch size and memory pool info to check whether there are de-sync issues.
logger . error (
f " { self . cur_batch . batch_size ( ) =} , "
f " { self . cur_batch . reqs =} , "
f " { self . token_to_kv_pool_allocator . available_size ( ) =} , "
f " { self . tree_cache . evictable_size ( ) =} , "
)
2025-03-06 00:13:20 -08:00
pyspy_dump_schedulers ( )
2025-05-10 21:54:46 -07:00
logger . error ( f " Watchdog timeout ( { self . watchdog_timeout =} ) " )
2025-03-06 00:13:20 -08:00
print ( file = sys . stderr , flush = True )
print ( file = sys . stdout , flush = True )
2025-05-10 21:54:46 -07:00
# Wait for some time so that the parent process can print the error.
2025-03-06 00:13:20 -08:00
time . sleep ( 5 )
self . parent_process . send_signal ( signal . SIGQUIT )
2025-04-21 09:15:03 +08:00
def flush_cache_wrapped ( self , recv_req : FlushCacheReqInput ) :
success = self . flush_cache ( )
return FlushCacheReqOutput ( success = success )
2025-01-19 12:13:27 +08:00
2024-09-29 17:42:45 -07:00
def flush_cache ( self ) :
2024-10-19 23:19:26 -07:00
""" Flush the memory pool and cache. """
2025-04-30 18:18:07 -07:00
if (
len ( self . waiting_queue ) == 0
and self . running_batch . is_empty ( )
and ( self . pp_size == 1 or all ( x . is_empty ( ) for x in self . running_mbs ) )
) :
2025-03-03 00:12:04 -08:00
self . cur_batch = None
self . last_batch = None
2024-09-29 17:42:45 -07:00
self . tree_cache . reset ( )
2024-11-13 01:49:45 -08:00
if self . grammar_backend :
2024-11-12 21:17:38 -08:00
self . grammar_backend . reset ( )
2024-09-29 17:42:45 -07:00
self . req_to_token_pool . clear ( )
2025-03-05 08:06:07 -08:00
self . token_to_kv_pool_allocator . clear ( )
2025-01-20 20:25:13 -08:00
if not self . spec_algorithm . is_none ( ) :
self . draft_worker . model_runner . req_to_token_pool . clear ( )
2025-03-05 08:06:07 -08:00
self . draft_worker . model_runner . token_to_kv_pool_allocator . clear ( )
2025-01-20 20:25:13 -08:00
self . num_generated_tokens = 0
self . forward_ct_decode = 0
self . spec_num_total_accepted_tokens = 0
self . spec_num_total_forward_ct = 0
2025-03-03 00:12:04 -08:00
self . cum_spec_accept_length = 0
self . cum_spec_accept_count = 0
2024-09-29 17:42:45 -07:00
torch . cuda . empty_cache ( )
logger . info ( " Cache flushed successfully! " )
if_success = True
else :
logging . warning (
f " Cache not flushed because there are pending requests. "
f " #queue-req: { len ( self . waiting_queue ) } , "
2025-03-12 22:22:39 -07:00
f " #running-req: { len ( self . running_batch . reqs ) } "
2024-09-29 17:42:45 -07:00
)
if_success = False
return if_success
2025-05-29 20:50:10 +08:00
def get_load ( self ) :
# TODO(lsyin): use dynamically maintained num_waiting_tokens
load = (
self . max_total_num_tokens
- self . token_to_kv_pool_allocator . available_size ( )
- self . tree_cache . evictable_size ( )
)
load + = sum ( len ( req . origin_input_ids ) for req in self . waiting_queue )
if self . disaggregation_mode == DisaggregationMode . PREFILL :
load + = sum (
len ( req . origin_input_ids )
for req in self . disagg_prefill_bootstrap_queue . queue
)
elif self . disaggregation_mode == DisaggregationMode . DECODE :
load + = sum (
len ( req . req . origin_input_ids )
for req in self . disagg_decode_prealloc_queue . queue
)
return load
2025-03-03 00:12:04 -08:00
def get_internal_state ( self , recv_req : GetInternalStateReq ) :
ret = dict ( global_server_args_dict )
ret [ " last_gen_throughput " ] = self . last_gen_throughput
if not self . spec_algorithm . is_none ( ) and self . cum_spec_accept_count > 0 :
ret [ " avg_spec_accept_length " ] = (
self . cum_spec_accept_length / self . cum_spec_accept_count
)
if RECORD_STEP_TIME :
ret [ " step_time_dict " ] = self . step_time_dict
2025-05-29 20:50:10 +08:00
ret [ " load " ] = self . get_load ( )
return GetInternalStateReqOutput ( internal_state = ret )
2025-03-03 00:12:04 -08:00
def set_internal_state ( self , recv_req : SetInternalStateReq ) :
server_args_dict = recv_req . server_args
args_allow_update = set (
[
2025-04-30 18:18:07 -07:00
" max_micro_batch_size " ,
2025-03-03 00:12:04 -08:00
" speculative_accept_threshold_single " ,
" speculative_accept_threshold_acc " ,
]
)
if_success = True
for k , v in server_args_dict . items ( ) :
if k not in args_allow_update :
logging . warning ( f " Updating { k } is not supported. " )
if_success = False
break
2025-04-30 18:18:07 -07:00
elif k == " max_micro_batch_size " and (
v > self . max_running_requests / / self . pp_size or v < 1
) :
logging . warning (
f " Updating { k } to { v } is rejected because it is out of the valid range [1, { self . max_running_requests / / self . pp_size } ]. "
)
if_success = False
break
2025-03-03 00:12:04 -08:00
if if_success :
if not self . spec_algorithm . is_none ( ) and self . cum_spec_accept_count > 0 :
avg_spec_accept_length = (
self . cum_spec_accept_length / self . cum_spec_accept_count
)
logger . info ( f " { avg_spec_accept_length =} " )
self . cum_spec_accept_length = self . cum_spec_accept_count = 0
for k , v in server_args_dict . items ( ) :
global_server_args_dict [ k ] = v
2025-06-04 15:29:34 -07:00
logger . info ( f " Global server args updated! { global_server_args_dict =} " )
2025-03-03 00:12:04 -08:00
return SetInternalStateReqOutput (
updated = True ,
server_args = global_server_args_dict ,
)
2025-03-14 15:40:44 +08:00
def handle_rpc_request ( self , recv_req : RpcReqInput ) :
# Handle RPC requests
logger . info (
f " handle_rpc_request: { recv_req . method } , param: { recv_req . parameters } "
)
success = True
exec = None
try :
func = getattr ( self , recv_req . method )
func ( recv_req . parameters )
except Exception as e :
success = False
exec = e
logger . error ( f " Failed to call rpc { recv_req . method } : { str ( e ) } " )
barrier ( )
return RpcReqOutput ( success , " " if not exec else str ( exec ) )
def save_remote_model ( self , params ) :
url = params [ " url " ]
2025-03-27 20:28:38 -07:00
worker = self . tp_worker . worker
2025-03-14 15:40:44 +08:00
worker . model_runner . save_remote_model ( url )
def save_sharded_model ( self , params ) :
2025-03-27 20:28:38 -07:00
worker = self . tp_worker . worker
2025-03-14 15:40:44 +08:00
worker . model_runner . save_sharded_model (
path = params [ " path " ] ,
pattern = params [ " pattern " ] ,
max_size = params [ " max_size " ] ,
)
2024-09-29 17:42:45 -07:00
def abort_request ( self , recv_req : AbortReq ) :
# Delete requests in the waiting queue
2025-03-12 22:22:39 -07:00
to_del = [ ]
2024-09-29 17:42:45 -07:00
for i , req in enumerate ( self . waiting_queue ) :
2025-03-12 22:22:39 -07:00
if req . rid . startswith ( recv_req . rid ) :
to_del . append ( i )
2024-09-29 17:42:45 -07:00
2025-03-12 22:22:39 -07:00
# Sort in reverse order to avoid index issues when deleting
2025-05-10 21:54:46 -07:00
for i in reversed ( to_del ) :
2025-06-06 14:35:45 -07:00
# Abort method 1: directly pop from the queue
# This only works for requests that have not started anything.
# We still need to send something back to TokenizerManager to clean up the state.
2025-03-12 22:22:39 -07:00
req = self . waiting_queue . pop ( i )
2025-05-10 21:54:46 -07:00
self . send_to_tokenizer . send_pyobj ( AbortReq ( req . rid ) )
2024-11-28 02:22:15 -08:00
logger . debug ( f " Abort queued request. { req . rid =} " )
2024-09-29 17:42:45 -07:00
2025-06-06 14:35:45 -07:00
# Delete the requests in the grammar queue
for req in self . grammar_queue :
# Abort method 2: call `set_finish_with_abort`
# The request will still run one prefill forward pass.
# In this case, we change the input_ids to be only one token to make this prefill cheap.
if req . rid . startswith ( recv_req . rid ) :
logger . debug ( f " Abort grammar queue request. { req . rid =} " )
2025-06-14 22:49:41 -07:00
if req . grammar :
req . grammar . cancel ( )
2025-06-06 14:35:45 -07:00
req . set_finish_with_abort ( " Aborted by AbortReq. " )
2024-09-29 17:42:45 -07:00
# Delete requests in the running batch
2025-05-10 21:54:46 -07:00
if self . cur_batch is self . running_batch or self . cur_batch is None :
reqs = self . running_batch . reqs
else :
reqs = self . running_batch . reqs + self . cur_batch . reqs
for req in reqs :
2025-03-12 22:22:39 -07:00
if req . rid . startswith ( recv_req . rid ) and not req . finished ( ) :
2025-06-06 14:35:45 -07:00
# Abort method 3: set `to_abort=True`
# The request will still run one decode forward pass.
# Then we reuse all existing code to clean up the KV cache allocation.
2025-03-12 22:22:39 -07:00
logger . debug ( f " Abort running request. { req . rid =} " )
req . to_abort = True
2024-09-29 17:42:45 -07:00
2025-03-06 00:13:20 -08:00
def _pause_engine ( self ) - > Tuple [ List [ Req ] , int ] :
raise NotImplementedError ( )
2024-11-29 17:17:00 -08:00
def update_weights_from_disk ( self , recv_req : UpdateWeightFromDiskReqInput ) :
""" In-place update of the weights from disk. """
success , message = self . tp_worker . update_weights_from_disk ( recv_req )
2024-09-29 17:42:45 -07:00
if success :
flash_cache_success = self . flush_cache ( )
assert flash_cache_success , " Cache flush failed after updating weights "
else :
logger . error ( message )
2025-03-03 00:12:04 -08:00
return UpdateWeightFromDiskReqOutput ( success , message , 0 )
2024-09-29 17:42:45 -07:00
2024-12-01 23:23:18 -08:00
def init_weights_update_group ( self , recv_req : InitWeightsUpdateGroupReqInput ) :
""" Initialize the online model parameter update group. """
success , message = self . tp_worker . init_weights_update_group ( recv_req )
2025-01-19 12:13:27 +08:00
return InitWeightsUpdateGroupReqOutput ( success , message )
2024-12-01 23:23:18 -08:00
def update_weights_from_distributed (
2025-01-07 02:52:53 -08:00
self ,
recv_req : UpdateWeightsFromDistributedReqInput ,
) - > Tuple [ bool , str ] :
2024-12-01 23:23:18 -08:00
""" Update the online model parameter. """
success , message = self . tp_worker . update_weights_from_distributed ( recv_req )
if success :
flash_cache_success = self . flush_cache ( )
assert flash_cache_success , " Cache flush failed after updating weights "
else :
logger . error ( message )
2025-01-19 12:13:27 +08:00
return UpdateWeightsFromDistributedReqOutput ( success , message )
2024-12-01 23:23:18 -08:00
2024-12-29 05:30:27 +08:00
def update_weights_from_tensor ( self , recv_req : UpdateWeightsFromTensorReqInput ) :
""" Update the online model parameter from tensors. """
success , message = self . tp_worker . update_weights_from_tensor ( recv_req )
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success :
2025-03-01 01:53:10 +08:00
if recv_req . flush_cache :
flash_cache_success = self . flush_cache ( )
assert flash_cache_success , " Cache flush failed after updating weights "
2024-12-29 05:30:27 +08:00
else :
logger . error ( message )
2025-01-19 12:13:27 +08:00
return UpdateWeightsFromTensorReqOutput ( success , message )
2024-12-29 05:30:27 +08:00
2024-11-29 23:36:38 -08:00
def get_weights_by_name ( self , recv_req : GetWeightsByNameReqInput ) :
parameter = self . tp_worker . get_weights_by_name ( recv_req )
2025-01-19 12:13:27 +08:00
return GetWeightsByNameReqOutput ( parameter )
2024-11-29 23:36:38 -08:00
2025-03-03 00:12:04 -08:00
def release_memory_occupation ( self , recv_req : ReleaseMemoryOccupationReqInput ) :
2025-06-19 00:56:37 -07:00
tags = recv_req . tags
import subprocess
if tags is None :
tags = [ GPU_MEMORY_TYPE_WEIGHTS , GPU_MEMORY_TYPE_KV_CACHE ]
if GPU_MEMORY_TYPE_KV_CACHE in tags :
self . memory_saver_adapter . pause ( GPU_MEMORY_TYPE_KV_CACHE )
self . flush_cache ( )
if GPU_MEMORY_TYPE_WEIGHTS in tags :
self . stashed_model_static_state = _export_static_state (
self . tp_worker . worker . model_runner . model
)
self . memory_saver_adapter . pause ( GPU_MEMORY_TYPE_WEIGHTS )
2025-01-19 12:13:27 +08:00
return ReleaseMemoryOccupationReqOutput ( )
2025-01-14 03:38:51 +08:00
2025-03-03 00:12:04 -08:00
def resume_memory_occupation ( self , recv_req : ResumeMemoryOccupationReqInput ) :
2025-06-19 00:56:37 -07:00
tags = recv_req . tags
if tags is None or len ( tags ) == 0 :
tags = [ GPU_MEMORY_TYPE_WEIGHTS , GPU_MEMORY_TYPE_KV_CACHE ]
if GPU_MEMORY_TYPE_WEIGHTS in tags :
self . memory_saver_adapter . resume ( GPU_MEMORY_TYPE_WEIGHTS )
_import_static_state (
self . tp_worker . worker . model_runner . model ,
self . stashed_model_static_state ,
)
del self . stashed_model_static_state
if GPU_MEMORY_TYPE_KV_CACHE in tags :
self . memory_saver_adapter . resume ( GPU_MEMORY_TYPE_KV_CACHE )
2025-01-19 12:13:27 +08:00
return ResumeMemoryOccupationReqOutput ( )
2025-05-08 16:03:08 +08:00
def slow_down ( self , recv_req : SlowDownReqInput ) :
t = recv_req . forward_sleep_time
if t is not None and t < = 0 :
t = None
self . forward_sleep_time = t
return SlowDownReqOutput ( )
2025-01-19 12:13:27 +08:00
def profile ( self , recv_req : ProfileReq ) :
2025-03-03 00:12:04 -08:00
if recv_req . type == ProfileReqType . START_PROFILE :
2025-05-31 15:53:55 -07:00
if recv_req . profile_by_stage :
return self . init_profile (
recv_req . output_dir ,
recv_req . num_steps ,
recv_req . activities ,
recv_req . with_stack ,
recv_req . record_shapes ,
recv_req . profile_by_stage ,
2025-06-03 02:17:22 +08:00
recv_req . profile_id ,
2025-05-31 15:53:55 -07:00
)
else :
self . init_profile (
recv_req . output_dir ,
recv_req . num_steps ,
recv_req . activities ,
recv_req . with_stack ,
recv_req . record_shapes ,
recv_req . profile_by_stage ,
2025-06-03 02:17:22 +08:00
recv_req . profile_id ,
2025-05-31 15:53:55 -07:00
)
return self . start_profile ( True )
2025-01-19 12:13:27 +08:00
else :
2025-03-03 00:12:04 -08:00
return self . stop_profile ( )
2025-05-31 15:53:55 -07:00
def init_profile (
2025-03-03 00:12:04 -08:00
self ,
output_dir : Optional [ str ] ,
num_steps : Optional [ int ] ,
activities : Optional [ List [ str ] ] ,
2025-03-28 14:01:42 +08:00
with_stack : Optional [ bool ] ,
record_shapes : Optional [ bool ] ,
2025-05-31 15:53:55 -07:00
profile_by_stage : bool ,
2025-06-03 02:17:22 +08:00
profile_id : str ,
2025-05-31 15:53:55 -07:00
) - > ProfileReqOutput :
if self . profile_in_progress :
2025-03-03 00:12:04 -08:00
return ProfileReqOutput (
success = False ,
message = " Profiling is already in progress. Call /stop_profile first. " ,
)
2025-05-31 15:53:55 -07:00
self . profile_by_stage = profile_by_stage
2025-03-03 00:12:04 -08:00
if output_dir is None :
output_dir = os . getenv ( " SGLANG_TORCH_PROFILER_DIR " , " /tmp " )
if activities is None :
activities = [ " CPU " , " GPU " ]
self . torch_profiler_output_dir = output_dir
2025-05-31 15:53:55 -07:00
self . torch_profiler_with_stack = with_stack
self . torch_profiler_record_shapes = record_shapes
2025-03-28 13:21:13 +08:00
self . profiler_activities = activities
2025-06-03 02:17:22 +08:00
self . profile_id = profile_id
2025-05-31 15:53:55 -07:00
if num_steps :
self . profile_steps = num_steps
if self . profile_by_stage :
self . profiler_target_prefill_ct = num_steps
self . profiler_target_decode_ct = num_steps
self . profiler_prefill_ct = 0
self . profiler_decode_ct = 0
else :
self . profiler_target_forward_ct = self . forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else :
self . profiler_target_forward_ct = None
return ProfileReqOutput ( success = True , message = " Succeeded " )
def start_profile (
self , stage : Optional [ ForwardMode ] = None
) - > ProfileReqOutput | None :
stage_str = f " for { stage . __str__ ( ) } " if stage else " "
2025-03-03 00:12:04 -08:00
logger . info (
2025-06-07 17:32:50 +08:00
f " Profiling starts { stage_str } . Traces will be saved to: { self . torch_profiler_output_dir } (with profile id: { self . profile_id } ) " ,
2025-03-03 00:12:04 -08:00
)
2025-05-31 15:53:55 -07:00
activities = self . profiler_activities
with_stack = self . torch_profiler_with_stack
record_shapes = self . torch_profiler_record_shapes
2025-03-03 00:12:04 -08:00
activity_map = {
" CPU " : torch . profiler . ProfilerActivity . CPU ,
" GPU " : torch . profiler . ProfilerActivity . CUDA ,
}
torchprof_activities = [
activity_map [ a ] for a in activities if a in activity_map
]
2025-05-31 15:53:55 -07:00
if " RPD " in activities :
from rpdTracerControl import rpdTracerControl
rpdTracerControl . skipCreate ( )
self . rpd_profile_path = os . path . join (
self . torch_profiler_output_dir ,
" rpd- " + str ( time . time ( ) ) + f " -TP- { self . tp_rank } " + " .trace.json.gz " ,
)
if self . tp_rank == 0 :
import sqlite3
from rocpd . schema import RocpdSchema
if os . path . exists ( " trace.rpd " ) :
os . unlink ( " trace.rpd " )
schema = RocpdSchema ( )
connection = sqlite3 . connect ( " trace.rpd " )
schema . writeSchema ( connection )
connection . commit ( )
del connection
torch . distributed . barrier ( self . tp_cpu_group )
self . rpd_profiler = rpdTracerControl ( )
self . rpd_profiler . setPythonTrace ( True )
self . rpd_profiler . start ( )
self . rpd_profiler . rangePush ( " " , " rpd profile range " , " " )
self . profile_in_progress = True
elif torchprof_activities :
2025-03-03 00:12:04 -08:00
self . torch_profiler = torch . profiler . profile (
activities = torchprof_activities ,
2025-03-28 14:01:42 +08:00
with_stack = with_stack if with_stack is not None else True ,
record_shapes = record_shapes if record_shapes is not None else False ,
2025-03-03 00:12:04 -08:00
)
self . torch_profiler . start ( )
2025-05-31 15:53:55 -07:00
self . profile_in_progress = True
2025-03-03 00:12:04 -08:00
if " MEM " in activities :
torch . cuda . memory . _record_memory_history ( max_entries = 100000 )
2025-05-31 15:53:55 -07:00
self . profile_in_progress = True
2025-01-14 03:38:51 +08:00
2025-03-28 13:21:13 +08:00
if " CUDA_PROFILER " in activities :
torch . cuda . cudart ( ) . cudaProfilerStart ( )
2025-05-31 15:53:55 -07:00
return ProfileReqOutput ( success = True , message = " Succeeded " )
2024-10-11 17:34:25 +08:00
2025-05-31 15:53:55 -07:00
def stop_profile (
self , stage : Optional [ ForwardMode ] = None
) - > ProfileReqOutput | None :
if not self . profile_in_progress :
2025-05-18 08:06:15 +08:00
return ProfileReqOutput (
success = False ,
message = " Profiling is not in progress. Call /start_profile first. " ,
)
2025-03-03 00:12:04 -08:00
2025-06-05 15:07:03 +08:00
if not Path ( self . torch_profiler_output_dir ) . exists ( ) :
Path ( self . torch_profiler_output_dir ) . mkdir ( parents = True , exist_ok = True )
2025-05-31 15:53:55 -07:00
stage_suffix = f " - { stage . __str__ ( ) } " if stage else " "
logger . info ( " Stop profiling " + stage_suffix + " ... " )
2025-03-03 00:12:04 -08:00
if self . torch_profiler is not None :
self . torch_profiler . stop ( )
self . torch_profiler . export_chrome_trace (
os . path . join (
self . torch_profiler_output_dir ,
2025-06-03 02:17:22 +08:00
self . profile_id
2025-05-31 15:53:55 -07:00
+ f " -TP- { self . tp_rank } "
+ stage_suffix
+ " .trace.json.gz " ,
2025-03-03 00:12:04 -08:00
)
)
2025-05-31 15:53:55 -07:00
torch . distributed . barrier ( self . tp_cpu_group )
if self . rpd_profiler is not None :
self . rpd_profiler . rangePop ( )
self . rpd_profiler . stop ( )
self . rpd_profiler . flush ( )
2025-03-03 00:12:04 -08:00
2025-05-31 15:53:55 -07:00
torch . distributed . barrier ( self . tp_cpu_group )
if self . tp_rank == 0 :
from sglang . srt . utils import rpd_to_chrome_trace
rpd_to_chrome_trace ( " trace.rpd " , self . rpd_profile_path )
self . rpd_profiler = None
self . rpd_profiler_path = None
if self . profiler_activities is not None and " MEM " in self . profiler_activities :
2025-03-03 00:12:04 -08:00
memory_profile_path = os . path . join (
2025-03-29 01:35:02 +08:00
self . torch_profiler_output_dir ,
2025-05-31 15:53:55 -07:00
str ( time . time ( ) )
+ f " -TP- { self . tp_rank } -memory "
+ stage_suffix
+ " .pickle " ,
2025-03-03 00:12:04 -08:00
)
torch . cuda . memory . _dump_snapshot ( memory_profile_path )
torch . cuda . memory . _record_memory_history ( enabled = None )
2025-03-28 13:21:13 +08:00
if " CUDA_PROFILER " in self . profiler_activities :
torch . cuda . cudart ( ) . cudaProfilerStop ( )
2025-03-03 00:12:04 -08:00
logger . info (
" Profiling done. Traces are saved to: %s " ,
self . torch_profiler_output_dir ,
2024-10-11 17:34:25 +08:00
)
2025-03-03 00:12:04 -08:00
self . torch_profiler = None
2025-05-31 15:53:55 -07:00
self . profile_in_progress = False
return ProfileReqOutput ( success = True , message = " Succeeded. " )
def _profile_batch_predicate ( self , batch ) :
if self . profile_by_stage :
if batch . forward_mode . is_prefill ( ) :
if self . profiler_prefill_ct == 0 :
self . start_profile ( batch . forward_mode )
self . profiler_prefill_ct + = 1
if self . profiler_prefill_ct > self . profiler_target_prefill_ct :
if self . profile_in_progress :
self . stop_profile ( stage = ForwardMode . EXTEND )
elif batch . forward_mode . is_decode ( ) :
if self . profiler_decode_ct == 0 :
if self . profile_in_progress :
# force trace flush
self . stop_profile ( ForwardMode . EXTEND )
self . start_profile ( batch . forward_mode )
self . profiler_decode_ct + = 1
if self . profiler_decode_ct > self . profiler_target_decode_ct :
if self . profile_in_progress :
self . stop_profile ( stage = ForwardMode . DECODE )
2025-06-19 01:55:01 +08:00
elif batch . forward_mode . is_idle ( ) :
pass
2025-05-31 15:53:55 -07:00
else :
2025-06-19 01:55:01 +08:00
raise RuntimeError ( f " unsupported profile stage: { batch . forward_mode } " )
2025-05-31 15:53:55 -07:00
else :
# Check profiler
if (
self . profiler_target_forward_ct
and self . profiler_target_forward_ct < = self . forward_ct
) :
self . stop_profile ( )
2024-10-11 17:34:25 +08:00
2025-03-24 21:34:19 -07:00
def expert_distribution_handle ( self , recv_req : ExpertDistributionReq ) :
if recv_req == ExpertDistributionReq . START_RECORD :
2025-05-20 11:07:43 +08:00
get_global_expert_distribution_recorder ( ) . start_record ( )
2025-03-24 21:34:19 -07:00
elif recv_req == ExpertDistributionReq . STOP_RECORD :
2025-05-20 11:07:43 +08:00
get_global_expert_distribution_recorder ( ) . stop_record ( )
2025-03-24 21:34:19 -07:00
elif recv_req == ExpertDistributionReq . DUMP_RECORD :
2025-05-20 11:07:43 +08:00
get_global_expert_distribution_recorder ( ) . dump_record ( )
2025-03-24 21:34:19 -07:00
else :
raise ValueError ( " Unrecognized ExpertDistributionReq value " )
2025-03-25 16:17:03 +08:00
return ExpertDistributionReqOutput ( )
2025-03-24 21:34:19 -07:00
2025-01-19 12:13:27 +08:00
def open_session ( self , recv_req : OpenSessionReqInput ) :
2024-11-20 00:36:53 -08:00
# handle error
session_id = recv_req . session_id
if session_id in self . sessions :
logger . warning ( f " session id { session_id } already exist, cannot open. " )
2025-01-19 12:13:27 +08:00
return OpenSessionReqOutput ( session_id , False )
2024-12-29 02:10:27 -08:00
elif session_id is None :
2025-03-03 00:12:04 -08:00
logger . warning ( " session id is None, cannot open. " )
2025-01-19 12:13:27 +08:00
return OpenSessionReqOutput ( session_id , False )
2024-11-20 00:36:53 -08:00
else :
self . sessions [ session_id ] = Session (
recv_req . capacity_of_str_len , session_id
)
2025-01-19 12:13:27 +08:00
return OpenSessionReqOutput ( session_id , True )
2024-11-20 00:36:53 -08:00
def close_session ( self , recv_req : CloseSessionReqInput ) :
# handle error
session_id = recv_req . session_id
if session_id not in self . sessions :
logger . warning ( f " session id { session_id } does not exist, cannot delete. " )
else :
del self . sessions [ session_id ]
2025-04-30 18:18:07 -07:00
def get_print_prefix ( self ) :
prefix = " "
2025-05-13 02:51:39 -04:00
if self . attn_dp_rank is not None :
prefix + = f " DP { self . attn_dp_rank } "
2025-04-30 18:18:07 -07:00
if self . server_args . tp_size > 1 :
prefix + = f " TP { self . tp_rank } "
if self . pp_size > 1 :
prefix + = f " PP { self . pp_rank } "
return prefix
2025-05-19 14:19:54 -07:00
def _publish_kv_events ( self ) :
if self . enable_kv_cache_events :
events = self . tree_cache . take_events ( )
if events :
batch = KVEventBatch ( ts = time . time ( ) , events = events )
self . kv_event_publisher . publish ( batch )
2024-09-29 02:36:12 -07:00
2025-03-03 00:12:04 -08:00
def is_health_check_generate_req ( recv_req ) :
return getattr ( recv_req , " rid " , " " ) . startswith ( " HEALTH_CHECK " )
2025-01-14 03:38:51 +08:00
def _export_static_state ( model ) :
return dict (
buffers = [
( name , buffer . detach ( ) . clone ( ) ) for name , buffer in model . named_buffers ( )
]
)
def _import_static_state ( model , static_params ) :
self_named_buffers = dict ( model . named_buffers ( ) )
for name , tensor in static_params [ " buffers " ] :
self_named_buffers [ name ] [ . . . ] = tensor
2024-09-29 02:36:12 -07:00
def run_scheduler_process (
server_args : ServerArgs ,
port_args : PortArgs ,
gpu_id : int ,
tp_rank : int ,
2025-04-30 18:18:07 -07:00
pp_rank : int ,
2024-10-11 07:22:48 -07:00
dp_rank : Optional [ int ] ,
2024-10-06 00:10:48 -07:00
pipe_writer ,
2024-09-29 02:36:12 -07:00
) :
2025-03-13 10:29:35 +08:00
# Generate the prefix
2025-04-30 18:18:07 -07:00
prefix = " "
if dp_rank is not None :
prefix + = f " DP { dp_rank } "
if server_args . tp_size > 1 :
prefix + = f " TP { tp_rank } "
if server_args . pp_size > 1 :
prefix + = f " PP { pp_rank } "
2025-03-13 10:29:35 +08:00
2025-03-03 00:12:04 -08:00
# Config the process
2025-03-17 05:13:16 -07:00
kill_itself_when_parent_died ( )
2025-03-13 10:29:35 +08:00
setproctitle . setproctitle ( f " sglang::scheduler { prefix . replace ( ' ' , ' _ ' ) } " )
2025-01-13 01:39:14 -08:00
faulthandler . enable ( )
2025-03-03 00:12:04 -08:00
parent_process = psutil . Process ( ) . parent ( )
2024-12-08 01:06:15 -08:00
2024-11-27 09:36:36 -08:00
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and " SGLANG_DP_RANK " in os . environ :
dp_rank = int ( os . environ [ " SGLANG_DP_RANK " ] )
2024-11-21 17:13:33 -08:00
2025-02-25 08:51:23 +08:00
# Configure the logger
2025-03-03 00:12:04 -08:00
configure_logger ( server_args , prefix = prefix )
2024-12-29 00:45:57 -08:00
suppress_other_loggers ( )
2024-10-11 07:22:48 -07:00
2024-12-29 00:45:57 -08:00
# Set cpu affinity to this gpu process
2024-12-06 05:49:29 -08:00
if get_bool_env_var ( " SGLANG_SET_CPU_AFFINITY " ) :
set_gpu_proc_affinity ( server_args . tp_size , server_args . nnodes , gpu_id )
2025-05-22 20:32:41 -07:00
embedding_cache_size = 100
if " SGLANG_VLM_CACHE_SIZE_MB " in os . environ :
embedding_cache_size = int ( os . environ [ " SGLANG_VLM_CACHE_SIZE_MB " ] )
init_embedding_cache ( embedding_cache_size * 1024 * 1024 )
2024-12-29 00:45:57 -08:00
# Create a scheduler and run the event loop
2024-09-29 02:36:12 -07:00
try :
2025-04-30 18:18:07 -07:00
scheduler = Scheduler ( server_args , port_args , gpu_id , tp_rank , pp_rank , dp_rank )
2024-11-22 15:10:10 -08:00
pipe_writer . send (
2025-01-19 06:14:19 +08:00
{
" status " : " ready " ,
" max_total_num_tokens " : scheduler . max_total_num_tokens ,
" max_req_input_len " : scheduler . max_req_input_len ,
}
2024-11-22 15:10:10 -08:00
)
2025-03-21 14:47:47 -07:00
disaggregation_mode : DisaggregationMode = scheduler . disaggregation_mode
if disaggregation_mode == DisaggregationMode . NULL :
2025-04-30 18:18:07 -07:00
if server_args . pp_size > 1 :
scheduler . event_loop_pp ( )
elif scheduler . enable_overlap :
2025-03-21 14:47:47 -07:00
scheduler . event_loop_overlap ( )
else :
scheduler . event_loop_normal ( )
elif disaggregation_mode == DisaggregationMode . PREFILL :
2025-04-21 12:12:56 -07:00
if scheduler . enable_overlap :
scheduler . event_loop_overlap_disagg_prefill ( )
else :
scheduler . event_loop_normal_disagg_prefill ( )
2025-04-30 18:18:07 -07:00
2025-03-21 14:47:47 -07:00
elif disaggregation_mode == DisaggregationMode . DECODE :
2025-04-21 12:06:16 -07:00
if scheduler . enable_overlap :
scheduler . event_loop_overlap_disagg_decode ( )
else :
scheduler . event_loop_normal_disagg_decode ( )
2025-03-21 14:47:47 -07:00
2024-09-29 02:36:12 -07:00
except Exception :
2024-11-28 00:22:39 -08:00
traceback = get_exception_traceback ( )
logger . error ( f " Scheduler hit an exception: { traceback } " )
parent_process . send_signal ( signal . SIGQUIT )