2024-11-22 22:16:53 +08:00
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-09-29 02:36:12 -07:00
""" A scheduler that manages a tensor parallel GPU worker. """
2025-01-13 01:39:14 -08:00
import faulthandler
2024-09-29 02:36:12 -07:00
import logging
2024-09-29 17:42:45 -07:00
import os
2024-11-28 00:22:39 -08:00
import signal
2025-03-03 00:12:04 -08:00
import sys
2024-10-27 02:00:50 -07:00
import threading
2024-09-29 17:42:45 -07:00
import time
import warnings
2025-03-03 00:12:04 -08:00
from collections import defaultdict , deque
2024-11-12 21:17:38 -08:00
from concurrent import futures
2025-01-16 12:51:11 -08:00
from dataclasses import dataclass
2025-01-18 19:37:30 -08:00
from http import HTTPStatus
2024-10-12 21:35:30 -07:00
from types import SimpleNamespace
2025-01-16 12:51:11 -08:00
from typing import Dict , List , Optional , Tuple , Union
2024-09-29 02:36:12 -07:00
2024-11-28 00:22:39 -08:00
import psutil
2024-12-08 01:06:15 -08:00
import setproctitle
2024-09-29 17:42:45 -07:00
import torch
2024-09-29 02:36:12 -07:00
import zmq
2024-09-29 17:42:45 -07:00
from sglang . global_config import global_config
2024-11-24 04:47:10 -08:00
from sglang . srt . configs . model_config import ModelConfig
2025-01-19 17:10:29 -08:00
from sglang . srt . constrained . base_grammar_backend import create_grammar_backend
2024-09-29 17:42:45 -07:00
from sglang . srt . hf_transformers_utils import get_processor , get_tokenizer
2025-01-16 11:15:00 -08:00
from sglang . srt . layers . dp_attention import compute_dp_attention_world_info
2024-09-29 17:42:45 -07:00
from sglang . srt . layers . logits_processor import LogitsProcessorOutput
from sglang . srt . managers . io_struct import (
AbortReq ,
BatchEmbeddingOut ,
2025-03-03 00:12:04 -08:00
BatchMultimodalDecodeReq ,
2024-09-29 17:42:45 -07:00
BatchTokenIDOut ,
2024-11-20 00:36:53 -08:00
CloseSessionReqInput ,
2024-09-29 17:42:45 -07:00
FlushCacheReq ,
2025-03-03 00:12:04 -08:00
GetInternalStateReq ,
GetInternalStateReqOutput ,
2024-11-29 23:36:38 -08:00
GetWeightsByNameReqInput ,
GetWeightsByNameReqOutput ,
2025-03-03 00:12:04 -08:00
HealthCheckOutput ,
2024-12-01 23:23:18 -08:00
InitWeightsUpdateGroupReqInput ,
InitWeightsUpdateGroupReqOutput ,
2024-11-20 00:36:53 -08:00
OpenSessionReqInput ,
OpenSessionReqOutput ,
2024-10-11 17:34:25 +08:00
ProfileReq ,
2025-03-03 00:12:04 -08:00
ProfileReqOutput ,
ProfileReqType ,
2025-01-14 03:38:51 +08:00
ReleaseMemoryOccupationReqInput ,
ReleaseMemoryOccupationReqOutput ,
ResumeMemoryOccupationReqInput ,
ResumeMemoryOccupationReqOutput ,
2025-03-03 00:12:04 -08:00
SetInternalStateReq ,
SetInternalStateReqOutput ,
2024-09-29 17:42:45 -07:00
TokenizedEmbeddingReqInput ,
TokenizedGenerateReqInput ,
2024-11-29 17:17:00 -08:00
UpdateWeightFromDiskReqInput ,
UpdateWeightFromDiskReqOutput ,
2024-12-01 23:23:18 -08:00
UpdateWeightsFromDistributedReqInput ,
UpdateWeightsFromDistributedReqOutput ,
2024-12-29 05:30:27 +08:00
UpdateWeightsFromTensorReqInput ,
UpdateWeightsFromTensorReqOutput ,
2024-09-29 17:42:45 -07:00
)
from sglang . srt . managers . schedule_batch import (
FINISH_ABORT ,
BaseFinishReason ,
ImageInputs ,
Req ,
ScheduleBatch ,
2024-10-19 23:19:26 -07:00
global_server_args_dict ,
2024-09-29 17:42:45 -07:00
)
2024-10-04 18:00:18 -07:00
from sglang . srt . managers . schedule_policy import (
AddReqResult ,
PrefillAdder ,
SchedulePolicy ,
)
2024-11-20 00:36:53 -08:00
from sglang . srt . managers . session_controller import Session
2024-09-30 02:41:11 -07:00
from sglang . srt . managers . tp_worker import TpModelWorker
2024-10-20 00:29:29 -07:00
from sglang . srt . managers . tp_worker_overlap_thread import TpModelWorkerClient
2025-01-16 14:51:19 -08:00
from sglang . srt . managers . utils import validate_input_length
2024-09-29 17:42:45 -07:00
from sglang . srt . mem_cache . chunk_cache import ChunkCache
2025-02-23 21:56:30 -08:00
from sglang . srt . mem_cache . hiradix_cache import HiRadixCache
2024-09-29 17:42:45 -07:00
from sglang . srt . mem_cache . radix_cache import RadixCache
2024-11-10 04:39:32 -08:00
from sglang . srt . metrics . collector import SchedulerMetricsCollector , SchedulerStats
2024-11-19 15:04:43 -08:00
from sglang . srt . model_executor . forward_batch_info import ForwardMode
2024-09-29 02:36:12 -07:00
from sglang . srt . server_args import PortArgs , ServerArgs
2025-01-02 02:09:08 -08:00
from sglang . srt . speculative . spec_info import SpeculativeAlgorithm
2025-01-13 14:24:00 -08:00
from sglang . srt . torch_memory_saver_adapter import TorchMemorySaverAdapter
2024-09-29 17:42:45 -07:00
from sglang . srt . utils import (
broadcast_pyobj ,
configure_logger ,
2024-11-17 22:18:11 -08:00
crash_on_warnings ,
2024-11-27 02:52:46 -08:00
get_bool_env_var ,
2024-10-25 23:07:07 -07:00
get_zmq_socket ,
2025-03-03 00:12:04 -08:00
kill_itself_when_parent_died ,
pyspy_dump_schedulers ,
2024-11-27 02:52:46 -08:00
set_gpu_proc_affinity ,
2024-09-29 17:42:45 -07:00
set_random_seed ,
suppress_other_loggers ,
)
2025-01-19 12:13:27 +08:00
from sglang . utils import TypeBasedDispatcher , get_exception_traceback
2024-09-29 02:36:12 -07:00
logger = logging . getLogger ( __name__ )
2024-12-29 00:45:57 -08:00
# Test retract decode for debugging purposes
2025-03-03 00:12:04 -08:00
TEST_RETRACT = get_bool_env_var ( " SGLANG_TEST_RETRACT " )
RECORD_STEP_TIME = get_bool_env_var ( " SGLANG_RECORD_STEP_TIME " )
2024-10-13 20:32:37 -07:00
2024-09-29 02:36:12 -07:00
2025-01-16 12:51:11 -08:00
@dataclass
class GenerationBatchResult :
logits_output : LogitsProcessorOutput
next_token_ids : List [ int ]
2025-03-03 00:12:04 -08:00
extend_input_len_per_req : List [ int ]
extend_logprob_start_len_per_req : List [ int ]
2025-01-16 12:51:11 -08:00
bid : int
@dataclass
class EmbeddingBatchResult :
embeddings : torch . Tensor
bid : int
2024-09-29 02:36:12 -07:00
class Scheduler :
""" A scheduler that manages a tensor parallel GPU worker. """
def __init__ (
self ,
server_args : ServerArgs ,
port_args : PortArgs ,
gpu_id : int ,
tp_rank : int ,
2024-10-19 17:39:38 -07:00
dp_rank : Optional [ int ] ,
2024-09-29 02:36:12 -07:00
) :
# Parse args
2024-09-29 17:42:45 -07:00
self . server_args = server_args
2024-09-29 02:36:12 -07:00
self . tp_rank = tp_rank
self . tp_size = server_args . tp_size
2024-09-29 17:42:45 -07:00
self . schedule_policy = server_args . schedule_policy
self . lora_paths = server_args . lora_paths
self . max_loras_per_batch = server_args . max_loras_per_batch
2024-11-19 22:07:58 -08:00
self . enable_overlap = not server_args . disable_overlap_schedule
2024-10-25 18:51:59 -07:00
self . skip_tokenizer_init = server_args . skip_tokenizer_init
2024-11-10 04:39:32 -08:00
self . enable_metrics = server_args . enable_metrics
2025-03-03 00:12:04 -08:00
self . stream_interval = server_args . stream_interval
2025-01-02 02:09:08 -08:00
self . spec_algorithm = SpeculativeAlgorithm . from_string (
server_args . speculative_algorithm
)
2025-03-03 00:12:04 -08:00
self . gpu_id = gpu_id
self . enable_hierarchical_cache = server_args . enable_hierarchical_cache
2025-01-02 02:09:08 -08:00
self . decode_mem_cache_buf_multiplier = (
2025-03-03 00:12:04 -08:00
(
self . server_args . speculative_num_draft_tokens
+ (
self . server_args . speculative_eagle_topk
* self . server_args . speculative_num_steps
)
)
2025-01-02 02:09:08 -08:00
if not self . spec_algorithm . is_none ( )
else 1
)
2024-09-29 02:36:12 -07:00
2025-01-19 17:10:29 -08:00
# Distributed rank info
2025-01-16 11:15:00 -08:00
self . dp_size = server_args . dp_size
self . attn_tp_rank , self . attn_tp_size , self . dp_rank = (
compute_dp_attention_world_info (
server_args . enable_dp_attention ,
self . tp_rank ,
self . tp_size ,
self . dp_size ,
)
)
2025-01-19 17:10:29 -08:00
# Init inter-process communication
context = zmq . Context ( 2 )
2025-01-16 11:15:00 -08:00
if self . attn_tp_rank == 0 :
2024-10-25 23:07:07 -07:00
self . recv_from_tokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PULL , port_args . scheduler_input_ipc_name , False
2024-10-25 23:07:07 -07:00
)
2024-11-15 05:02:44 -08:00
self . send_to_tokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PUSH , port_args . tokenizer_ipc_name , False
2024-11-15 05:02:44 -08:00
)
2024-09-29 02:36:12 -07:00
2024-10-25 18:51:59 -07:00
if server_args . skip_tokenizer_init :
2025-02-24 14:47:59 -08:00
# Directly send to the TokenizerManager
2024-10-25 23:07:07 -07:00
self . send_to_detokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PUSH , port_args . tokenizer_ipc_name , False
2024-10-25 18:51:59 -07:00
)
else :
2024-12-29 00:45:57 -08:00
# Send to the DetokenizerManager
2024-10-25 23:07:07 -07:00
self . send_to_detokenizer = get_zmq_socket (
2025-01-16 14:36:07 -08:00
context , zmq . PUSH , port_args . detokenizer_ipc_name , False
2024-10-25 18:51:59 -07:00
)
2024-09-29 02:36:12 -07:00
else :
2024-10-12 21:35:30 -07:00
self . recv_from_tokenizer = None
2024-11-22 15:46:16 -08:00
self . send_to_tokenizer = SimpleNamespace ( send_pyobj = lambda x : None )
self . send_to_detokenizer = SimpleNamespace ( send_pyobj = lambda x : None )
2024-09-29 17:42:45 -07:00
# Init tokenizer
self . model_config = ModelConfig (
server_args . model_path ,
2024-11-03 12:25:39 -08:00
trust_remote_code = server_args . trust_remote_code ,
2024-12-02 23:22:13 +08:00
revision = server_args . revision ,
2024-09-29 17:42:45 -07:00
context_length = server_args . context_length ,
2024-11-03 12:25:39 -08:00
model_override_args = server_args . json_model_override_args ,
is_embedding = server_args . is_embedding ,
2024-12-02 23:22:13 +08:00
dtype = server_args . dtype ,
quantization = server_args . quantization ,
2024-09-29 17:42:45 -07:00
)
2024-11-03 12:25:39 -08:00
self . is_generation = self . model_config . is_generation
2024-09-29 17:42:45 -07:00
if server_args . skip_tokenizer_init :
self . tokenizer = self . processor = None
else :
2024-11-03 12:25:39 -08:00
if self . model_config . is_multimodal :
2024-09-29 17:42:45 -07:00
self . processor = get_processor (
server_args . tokenizer_path ,
tokenizer_mode = server_args . tokenizer_mode ,
trust_remote_code = server_args . trust_remote_code ,
2025-01-19 20:36:07 +01:00
revision = server_args . revision ,
2024-09-29 17:42:45 -07:00
)
self . tokenizer = self . processor . tokenizer
else :
self . tokenizer = get_tokenizer (
server_args . tokenizer_path ,
tokenizer_mode = server_args . tokenizer_mode ,
trust_remote_code = server_args . trust_remote_code ,
2025-01-19 20:36:07 +01:00
revision = server_args . revision ,
2024-09-29 17:42:45 -07:00
)
2024-09-29 02:36:12 -07:00
2024-11-19 22:07:58 -08:00
# Check whether overlap can be enabled
if not self . is_generation :
self . enable_overlap = False
logger . info ( " Overlap scheduler is disabled for embedding models. " )
2024-11-20 02:58:35 -08:00
2024-11-27 23:44:33 -08:00
if self . model_config . is_multimodal :
self . enable_overlap = False
logger . info ( " Overlap scheduler is disabled for multimodal models. " )
2024-09-29 17:42:45 -07:00
# Launch a tensor parallel worker
2024-10-20 18:17:41 -07:00
if self . enable_overlap :
2024-10-20 00:29:29 -07:00
TpWorkerClass = TpModelWorkerClient
2024-10-19 23:19:26 -07:00
else :
TpWorkerClass = TpModelWorker
2024-10-20 00:29:29 -07:00
2024-10-19 23:19:26 -07:00
self . tp_worker = TpWorkerClass (
2024-10-19 17:39:38 -07:00
server_args = server_args ,
2024-09-29 02:36:12 -07:00
gpu_id = gpu_id ,
tp_rank = tp_rank ,
2024-10-19 17:39:38 -07:00
dp_rank = dp_rank ,
2024-10-11 07:22:48 -07:00
nccl_port = port_args . nccl_port ,
2024-09-29 02:36:12 -07:00
)
2024-09-29 17:42:45 -07:00
2025-03-03 00:12:04 -08:00
# Launch a draft worker for speculative decoding
2025-01-02 02:09:08 -08:00
if self . spec_algorithm . is_eagle ( ) :
from sglang . srt . speculative . eagle_worker import EAGLEWorker
self . draft_worker = EAGLEWorker (
gpu_id = gpu_id ,
tp_rank = tp_rank ,
server_args = server_args ,
nccl_port = port_args . nccl_port ,
target_worker = self . tp_worker ,
dp_rank = dp_rank ,
)
2025-03-03 00:12:04 -08:00
self . prefill_only_one_req = True
2025-01-02 02:09:08 -08:00
else :
self . draft_worker = None
2025-03-03 00:12:04 -08:00
self . prefill_only_one_req = False
2025-01-02 02:09:08 -08:00
2024-10-06 00:10:48 -07:00
# Get token and memory info from the model worker
2024-09-29 17:42:45 -07:00
(
self . max_total_num_tokens ,
self . max_prefill_tokens ,
self . max_running_requests ,
2024-10-21 16:12:04 -07:00
self . max_req_len ,
2024-09-29 17:42:45 -07:00
self . max_req_input_len ,
self . random_seed ,
2024-10-19 17:39:38 -07:00
self . device ,
2024-10-19 23:19:26 -07:00
worker_global_server_args_dict ,
_ ,
_ ,
_ ,
) = self . tp_worker . get_worker_info ( )
2024-10-19 17:39:38 -07:00
self . tp_cpu_group = self . tp_worker . get_tp_cpu_group ( )
2025-01-16 11:15:00 -08:00
self . attn_tp_cpu_group = self . tp_worker . get_attention_tp_cpu_group ( )
2024-10-19 17:39:38 -07:00
self . pad_input_ids_func = self . tp_worker . get_pad_input_ids_func ( )
2024-10-19 23:19:26 -07:00
global_server_args_dict . update ( worker_global_server_args_dict )
2024-09-29 17:42:45 -07:00
set_random_seed ( self . random_seed )
2025-03-03 00:12:04 -08:00
2024-09-29 17:42:45 -07:00
# Print debug info
logger . info (
f " max_total_num_tokens= { self . max_total_num_tokens } , "
2025-01-26 04:51:54 -08:00
f " chunked_prefill_size= { server_args . chunked_prefill_size } , "
2024-09-29 17:42:45 -07:00
f " max_prefill_tokens= { self . max_prefill_tokens } , "
f " max_running_requests= { self . max_running_requests } , "
f " context_len= { self . model_config . context_len } "
)
2024-10-19 17:39:38 -07:00
# Init memory pool and cache
self . req_to_token_pool , self . token_to_kv_pool = self . tp_worker . get_memory_pool ( )
2024-09-29 17:42:45 -07:00
if (
server_args . chunked_prefill_size is not None
and server_args . disable_radix_cache
) :
self . tree_cache = ChunkCache (
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool = self . token_to_kv_pool ,
)
else :
2025-03-03 00:12:04 -08:00
if self . enable_hierarchical_cache :
self . tree_cache = HiRadixCache (
2025-02-23 21:56:30 -08:00
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool = self . token_to_kv_pool ,
)
2025-03-03 00:12:04 -08:00
else :
self . tree_cache = RadixCache (
2025-02-23 21:56:30 -08:00
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool = self . token_to_kv_pool ,
disable = server_args . disable_radix_cache ,
)
2025-03-03 00:12:04 -08:00
2024-09-30 02:41:11 -07:00
self . policy = SchedulePolicy ( self . schedule_policy , self . tree_cache )
2024-09-29 17:42:45 -07:00
# Init running status
self . waiting_queue : List [ Req ] = [ ]
2025-02-23 21:56:30 -08:00
self . staging_reqs = { }
2024-11-19 15:04:43 -08:00
# The running decoding batch for continuous batching
2024-10-14 05:25:00 -07:00
self . running_batch : Optional [ ScheduleBatch ] = None
2024-11-19 15:04:43 -08:00
# The current forward batch
2024-10-16 01:33:20 -07:00
self . cur_batch : Optional [ ScheduleBatch ] = None
2024-11-19 15:04:43 -08:00
# The current forward batch
self . last_batch : Optional [ ScheduleBatch ] = None
2024-10-27 02:00:50 -07:00
self . forward_ct = 0
self . forward_ct_decode = 0
2024-09-29 17:42:45 -07:00
self . num_generated_tokens = 0
2025-01-19 17:10:29 -08:00
self . spec_num_total_accepted_tokens = 0
self . spec_num_total_forward_ct = 0
2025-03-03 00:12:04 -08:00
self . cum_spec_accept_length = 0
self . cum_spec_accept_count = 0
2024-11-10 04:39:32 -08:00
self . last_decode_stats_tic = time . time ( )
2025-03-03 00:12:04 -08:00
self . return_health_check_ct = 0
2024-12-06 05:49:29 -08:00
self . current_stream = torch . get_device_module ( self . device ) . current_stream ( )
2025-01-17 13:22:53 +08:00
if self . device == " cpu " :
self . current_stream . synchronize = lambda : None # No-op for CPU
2024-12-06 05:49:29 -08:00
2025-03-03 00:12:04 -08:00
# For metrics only.
# The largest prefill length of a single request
self . _largest_prefill_len : int = 0
# The largest context length (prefill + generation) of a single request
self . _largest_prefill_decode_len : int = 0
self . last_gen_throughput : float = 0.0
self . step_time_dict = defaultdict ( list ) # Dict[batch size -> step time]
2024-12-06 05:49:29 -08:00
# Session info
2024-12-22 06:25:57 -08:00
self . sessions : Dict [ str , Session ] = { }
2024-09-29 17:42:45 -07:00
# Init chunked prefill
self . chunked_prefill_size = server_args . chunked_prefill_size
2024-11-29 16:03:32 -08:00
if self . chunked_prefill_size < = 0 : # -1 means disable
self . chunked_prefill_size = None
2025-03-03 00:12:04 -08:00
self . chunked_req = None
2024-09-29 17:42:45 -07:00
self . is_mixed_chunk = (
self . chunked_prefill_size is not None and server_args . enable_mixed_chunk
)
2024-11-12 21:17:38 -08:00
# Init the grammar backend for constrained generation
2024-11-13 01:45:28 +09:00
self . grammar_queue : List [ Req ] = [ ]
2024-10-25 10:24:44 -07:00
if not server_args . skip_tokenizer_init :
2025-01-19 17:10:29 -08:00
self . grammar_backend = create_grammar_backend (
server_args , self . tokenizer , self . model_config . vocab_size
)
2024-11-12 21:17:38 -08:00
else :
self . grammar_backend = None
2024-09-29 17:42:45 -07:00
# Init new token estimation
2024-10-25 10:24:44 -07:00
assert (
server_args . schedule_conservativeness > = 0
) , " Invalid schedule_conservativeness "
2024-10-26 16:39:41 -07:00
self . init_new_token_ratio = min (
global_config . default_init_new_token_ratio
2024-10-25 10:24:44 -07:00
* server_args . schedule_conservativeness ,
1.0 ,
2024-09-29 17:42:45 -07:00
)
2024-10-26 16:39:41 -07:00
self . min_new_token_ratio = min (
self . init_new_token_ratio
* global_config . default_min_new_token_ratio_factor ,
1.0 ,
)
self . new_token_ratio_decay = (
self . init_new_token_ratio - self . min_new_token_ratio
) / global_config . default_new_token_ratio_decay_steps
self . new_token_ratio = self . init_new_token_ratio
2025-03-03 00:12:04 -08:00
# Tell whether the current running batch is full so that we can skip
2024-11-24 04:47:10 -08:00
# the check of whether to prefill new requests.
# This is an optimization to reduce the overhead of the prefill check.
2024-10-04 01:51:11 -07:00
self . batch_is_full = False
2024-09-29 02:36:12 -07:00
2024-10-27 02:00:50 -07:00
# Init watchdog thread
self . watchdog_timeout = server_args . watchdog_timeout
t = threading . Thread ( target = self . watchdog_thread , daemon = True )
t . start ( )
2024-11-28 00:22:39 -08:00
self . parent_process = psutil . Process ( ) . parent ( )
2024-10-27 02:00:50 -07:00
2025-03-03 00:12:04 -08:00
# Init memory saver
2025-01-14 03:38:51 +08:00
self . memory_saver_adapter = TorchMemorySaverAdapter . create (
enable = server_args . enable_memory_saver
)
2024-10-16 11:20:17 -07:00
# Init profiler
2025-03-03 00:12:04 -08:00
self . torch_profiler = None
self . torch_profiler_output_dir : Optional [ str ] = None
self . torch_profiler_activities : Optional [ List [ str ] ] = None
self . profiler_target_forward_ct : Optional [ int ] = None
2024-11-10 04:39:32 -08:00
2024-11-06 12:42:53 +08:00
# Init metrics stats
2024-11-10 04:39:32 -08:00
self . stats = SchedulerStats ( )
if self . enable_metrics :
self . metrics_collector = SchedulerMetricsCollector (
labels = {
" model_name " : self . server_args . served_model_name ,
# TODO: Add lora name/path in the future,
} ,
)
2024-10-11 17:34:25 +08:00
2025-01-19 17:10:29 -08:00
# Init request dispatcher
self . _request_dispatcher = TypeBasedDispatcher (
2025-01-19 12:13:27 +08:00
[
( TokenizedGenerateReqInput , self . handle_generate_request ) ,
( TokenizedEmbeddingReqInput , self . handle_embedding_request ) ,
( FlushCacheReq , self . flush_cache_wrapped ) ,
( AbortReq , self . abort_request ) ,
2025-03-03 00:12:04 -08:00
( OpenSessionReqInput , self . open_session ) ,
( CloseSessionReqInput , self . close_session ) ,
2025-01-19 12:13:27 +08:00
( UpdateWeightFromDiskReqInput , self . update_weights_from_disk ) ,
( InitWeightsUpdateGroupReqInput , self . init_weights_update_group ) ,
(
UpdateWeightsFromDistributedReqInput ,
self . update_weights_from_distributed ,
) ,
( UpdateWeightsFromTensorReqInput , self . update_weights_from_tensor ) ,
( GetWeightsByNameReqInput , self . get_weights_by_name ) ,
2025-03-03 00:12:04 -08:00
( ReleaseMemoryOccupationReqInput , self . release_memory_occupation ) ,
( ResumeMemoryOccupationReqInput , self . resume_memory_occupation ) ,
2025-01-19 12:13:27 +08:00
( ProfileReq , self . profile ) ,
2025-03-03 00:12:04 -08:00
( GetInternalStateReq , self . get_internal_state ) ,
( SetInternalStateReq , self . set_internal_state ) ,
2025-01-19 12:13:27 +08:00
]
)
2024-10-27 02:00:50 -07:00
def watchdog_thread ( self ) :
2025-03-03 00:12:04 -08:00
""" A watch dog thread that will try to kill the server itself if one forward batch takes too long. """
2024-10-27 02:00:50 -07:00
self . watchdog_last_forward_ct = 0
self . watchdog_last_time = time . time ( )
while True :
2025-01-16 12:51:11 -08:00
current = time . time ( )
2024-10-27 02:00:50 -07:00
if self . cur_batch is not None :
if self . watchdog_last_forward_ct == self . forward_ct :
2025-01-16 12:51:11 -08:00
if current > self . watchdog_last_time + self . watchdog_timeout :
2024-10-27 02:00:50 -07:00
logger . error ( f " Watchdog timeout ( { self . watchdog_timeout =} ) " )
break
else :
self . watchdog_last_forward_ct = self . forward_ct
2025-01-16 12:51:11 -08:00
self . watchdog_last_time = current
time . sleep ( self . watchdog_timeout / / 2 )
2025-03-03 00:12:04 -08:00
# Print batch size and memory pool info to check whether there are de-sync issues.
logger . error (
f " { self . cur_batch . batch_size ( ) =} , "
f " { self . cur_batch . reqs =} , "
f " { self . token_to_kv_pool . available_size ( ) =} , "
f " { self . tree_cache . evictable_size ( ) =} , "
)
# Wait for some time so that the parent process can print the error.
pyspy_dump_schedulers ( )
print ( file = sys . stderr , flush = True )
print ( file = sys . stdout , flush = True )
2025-01-13 01:39:14 -08:00
time . sleep ( 5 )
2024-11-28 00:22:39 -08:00
self . parent_process . send_signal ( signal . SIGQUIT )
2024-10-27 02:00:50 -07:00
2024-11-18 08:29:20 +08:00
@torch.no_grad ( )
2024-10-13 19:54:02 -07:00
def event_loop_normal ( self ) :
2024-11-19 15:04:43 -08:00
""" A normal scheduler loop. """
2024-09-29 02:36:12 -07:00
while True :
2024-10-06 03:24:04 -07:00
recv_reqs = self . recv_requests ( )
self . process_input_requests ( recv_reqs )
2024-09-29 02:36:12 -07:00
2024-10-12 21:35:30 -07:00
batch = self . get_next_batch_to_run ( )
2024-10-27 02:00:50 -07:00
self . cur_batch = batch
2024-10-12 21:35:30 -07:00
if batch :
result = self . run_batch ( batch )
self . process_batch_result ( batch , result )
2024-10-14 05:25:00 -07:00
else :
2024-12-29 00:45:57 -08:00
# When the server is idle, so self-check and re-init some states
2024-10-14 05:25:00 -07:00
self . check_memory ( )
2024-10-26 16:39:41 -07:00
self . new_token_ratio = self . init_new_token_ratio
2024-10-12 21:35:30 -07:00
self . last_batch = batch
2024-09-29 02:36:12 -07:00
2024-11-18 08:29:20 +08:00
@torch.no_grad ( )
2024-10-16 01:33:20 -07:00
def event_loop_overlap ( self ) :
2024-10-19 23:19:26 -07:00
""" A scheduler loop that overlaps the CPU processing and GPU computation. """
2025-01-27 03:00:41 -08:00
self . result_queue = deque ( )
2024-10-16 01:33:20 -07:00
while True :
recv_reqs = self . recv_requests ( )
self . process_input_requests ( recv_reqs )
batch = self . get_next_batch_to_run ( )
self . cur_batch = batch
2024-12-29 00:45:57 -08:00
2024-10-16 01:33:20 -07:00
if batch :
result = self . run_batch ( batch )
2025-01-27 03:00:41 -08:00
self . result_queue . append ( ( batch . copy ( ) , result ) )
2024-10-16 01:33:20 -07:00
2024-11-19 15:04:43 -08:00
if self . last_batch is None :
2025-01-19 17:10:29 -08:00
# Create a dummy first batch to start the pipeline for overlap schedule.
2024-11-19 15:04:43 -08:00
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch (
reqs = None ,
forward_mode = ForwardMode . DUMMY_FIRST ,
next_batch_sampling_info = self . tp_worker . cur_sampling_info ,
)
self . process_batch_result ( tmp_batch , None )
2024-10-16 01:33:20 -07:00
if self . last_batch :
2024-12-29 00:45:57 -08:00
# Process the results of the last batch
2025-01-27 03:00:41 -08:00
tmp_batch , tmp_result = self . result_queue . popleft ( )
2024-11-19 15:04:43 -08:00
tmp_batch . next_batch_sampling_info = (
self . tp_worker . cur_sampling_info if batch else None
)
2024-10-16 01:33:20 -07:00
self . process_batch_result ( tmp_batch , tmp_result )
elif batch is None :
2024-12-29 00:45:57 -08:00
# When the server is idle, so self-check and re-init some states
2024-10-16 01:33:20 -07:00
self . check_memory ( )
2024-10-26 16:39:41 -07:00
self . new_token_ratio = self . init_new_token_ratio
2024-10-16 01:33:20 -07:00
self . last_batch = batch
2024-12-29 00:45:57 -08:00
def recv_requests ( self ) - > List [ Req ] :
""" Receive results at tp_rank = 0 and broadcast it to all other TP ranks. """
2025-01-16 11:15:00 -08:00
if self . attn_tp_rank == 0 :
2024-10-06 03:24:04 -07:00
recv_reqs = [ ]
2024-12-08 03:55:27 -08:00
while True :
try :
recv_req = self . recv_from_tokenizer . recv_pyobj ( zmq . NOBLOCK )
except zmq . ZMQError :
break
2024-12-08 04:08:04 -08:00
recv_reqs . append ( recv_req )
2024-10-06 03:24:04 -07:00
else :
recv_reqs = None
2024-09-29 02:36:12 -07:00
2025-01-16 11:15:00 -08:00
if self . server_args . enable_dp_attention :
if self . attn_tp_rank == 0 :
work_reqs = [
req
for req in recv_reqs
if isinstance (
req , ( TokenizedGenerateReqInput , TokenizedEmbeddingReqInput )
)
]
control_reqs = [
req
for req in recv_reqs
if not isinstance (
req , ( TokenizedGenerateReqInput , TokenizedEmbeddingReqInput )
)
]
else :
work_reqs = None
control_reqs = None
if self . attn_tp_size != 1 :
attn_tp_rank_0 = self . dp_rank * self . attn_tp_size
work_reqs = broadcast_pyobj (
work_reqs ,
self . attn_tp_rank ,
self . attn_tp_cpu_group ,
src = attn_tp_rank_0 ,
)
if self . tp_size != 1 :
control_reqs = broadcast_pyobj (
control_reqs , self . tp_rank , self . tp_cpu_group
)
recv_reqs = work_reqs + control_reqs
elif self . tp_size != 1 :
2024-10-07 13:05:53 -07:00
recv_reqs = broadcast_pyobj ( recv_reqs , self . tp_rank , self . tp_cpu_group )
2024-09-29 02:36:12 -07:00
return recv_reqs
2024-10-06 03:24:04 -07:00
def process_input_requests ( self , recv_reqs : List ) :
2024-09-29 17:42:45 -07:00
for recv_req in recv_reqs :
2025-03-03 00:12:04 -08:00
# If it is a health check generation request and there are running requests, ignore it.
if is_health_check_generate_req ( recv_req ) and (
self . chunked_req is not None or self . running_batch is not None
) :
self . return_health_check_ct + = 1
continue
2025-01-19 17:10:29 -08:00
output = self . _request_dispatcher ( recv_req )
2025-01-19 12:13:27 +08:00
if output is not None :
self . send_to_tokenizer . send_pyobj ( output )
2024-09-29 17:42:45 -07:00
def handle_generate_request (
self ,
recv_req : TokenizedGenerateReqInput ,
) :
2024-11-29 03:15:58 -08:00
# Create a new request
2024-12-29 02:10:27 -08:00
if (
recv_req . session_params is None
or recv_req . session_params . id is None
or recv_req . session_params . id not in self . sessions
) :
2024-11-25 19:35:04 -05:00
if recv_req . input_embeds is not None :
# Generate fake input_ids based on the length of input_embeds
seq_length = len ( recv_req . input_embeds )
fake_input_ids = [ 1 ] * seq_length
recv_req . input_ids = fake_input_ids
2025-01-19 14:46:53 -08:00
# Handle custom logit processor passed to the request
custom_logit_processor = recv_req . custom_logit_processor
if (
not self . server_args . enable_custom_logit_processor
and custom_logit_processor is not None
) :
logger . warning (
" The SGLang server is not configured to enable custom logit processor. "
" The custom logit processor passed in will be ignored. "
" Please set --enable-custom-logits-processor to enable this feature. "
)
custom_logit_processor = None
2024-11-20 00:36:53 -08:00
req = Req (
recv_req . rid ,
recv_req . input_text ,
recv_req . input_ids ,
recv_req . sampling_params ,
2024-12-08 12:27:13 -08:00
return_logprob = recv_req . return_logprob ,
top_logprobs_num = recv_req . top_logprobs_num ,
2025-03-03 00:12:04 -08:00
token_ids_logprob = recv_req . token_ids_logprob ,
2024-12-08 12:27:13 -08:00
stream = recv_req . stream ,
2024-11-20 00:36:53 -08:00
lora_path = recv_req . lora_path ,
2024-11-25 19:35:04 -05:00
input_embeds = recv_req . input_embeds ,
2025-01-19 14:46:53 -08:00
custom_logit_processor = custom_logit_processor ,
2025-03-01 20:51:29 -05:00
return_hidden_states = recv_req . return_hidden_states ,
2024-12-27 11:23:46 -08:00
eos_token_ids = self . model_config . hf_eos_token_id ,
2024-11-20 00:36:53 -08:00
)
req . tokenizer = self . tokenizer
2024-11-25 16:38:43 -08:00
2024-12-29 02:10:27 -08:00
if (
recv_req . session_params is not None
and recv_req . session_params . id is not None
) :
2024-11-20 00:36:53 -08:00
req . finished_reason = FINISH_ABORT (
2024-12-29 02:10:27 -08:00
f " Invalid request: session id { recv_req . session_params . id } does not exist "
2024-11-20 00:36:53 -08:00
)
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-11-20 00:36:53 -08:00
return
else :
2024-12-29 02:10:27 -08:00
# Create a new request from a previous session
session = self . sessions [ recv_req . session_params . id ]
2024-11-25 12:32:51 -08:00
req = session . create_req ( recv_req , self . tokenizer )
2024-11-20 00:36:53 -08:00
if isinstance ( req . finished_reason , FINISH_ABORT ) :
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-11-20 00:36:53 -08:00
return
2024-09-29 17:42:45 -07:00
2025-01-27 03:00:41 -08:00
# Handle multimodal inputs
2024-09-29 17:42:45 -07:00
if recv_req . image_inputs is not None :
2024-11-29 03:15:58 -08:00
image_inputs = ImageInputs . from_dict ( recv_req . image_inputs )
# Expand a single image token into multiple dummy tokens for receiving image embeddings
2024-09-30 06:41:49 -07:00
req . origin_input_ids = self . pad_input_ids_func (
2024-11-27 00:03:29 -08:00
req . origin_input_ids , image_inputs
2024-09-29 17:42:45 -07:00
)
2024-11-29 03:15:58 -08:00
req . extend_image_inputs ( image_inputs )
2024-09-29 17:42:45 -07:00
2024-11-29 04:24:20 -08:00
if len ( req . origin_input_ids ) > = self . max_req_input_len :
2025-01-18 19:37:30 -08:00
error_msg = (
2024-11-29 04:24:20 -08:00
" Multimodal prompt is too long after expanding multimodal tokens. "
2025-01-18 19:37:30 -08:00
f " After expanding { len ( req . origin_input_ids_unpadded ) =} => { len ( req . origin_input_ids ) } >= { self . max_req_input_len } . "
2024-11-21 19:05:41 -08:00
)
2025-01-18 19:37:30 -08:00
logger . error ( error_msg )
2024-11-28 02:22:15 -08:00
req . origin_input_ids = [ 0 ]
2024-11-29 04:24:20 -08:00
req . image_inputs = None
2024-11-21 19:05:41 -08:00
req . sampling_params . max_new_tokens = 0
2024-11-29 04:24:20 -08:00
req . finished_reason = FINISH_ABORT (
2025-01-18 19:37:30 -08:00
error_msg , HTTPStatus . BAD_REQUEST , " BadRequestError "
2024-11-29 04:24:20 -08:00
)
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-11-21 19:05:41 -08:00
return
2025-01-16 14:51:19 -08:00
# Validate prompts length
error_msg = validate_input_length (
req ,
self . max_req_input_len ,
self . server_args . allow_auto_truncate ,
)
if error_msg :
2025-02-27 22:59:43 -08:00
req . origin_input_ids = [ 0 ]
req . sampling_params . max_new_tokens = 0
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2025-01-16 14:51:19 -08:00
return
2024-10-21 16:12:04 -07:00
2025-01-26 01:39:28 -08:00
# Copy more attributes
2025-03-03 00:12:04 -08:00
if recv_req . logprob_start_len == - 1 or not recv_req . return_logprob :
2025-01-26 01:39:28 -08:00
# By default, only return the logprobs for output tokens
req . logprob_start_len = len ( req . origin_input_ids ) - 1
else :
req . logprob_start_len = recv_req . logprob_start_len
2025-03-03 00:12:04 -08:00
if req . logprob_start_len > = len ( req . origin_input_ids ) :
req . finished_reason = FINISH_ABORT (
f " logprob_start_len, ( { req . logprob_start_len } ) is higher than the number of input tokens ( { len ( req . origin_input_ids ) } ). Request with a lower logprob_start_len. " ,
HTTPStatus . BAD_REQUEST ,
" BadRequestError " ,
)
req . logprob_start_len = len ( req . origin_input_ids ) - 1
self . _add_request_to_queue ( req )
return
2024-09-29 17:42:45 -07:00
req . sampling_params . max_new_tokens = min (
(
req . sampling_params . max_new_tokens
if req . sampling_params . max_new_tokens is not None
else 1 << 30
) ,
2024-10-21 16:12:04 -07:00
self . max_req_len - len ( req . origin_input_ids ) - 1 ,
2024-09-29 17:42:45 -07:00
)
2024-11-13 01:49:45 -08:00
# Init grammar cache for this request
add_to_grammar_queue = False
if (
req . sampling_params . json_schema is not None
or req . sampling_params . regex is not None
2024-12-26 18:42:41 +05:30
or req . sampling_params . ebnf is not None
2025-02-28 15:33:41 +08:00
or req . sampling_params . structural_tag is not None
2024-11-13 01:49:45 -08:00
) :
assert self . grammar_backend is not None
if req . sampling_params . json_schema is not None :
key = ( " json " , req . sampling_params . json_schema )
elif req . sampling_params . regex is not None :
key = ( " regex " , req . sampling_params . regex )
2024-12-26 18:42:41 +05:30
elif req . sampling_params . ebnf is not None :
key = ( " ebnf " , req . sampling_params . ebnf )
2025-02-28 15:33:41 +08:00
elif req . sampling_params . structural_tag :
key = ( " structural_tag " , req . sampling_params . structural_tag )
2024-11-13 01:49:45 -08:00
req . grammar = self . grammar_backend . get_cached_value ( key )
if not req . grammar :
req . grammar = self . grammar_backend . get_future_value ( key )
add_to_grammar_queue = True
if add_to_grammar_queue :
2024-11-13 01:45:28 +09:00
self . grammar_queue . append ( req )
else :
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
def _add_request_to_queue ( self , req : Req ) :
self . waiting_queue . append ( req )
def _extend_requests_to_queue ( self , reqs : List [ Req ] ) :
self . waiting_queue . extend ( reqs )
2024-09-29 17:42:45 -07:00
def handle_embedding_request (
self ,
2024-11-03 08:38:26 -08:00
recv_req : TokenizedEmbeddingReqInput ,
2024-09-29 17:42:45 -07:00
) :
req = Req (
recv_req . rid ,
recv_req . input_text ,
recv_req . input_ids ,
recv_req . sampling_params ,
)
req . tokenizer = self . tokenizer
2025-01-16 14:51:19 -08:00
# Validate prompts length
2025-01-26 01:39:28 -08:00
error_msg = validate_input_length (
2025-01-16 14:51:19 -08:00
req ,
self . max_req_input_len ,
self . server_args . allow_auto_truncate ,
)
2025-01-26 01:39:28 -08:00
if error_msg :
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2025-01-26 01:39:28 -08:00
return
2024-09-29 17:42:45 -07:00
2025-01-26 01:39:28 -08:00
# Copy more attributes
req . logprob_start_len = len ( req . origin_input_ids ) - 1
2025-03-03 00:12:04 -08:00
self . _add_request_to_queue ( req )
2024-09-29 17:42:45 -07:00
2025-01-27 03:00:41 -08:00
def log_prefill_stats (
self ,
adder : PrefillAdder ,
can_run_list : List [ Req ] ,
2025-03-03 00:12:04 -08:00
running_bs : int ,
2025-01-27 03:00:41 -08:00
) :
2024-11-10 04:39:32 -08:00
num_used = self . max_total_num_tokens - (
self . token_to_kv_pool . available_size ( ) + self . tree_cache . evictable_size ( )
)
2025-03-03 00:12:04 -08:00
self . _largest_prefill_len = max (
self . _largest_prefill_len , adder . log_input_tokens
)
2024-11-10 04:39:32 -08:00
2025-03-03 00:12:04 -08:00
f = (
2024-11-10 04:39:32 -08:00
f " Prefill batch. "
f " #new-seq: { len ( can_run_list ) } , "
f " #new-token: { adder . log_input_tokens } , "
f " #cached-token: { adder . log_hit_tokens } , "
f " token usage: { num_used / self . max_total_num_tokens : .2f } , "
f " #running-req: { running_bs } , "
2025-03-03 00:12:04 -08:00
f " #queue-req: { len ( self . waiting_queue ) } , "
2024-11-10 04:39:32 -08:00
)
2025-03-03 00:12:04 -08:00
logger . info ( f )
2024-11-10 04:39:32 -08:00
if self . enable_metrics :
2025-03-03 00:12:04 -08:00
cache_hit_rate = adder . log_hit_tokens / (
adder . log_input_tokens + adder . log_hit_tokens
)
2024-11-10 04:39:32 -08:00
self . stats . num_running_reqs = running_bs
self . stats . num_used_tokens = num_used
self . stats . token_usage = round ( num_used / self . max_total_num_tokens , 2 )
2025-03-03 00:12:04 -08:00
self . stats . num_queue_reqs = len ( self . waiting_queue )
self . stats . cache_hit_rate = cache_hit_rate
2024-11-10 04:39:32 -08:00
self . metrics_collector . log_stats ( self . stats )
def log_decode_stats ( self ) :
2025-03-03 00:12:04 -08:00
gap_latency = time . time ( ) - self . last_decode_stats_tic
self . last_decode_stats_tic = time . time ( )
self . last_gen_throughput = self . num_generated_tokens / gap_latency
self . num_generated_tokens = 0
num_running_reqs = len ( self . running_batch . reqs ) if self . running_batch else 0
2024-10-06 03:24:04 -07:00
num_used = self . max_total_num_tokens - (
self . token_to_kv_pool . available_size ( ) + self . tree_cache . evictable_size ( )
)
2025-03-03 00:12:04 -08:00
if RECORD_STEP_TIME :
self . step_time_dict [ num_running_reqs ] . append (
gap_latency / self . server_args . decode_log_interval
)
2024-10-06 03:24:04 -07:00
2025-01-19 17:10:29 -08:00
if self . spec_algorithm . is_none ( ) :
msg = (
f " Decode batch. "
f " #running-req: { num_running_reqs } , "
f " #token: { num_used } , "
f " token usage: { num_used / self . max_total_num_tokens : .2f } , "
2025-03-03 00:12:04 -08:00
f " gen throughput (token/s): { self . last_gen_throughput : .2f } , "
f " largest-len: { self . _largest_prefill_decode_len } , "
f " #queue-req: { len ( self . waiting_queue ) } , "
2025-01-19 17:10:29 -08:00
)
2025-01-19 18:36:59 -08:00
spec_accept_length = 0
2025-01-19 17:10:29 -08:00
else :
2025-01-19 18:36:59 -08:00
spec_accept_length = (
2025-01-19 17:10:29 -08:00
self . spec_num_total_accepted_tokens / self . spec_num_total_forward_ct
)
2025-03-03 00:12:04 -08:00
self . cum_spec_accept_length + = self . spec_num_total_accepted_tokens
self . cum_spec_accept_count + = self . spec_num_total_forward_ct
2025-01-19 17:10:29 -08:00
self . spec_num_total_accepted_tokens = self . spec_num_total_forward_ct = 0
msg = (
f " Decode batch. "
f " #running-req: { num_running_reqs } , "
f " #token: { num_used } , "
f " token usage: { num_used / self . max_total_num_tokens : .2f } , "
2025-01-19 18:36:59 -08:00
f " accept len: { spec_accept_length : .2f } , "
2025-03-03 00:12:04 -08:00
f " gen throughput (token/s): { self . last_gen_throughput : .2f } , "
f " largest-len: { self . _largest_prefill_decode_len } , "
f " #queue-req: { len ( self . waiting_queue ) } , "
2025-01-19 17:10:29 -08:00
)
logger . info ( msg )
2024-11-10 04:39:32 -08:00
if self . enable_metrics :
self . stats . num_running_reqs = num_running_reqs
self . stats . num_used_tokens = num_used
self . stats . token_usage = num_used / self . max_total_num_tokens
2025-03-03 00:12:04 -08:00
self . stats . cache_hit_rate = 0.0
self . stats . gen_throughput = self . last_gen_throughput
2024-11-10 04:39:32 -08:00
self . stats . num_queue_reqs = len ( self . waiting_queue )
2025-01-19 18:36:59 -08:00
self . stats . spec_accept_length = spec_accept_length
2024-11-10 04:39:32 -08:00
self . metrics_collector . log_stats ( self . stats )
2024-10-06 03:24:04 -07:00
def check_memory ( self ) :
available_size = (
self . token_to_kv_pool . available_size ( ) + self . tree_cache . evictable_size ( )
)
2025-01-27 12:28:17 -08:00
protected_size = self . tree_cache . protected_size ( )
memory_leak = available_size != (
self . max_total_num_tokens
if not self . enable_hierarchical_cache
else self . max_total_num_tokens - protected_size
)
if memory_leak :
2024-11-17 22:18:11 -08:00
msg = (
2024-10-06 03:24:04 -07:00
" KV cache pool leak detected! "
2025-01-27 12:28:17 -08:00
f " { available_size =} , { protected_size =} , { self . max_total_num_tokens =} \n "
2024-10-06 03:24:04 -07:00
)
2024-11-17 22:18:11 -08:00
warnings . warn ( msg )
if crash_on_warnings ( ) :
raise ValueError ( msg )
2024-10-06 03:24:04 -07:00
if len ( self . req_to_token_pool . free_slots ) != self . req_to_token_pool . size :
2024-11-17 22:18:11 -08:00
msg = (
2024-10-06 03:24:04 -07:00
" Memory pool leak detected! "
2024-11-17 22:18:11 -08:00
f " available_size= { len ( self . req_to_token_pool . free_slots ) } , "
f " total_size= { self . req_to_token_pool . size } \n "
2024-10-06 03:24:04 -07:00
)
2024-11-17 22:18:11 -08:00
warnings . warn ( msg )
if crash_on_warnings ( ) :
raise ValueError ( msg )
2024-10-06 03:24:04 -07:00
2025-03-03 00:12:04 -08:00
if (
self . enable_metrics
and self . attn_tp_rank == 0
and time . time ( ) > self . metrics_collector . last_log_time + 30
) :
# During idle time, also collect metrics every 30 seconds.
num_used = self . max_total_num_tokens - (
self . token_to_kv_pool . available_size ( )
+ self . tree_cache . evictable_size ( )
)
num_running_reqs = len ( self . running_batch . reqs ) if self . running_batch else 0
self . stats . num_running_reqs = num_running_reqs
self . stats . num_used_tokens = num_used
self . stats . token_usage = num_used / self . max_total_num_tokens
self . stats . gen_throughput = 0
self . stats . num_queue_reqs = len ( self . waiting_queue )
self . metrics_collector . log_stats ( self . stats )
2024-12-11 12:51:50 -08:00
def get_next_batch_to_run ( self ) - > Optional [ ScheduleBatch ] :
2024-10-14 01:15:34 -07:00
# Merge the prefill batch into the running batch
2024-11-24 04:47:10 -08:00
if self . last_batch and self . last_batch . forward_mode . is_extend ( ) :
2025-03-03 00:12:04 -08:00
if self . chunked_req :
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
self . last_batch . filter_batch ( chunked_req_to_exclude = self . chunked_req )
self . tree_cache . cache_unfinished_req ( self . chunked_req )
# chunked request keeps its rid but will get a new req_pool_idx
self . req_to_token_pool . free ( self . chunked_req . req_pool_idx )
2024-10-14 01:15:34 -07:00
self . batch_is_full = False
2024-11-24 04:47:10 -08:00
2025-03-03 00:12:04 -08:00
self . last_batch . filter_batch ( )
2024-10-14 01:15:34 -07:00
if not self . last_batch . is_empty ( ) :
if self . running_batch is None :
self . running_batch = self . last_batch
else :
2025-03-03 00:12:04 -08:00
# merge running_batch with prefill batch
2024-10-14 01:15:34 -07:00
self . running_batch . merge_batch ( self . last_batch )
2024-10-12 21:35:30 -07:00
2024-10-07 13:05:53 -07:00
new_batch = self . get_new_batch_prefill ( )
if new_batch is not None :
2025-01-19 17:10:29 -08:00
# Run prefill first if possible
ret = new_batch
else :
# Run decode
if self . running_batch is None :
ret = None
else :
self . running_batch = self . update_running_batch ( self . running_batch )
ret = self . running_batch
2024-10-10 16:34:13 -07:00
2025-01-19 17:10:29 -08:00
# Handle DP attention
if self . server_args . enable_dp_attention :
ret = self . prepare_dp_attn_batch ( ret )
return ret
2024-10-07 13:05:53 -07:00
2024-10-06 03:24:04 -07:00
def get_new_batch_prefill ( self ) - > Optional [ ScheduleBatch ] :
2024-11-12 21:17:38 -08:00
# Check if the grammar is ready in the grammar queue
2024-11-13 01:45:28 +09:00
if self . grammar_queue :
2024-11-13 01:49:45 -08:00
self . move_ready_grammar_requests ( )
2024-11-13 01:45:28 +09:00
2024-10-06 03:24:04 -07:00
# Handle the cases where prefill is not allowed
if (
self . batch_is_full or len ( self . waiting_queue ) == 0
2025-03-03 00:12:04 -08:00
) and self . chunked_req is None :
2024-10-06 03:24:04 -07:00
return None
2024-10-14 05:25:00 -07:00
running_bs = len ( self . running_batch . reqs ) if self . running_batch else 0
2024-09-29 17:42:45 -07:00
if running_bs > = self . max_running_requests :
2024-10-04 01:51:11 -07:00
self . batch_is_full = True
2024-09-29 17:42:45 -07:00
return None
# Get priority queue
prefix_computed = self . policy . calc_priority ( self . waiting_queue )
2024-10-06 03:24:04 -07:00
# Prefill policy
2024-09-29 17:42:45 -07:00
adder = PrefillAdder (
self . tree_cache ,
2025-01-17 20:20:26 -08:00
self . token_to_kv_pool ,
2024-09-29 17:42:45 -07:00
self . running_batch ,
self . new_token_ratio ,
self . max_prefill_tokens ,
self . chunked_prefill_size ,
2024-11-10 04:39:32 -08:00
running_bs if self . is_mixed_chunk else 0 ,
2024-09-29 17:42:45 -07:00
)
2025-03-03 00:12:04 -08:00
is_chunked = self . chunked_req is not None
if is_chunked :
self . chunked_req . init_next_round_input ( )
self . chunked_req = adder . add_chunked_req ( self . chunked_req )
2024-10-25 10:24:44 -07:00
2024-10-14 05:25:00 -07:00
if self . lora_paths :
2024-09-29 17:42:45 -07:00
lora_set = (
set ( [ req . lora_path for req in self . running_batch . reqs ] )
if self . running_batch is not None
else set ( [ ] )
)
2024-10-19 23:19:26 -07:00
# Get requests from the waiting queue to a new prefill batch
2024-09-29 17:42:45 -07:00
for req in self . waiting_queue :
if (
2024-10-14 05:25:00 -07:00
self . lora_paths
2024-09-29 17:42:45 -07:00
and len (
lora_set
| set ( [ req . lora_path for req in adder . can_run_list ] )
| set ( [ req . lora_path ] )
)
> self . max_loras_per_batch
) :
2024-10-04 01:51:11 -07:00
self . batch_is_full = True
2024-09-29 17:42:45 -07:00
break
2024-10-04 18:00:18 -07:00
if running_bs + len ( adder . can_run_list ) > = self . max_running_requests :
2024-10-04 01:51:11 -07:00
self . batch_is_full = True
2024-09-29 17:42:45 -07:00
break
2024-10-04 18:00:18 -07:00
2024-09-29 17:42:45 -07:00
req . init_next_round_input ( None if prefix_computed else self . tree_cache )
2025-02-23 21:56:30 -08:00
if self . enable_hierarchical_cache and req . last_node is not None :
if req . last_node . evicted :
# loading KV cache for the request
req . last_node , req . prefix_indices = self . tree_cache . init_load_back (
req . last_node ,
req . prefix_indices ,
adder . rem_total_tokens ,
)
if req . last_node . loading :
# to prevent frequent cache invalidation
if req . rid in self . staging_reqs :
self . tree_cache . dec_lock_ref ( self . staging_reqs [ req . rid ] )
self . tree_cache . inc_lock_ref ( req . last_node )
self . staging_reqs [ req . rid ] = req . last_node
continue
elif req . last_node . loading :
if not self . tree_cache . loading_complete ( req . last_node ) :
continue
if req . rid in self . staging_reqs :
self . tree_cache . dec_lock_ref ( self . staging_reqs [ req . rid ] )
del self . staging_reqs [ req . rid ]
2025-03-03 00:12:04 -08:00
res = adder . add_one_req ( req , self . chunked_req )
2024-10-04 18:00:18 -07:00
if res != AddReqResult . CONTINUE :
if res == AddReqResult . NO_TOKEN :
2025-01-27 12:28:17 -08:00
if self . enable_hierarchical_cache :
# Set batch_is_full after making sure there are requests that can be served
self . batch_is_full = len ( adder . can_run_list ) > 0 or (
self . running_batch is not None
and not self . running_batch . is_empty ( )
)
else :
self . batch_is_full = True
2024-09-29 17:42:45 -07:00
break
2025-03-03 00:12:04 -08:00
if self . prefill_only_one_req :
2024-12-29 00:45:57 -08:00
break
2024-09-29 17:42:45 -07:00
2024-10-14 05:25:00 -07:00
# Update waiting queue
2025-03-03 00:12:04 -08:00
can_run_list : List [ Req ] = adder . can_run_list
2024-10-14 05:25:00 -07:00
if len ( can_run_list ) == 0 :
return None
self . waiting_queue = [
x for x in self . waiting_queue if x not in set ( can_run_list )
]
2024-09-29 17:42:45 -07:00
2025-03-03 00:12:04 -08:00
if adder . new_chunked_req is not None :
assert self . chunked_req is None
self . chunked_req = adder . new_chunked_req
2024-10-25 10:24:44 -07:00
2025-03-03 00:12:04 -08:00
if self . chunked_req :
self . chunked_req . is_chunked + = 1
2024-10-06 03:24:04 -07:00
2024-09-29 17:42:45 -07:00
# Print stats
2025-01-16 11:15:00 -08:00
if self . attn_tp_rank == 0 :
2025-03-03 00:12:04 -08:00
self . log_prefill_stats ( adder , can_run_list , running_bs )
2024-09-29 17:42:45 -07:00
2024-10-06 03:24:04 -07:00
# Create a new batch
2024-09-29 17:42:45 -07:00
new_batch = ScheduleBatch . init_new (
can_run_list ,
self . req_to_token_pool ,
self . token_to_kv_pool ,
self . tree_cache ,
2024-10-21 15:01:21 -07:00
self . model_config ,
2024-11-24 06:29:38 -08:00
self . enable_overlap ,
2025-01-02 02:09:08 -08:00
self . spec_algorithm ,
2025-01-20 02:00:35 -08:00
self . server_args . enable_custom_logit_processor ,
2024-09-29 17:42:45 -07:00
)
2024-11-24 06:29:38 -08:00
new_batch . prepare_for_extend ( )
2024-09-29 17:42:45 -07:00
2024-10-06 03:24:04 -07:00
# Mixed-style chunked prefill
2024-11-24 07:17:37 -08:00
if (
self . is_mixed_chunk
and self . running_batch is not None
and not ( new_batch . return_logprob or self . running_batch . return_logprob )
) :
# TODO (lianmin): support return_logprob + mixed chunked prefill
2024-10-30 21:20:41 -07:00
self . running_batch . filter_batch ( )
if not self . running_batch . is_empty ( ) :
2024-11-24 06:29:38 -08:00
self . running_batch . prepare_for_decode ( )
2024-10-30 21:20:41 -07:00
new_batch . mix_with_running ( self . running_batch )
new_batch . decoding_reqs = self . running_batch . reqs
2024-09-29 17:42:45 -07:00
self . running_batch = None
2024-10-14 05:25:00 -07:00
else :
new_batch . decoding_reqs = None
2024-10-06 03:24:04 -07:00
return new_batch
2024-11-24 04:47:10 -08:00
def update_running_batch ( self , batch : ScheduleBatch ) - > Optional [ ScheduleBatch ] :
2024-10-19 23:19:26 -07:00
""" Update the current running decoding batch. """
2024-11-24 04:47:10 -08:00
initial_bs = batch . batch_size ( )
2024-10-06 03:24:04 -07:00
2024-10-14 01:15:34 -07:00
batch . filter_batch ( )
if batch . is_empty ( ) :
2024-11-24 04:47:10 -08:00
self . batch_is_full = False
return None
2024-10-14 01:15:34 -07:00
2024-10-06 03:24:04 -07:00
# Check if decode out of memory
2025-01-02 02:09:08 -08:00
if not batch . check_decode_mem ( self . decode_mem_cache_buf_multiplier ) or (
2025-03-03 00:12:04 -08:00
TEST_RETRACT and batch . batch_size ( ) > 10
2025-01-02 02:09:08 -08:00
) :
2024-10-06 03:24:04 -07:00
old_ratio = self . new_token_ratio
2025-03-03 00:12:04 -08:00
retracted_reqs , new_token_ratio = batch . retract_decode ( self . server_args )
2024-10-06 03:24:04 -07:00
self . new_token_ratio = new_token_ratio
2025-01-02 02:09:08 -08:00
if self . draft_worker :
self . draft_worker . finish_request ( retracted_reqs )
2024-09-29 17:42:45 -07:00
2024-10-06 03:24:04 -07:00
logger . info (
" Decode out of memory happened. "
f " #retracted_reqs: { len ( retracted_reqs ) } , "
f " #new_token_ratio: { old_ratio : .4f } -> { self . new_token_ratio : .4f } "
)
2025-03-03 00:12:04 -08:00
self . _extend_requests_to_queue ( retracted_reqs )
2024-10-06 03:24:04 -07:00
else :
self . new_token_ratio = max (
2024-10-25 10:24:44 -07:00
self . new_token_ratio - self . new_token_ratio_decay ,
2024-10-06 03:24:04 -07:00
self . min_new_token_ratio ,
)
2024-11-24 04:47:10 -08:00
if batch . batch_size ( ) < initial_bs :
self . batch_is_full = False
2024-10-06 03:24:04 -07:00
# Update batch tensors
2024-11-24 06:29:38 -08:00
batch . prepare_for_decode ( )
2024-11-24 04:47:10 -08:00
return batch
2024-10-06 03:24:04 -07:00
2025-01-16 12:51:11 -08:00
def run_batch (
self , batch : ScheduleBatch
) - > Union [ GenerationBatchResult , EmbeddingBatchResult ] :
2024-10-19 23:19:26 -07:00
""" Run a batch. """
2024-10-27 02:00:50 -07:00
self . forward_ct + = 1
2025-03-03 00:12:04 -08:00
# Check profiler
if (
self . profiler_target_forward_ct
and self . profiler_target_forward_ct < = self . forward_ct
) :
self . stop_profile ( )
2024-09-29 17:42:45 -07:00
if self . is_generation :
2025-01-26 01:39:28 -08:00
if self . spec_algorithm . is_none ( ) :
model_worker_batch = batch . get_model_worker_batch ( )
logits_output , next_token_ids = self . tp_worker . forward_batch_generation (
model_worker_batch
)
2024-10-06 03:24:04 -07:00
else :
2025-01-26 01:39:28 -08:00
(
logits_output ,
next_token_ids ,
model_worker_batch ,
num_accepted_tokens ,
) = self . draft_worker . forward_batch_speculative_generation ( batch )
self . spec_num_total_accepted_tokens + = (
num_accepted_tokens + batch . batch_size ( )
)
self . spec_num_total_forward_ct + = batch . batch_size ( )
self . num_generated_tokens + = num_accepted_tokens
2024-10-13 19:54:02 -07:00
batch . output_ids = next_token_ids
2025-01-16 12:51:11 -08:00
2025-03-03 00:12:04 -08:00
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
if batch . return_logprob :
extend_input_len_per_req = [ req . extend_input_len for req in batch . reqs ]
extend_logprob_start_len_per_req = [
req . extend_logprob_start_len for req in batch . reqs
]
else :
extend_input_len_per_req = None
extend_logprob_start_len_per_req = None
2025-01-16 12:51:11 -08:00
ret = GenerationBatchResult (
logits_output = logits_output ,
next_token_ids = next_token_ids ,
2025-03-03 00:12:04 -08:00
extend_input_len_per_req = extend_input_len_per_req ,
extend_logprob_start_len_per_req = extend_logprob_start_len_per_req ,
2025-01-16 12:51:11 -08:00
bid = model_worker_batch . bid ,
)
2024-10-06 03:24:04 -07:00
else : # embedding or reward model
model_worker_batch = batch . get_model_worker_batch ( )
embeddings = self . tp_worker . forward_batch_embedding ( model_worker_batch )
2025-01-16 12:51:11 -08:00
ret = EmbeddingBatchResult (
embeddings = embeddings , bid = model_worker_batch . bid
)
2024-10-12 21:35:30 -07:00
return ret
2024-11-07 15:42:47 -08:00
2025-01-16 12:51:11 -08:00
def process_batch_result (
self ,
batch : ScheduleBatch ,
result : Union [ GenerationBatchResult , EmbeddingBatchResult ] ,
) :
2024-10-06 03:24:04 -07:00
if batch . forward_mode . is_decode ( ) :
self . process_batch_result_decode ( batch , result )
2024-10-12 21:35:30 -07:00
if batch . is_empty ( ) :
self . running_batch = None
2024-11-19 15:04:43 -08:00
elif batch . forward_mode . is_extend ( ) :
2024-10-06 03:24:04 -07:00
self . process_batch_result_prefill ( batch , result )
2025-01-16 11:15:00 -08:00
elif batch . forward_mode . is_idle ( ) :
if self . enable_overlap :
2025-01-16 12:51:11 -08:00
self . tp_worker . resolve_batch_result ( result . bid )
2025-02-26 01:32:05 +08:00
if batch . next_batch_sampling_info :
batch . next_batch_sampling_info . update_regex_vocab_mask ( )
self . current_stream . synchronize ( )
batch . next_batch_sampling_info . sampling_info_done . set ( )
2024-11-19 15:04:43 -08:00
elif batch . forward_mode . is_dummy_first ( ) :
batch . next_batch_sampling_info . update_regex_vocab_mask ( )
2024-12-06 05:49:29 -08:00
self . current_stream . synchronize ( )
2024-11-19 15:04:43 -08:00
batch . next_batch_sampling_info . sampling_info_done . set ( )
2024-10-06 03:24:04 -07:00
2025-03-03 00:12:04 -08:00
if self . return_health_check_ct :
# Return some signal for the health check.
# This is used to prevent the health check signal being blocked by long context prefill.
# However, one minor issue is that this code path does not check the status of detokenizer manager.
self . return_health_check_ct - = 1
self . send_to_tokenizer . send_pyobj ( HealthCheckOutput ( ) )
2025-01-16 12:51:11 -08:00
def process_batch_result_prefill (
self ,
batch : ScheduleBatch ,
result : Union [ GenerationBatchResult , EmbeddingBatchResult ] ,
) :
2024-12-08 03:55:27 -08:00
skip_stream_req = None
2024-10-27 02:00:50 -07:00
2024-10-06 03:24:04 -07:00
if self . is_generation :
2025-01-16 12:51:11 -08:00
(
logits_output ,
next_token_ids ,
2025-03-03 00:12:04 -08:00
extend_input_len_per_req ,
extend_logprob_start_len_per_req ,
2025-01-16 12:51:11 -08:00
bid ,
) = (
result . logits_output ,
result . next_token_ids ,
2025-03-03 00:12:04 -08:00
result . extend_input_len_per_req ,
result . extend_logprob_start_len_per_req ,
2025-01-16 12:51:11 -08:00
result . bid ,
)
2024-10-20 19:47:14 -07:00
if self . enable_overlap :
2024-11-17 15:48:12 -08:00
logits_output , next_token_ids = self . tp_worker . resolve_batch_result ( bid )
2024-10-20 19:47:14 -07:00
else :
# Move next_token_ids and logprobs to cpu
2024-12-30 04:51:38 -08:00
next_token_ids = next_token_ids . tolist ( )
2024-10-20 19:47:14 -07:00
if batch . return_logprob :
2025-03-03 00:12:04 -08:00
if logits_output . next_token_logprobs is not None :
logits_output . next_token_logprobs = (
logits_output . next_token_logprobs . tolist ( )
)
if logits_output . input_token_logprobs is not None :
logits_output . input_token_logprobs = tuple (
logits_output . input_token_logprobs . tolist ( )
)
2024-09-29 17:42:45 -07:00
2025-02-10 15:54:37 -08:00
hidden_state_offset = 0
2024-09-29 17:42:45 -07:00
# Check finish conditions
logprob_pt = 0
2024-11-18 17:48:28 -08:00
for i , ( req , next_token_id ) in enumerate ( zip ( batch . reqs , next_token_ids ) ) :
2024-10-31 18:27:42 -07:00
if req . is_retracted :
continue
2024-11-24 04:47:10 -08:00
if self . is_mixed_chunk and self . enable_overlap and req . finished ( ) :
2024-11-24 07:17:37 -08:00
# Free the one delayed token for the mixed decode batch
j = len ( batch . out_cache_loc ) - len ( batch . reqs ) + i
self . token_to_kv_pool . free ( batch . out_cache_loc [ j : j + 1 ] )
continue
2024-11-24 04:47:10 -08:00
2025-03-03 00:12:04 -08:00
if req . is_chunked < = 0 :
# req output_ids are set here
2024-11-18 17:48:28 -08:00
req . output_ids . append ( next_token_id )
2024-09-29 17:42:45 -07:00
req . check_finished ( )
2024-10-14 06:47:50 -07:00
if req . finished ( ) :
2024-10-19 17:39:38 -07:00
self . tree_cache . cache_finished_req ( req )
2024-10-14 06:47:50 -07:00
elif not batch . decoding_reqs or req not in batch . decoding_reqs :
2025-03-03 00:12:04 -08:00
# This updates radix so others can match
2024-10-14 06:47:50 -07:00
self . tree_cache . cache_unfinished_req ( req )
2024-10-17 18:33:21 -07:00
if req . return_logprob :
2025-03-03 00:12:04 -08:00
assert extend_logprob_start_len_per_req is not None
assert extend_input_len_per_req is not None
extend_logprob_start_len = extend_logprob_start_len_per_req [ i ]
extend_input_len = extend_input_len_per_req [ i ]
num_input_logprobs = extend_input_len - extend_logprob_start_len
self . add_logprob_return_values (
i ,
req ,
logprob_pt ,
next_token_ids ,
num_input_logprobs ,
logits_output ,
2024-10-17 18:33:21 -07:00
)
2025-03-03 00:12:04 -08:00
logprob_pt + = num_input_logprobs
2025-02-10 15:54:37 -08:00
if (
2025-03-01 20:51:29 -05:00
req . return_hidden_states
2025-02-10 15:54:37 -08:00
and logits_output . hidden_states is not None
) :
req . hidden_states . append (
logits_output . hidden_states [
hidden_state_offset : (
hidden_state_offset := hidden_state_offset
+ len ( req . origin_input_ids )
)
]
. cpu ( )
. clone ( )
)
2024-11-24 04:47:10 -08:00
if req . grammar is not None :
req . grammar . accept_token ( next_token_id )
2024-12-06 05:49:29 -08:00
req . grammar . finished = req . finished ( )
2024-10-31 18:27:42 -07:00
else :
2024-12-01 00:37:53 -08:00
# being chunked reqs' prefill is not finished
2025-03-03 00:12:04 -08:00
req . is_chunked - = 1
2024-12-08 03:55:27 -08:00
# There is only at most one request being currently chunked.
# Because this request does not finish prefill,
# we don't want to stream the request currently being chunked.
skip_stream_req = req
2024-10-31 18:27:42 -07:00
2025-03-03 00:12:04 -08:00
# Incrementally update input logprobs.
if req . return_logprob :
extend_logprob_start_len = extend_logprob_start_len_per_req [ i ]
extend_input_len = extend_input_len_per_req [ i ]
if extend_logprob_start_len < extend_input_len :
# Update input logprobs.
num_input_logprobs = (
extend_input_len - extend_logprob_start_len
)
self . add_input_logprob_return_values (
i ,
req ,
logits_output ,
logprob_pt ,
num_input_logprobs ,
last_prefill_chunk = False ,
)
logprob_pt + = num_input_logprobs
2024-11-19 15:04:43 -08:00
if batch . next_batch_sampling_info :
batch . next_batch_sampling_info . update_regex_vocab_mask ( )
2024-12-06 05:49:29 -08:00
self . current_stream . synchronize ( )
2024-11-19 15:04:43 -08:00
batch . next_batch_sampling_info . sampling_info_done . set ( )
2024-10-06 03:24:04 -07:00
else : # embedding or reward model
2025-01-16 12:51:11 -08:00
embeddings , bid = result . embeddings , result . bid
2024-10-16 11:20:17 -07:00
embeddings = embeddings . tolist ( )
2024-09-29 17:42:45 -07:00
# Check finish conditions
for i , req in enumerate ( batch . reqs ) :
2024-10-31 18:27:42 -07:00
if req . is_retracted :
continue
2024-09-29 17:42:45 -07:00
req . embedding = embeddings [ i ]
2025-03-03 00:12:04 -08:00
if req . is_chunked < = 0 :
2024-11-24 04:47:10 -08:00
# Dummy output token for embedding models
2024-09-29 17:42:45 -07:00
req . output_ids . append ( 0 )
req . check_finished ( )
2024-11-24 04:47:10 -08:00
if req . finished ( ) :
self . tree_cache . cache_finished_req ( req )
else :
self . tree_cache . cache_unfinished_req ( req )
2024-09-29 17:42:45 -07:00
else :
2024-12-01 00:37:53 -08:00
# being chunked reqs' prefill is not finished
2025-03-03 00:12:04 -08:00
req . is_chunked - = 1
2024-09-29 17:42:45 -07:00
2024-12-08 12:27:13 -08:00
self . stream_output ( batch . reqs , batch . return_logprob , skip_stream_req )
2024-09-29 17:42:45 -07:00
2025-01-16 12:51:11 -08:00
def process_batch_result_decode (
self ,
batch : ScheduleBatch ,
result : GenerationBatchResult ,
) :
logits_output , next_token_ids , bid = (
result . logits_output ,
result . next_token_ids ,
result . bid ,
)
2024-10-06 03:24:04 -07:00
self . num_generated_tokens + = len ( batch . reqs )
2024-10-20 19:47:14 -07:00
if self . enable_overlap :
2024-11-17 15:48:12 -08:00
logits_output , next_token_ids = self . tp_worker . resolve_batch_result ( bid )
2024-10-25 11:06:57 -07:00
next_token_logprobs = logits_output . next_token_logprobs
2024-10-20 19:47:14 -07:00
else :
next_token_ids = next_token_ids . tolist ( )
2024-12-30 04:51:38 -08:00
if batch . return_logprob :
next_token_logprobs = logits_output . next_token_logprobs . tolist ( )
2024-10-06 03:24:04 -07:00
2024-10-18 13:21:05 -07:00
self . token_to_kv_pool . free_group_begin ( )
2024-10-06 03:24:04 -07:00
# Check finish condition
for i , ( req , next_token_id ) in enumerate ( zip ( batch . reqs , next_token_ids ) ) :
2024-10-31 18:27:42 -07:00
if req . is_retracted :
continue
2024-11-17 19:49:20 -08:00
if self . enable_overlap and req . finished ( ) :
2024-11-24 06:29:38 -08:00
# Free the one delayed token
2024-10-19 06:50:56 -07:00
self . token_to_kv_pool . free ( batch . out_cache_loc [ i : i + 1 ] )
2024-10-16 01:33:20 -07:00
continue
2025-01-02 02:09:08 -08:00
if batch . spec_algorithm . is_none ( ) :
# speculative worker will solve the output_ids in speculative decoding
req . output_ids . append ( next_token_id )
2024-10-06 03:24:04 -07:00
req . check_finished ( )
if req . finished ( ) :
2024-10-19 17:39:38 -07:00
self . tree_cache . cache_finished_req ( req )
2024-10-06 03:24:04 -07:00
2025-03-03 00:12:04 -08:00
if req . return_logprob and batch . spec_algorithm . is_none ( ) :
# speculative worker handles logprob in speculative decoding
2024-12-08 12:27:13 -08:00
req . output_token_logprobs_val . append ( next_token_logprobs [ i ] )
req . output_token_logprobs_idx . append ( next_token_id )
2024-10-06 03:24:04 -07:00
if req . top_logprobs_num > 0 :
2024-12-08 12:27:13 -08:00
req . output_top_logprobs_val . append (
2024-12-30 04:51:38 -08:00
logits_output . next_token_top_logprobs_val [ i ]
2024-12-08 12:27:13 -08:00
)
req . output_top_logprobs_idx . append (
2024-12-30 04:51:38 -08:00
logits_output . next_token_top_logprobs_idx [ i ]
2024-12-08 12:27:13 -08:00
)
2025-03-03 00:12:04 -08:00
if req . token_ids_logprob is not None :
req . output_token_ids_logprobs_val . append (
logits_output . next_token_token_ids_logprobs_val [ i ]
)
req . output_token_ids_logprobs_idx . append (
logits_output . next_token_token_ids_logprobs_idx [ i ]
)
2024-10-06 03:24:04 -07:00
2025-03-01 20:51:29 -05:00
if req . return_hidden_states and logits_output . hidden_states is not None :
2025-02-10 15:54:37 -08:00
req . hidden_states . append ( logits_output . hidden_states [ i ] . cpu ( ) . clone ( ) )
2025-03-03 00:12:04 -08:00
if req . grammar is not None and batch . spec_algorithm . is_none ( ) :
2024-11-24 04:47:10 -08:00
req . grammar . accept_token ( next_token_id )
2024-12-06 05:49:29 -08:00
req . grammar . finished = req . finished ( )
2024-11-24 04:47:10 -08:00
2024-11-19 15:04:43 -08:00
if batch . next_batch_sampling_info :
batch . next_batch_sampling_info . update_regex_vocab_mask ( )
2024-12-06 05:49:29 -08:00
self . current_stream . synchronize ( )
2024-11-19 15:04:43 -08:00
batch . next_batch_sampling_info . sampling_info_done . set ( )
2024-12-08 12:27:13 -08:00
self . stream_output ( batch . reqs , batch . return_logprob )
2024-10-06 03:24:04 -07:00
2024-10-18 13:21:05 -07:00
self . token_to_kv_pool . free_group_end ( )
2024-10-27 02:00:50 -07:00
self . forward_ct_decode = ( self . forward_ct_decode + 1 ) % ( 1 << 30 )
2024-11-07 15:42:47 -08:00
if (
2025-01-16 11:15:00 -08:00
self . attn_tp_rank == 0
2024-11-07 15:42:47 -08:00
and self . forward_ct_decode % self . server_args . decode_log_interval == 0
) :
2024-11-10 04:39:32 -08:00
self . log_decode_stats ( )
2024-10-07 13:05:53 -07:00
2025-03-03 00:12:04 -08:00
def add_input_logprob_return_values (
2024-09-29 17:42:45 -07:00
self ,
i : int ,
req : Req ,
output : LogitsProcessorOutput ,
2025-03-03 00:12:04 -08:00
logprob_pt : int ,
num_input_logprobs : int ,
last_prefill_chunk : bool , # If True, it means prefill is finished.
2024-09-29 17:42:45 -07:00
) :
2025-03-03 00:12:04 -08:00
""" Incrementally add input logprobs to `req`.
Args :
i : The request index in a batch .
req : The request . Input logprobs inside req are modified as a
consequence of the API
fill_ids : The prefill ids processed .
output : Logit processor output that ' s used to compute input logprobs
last_prefill_chunk : True if it is the last prefill ( when chunked ) .
Some of input logprob operation should only happen at the last
prefill ( e . g . , computing input token logprobs ) .
"""
assert output . input_token_logprobs is not None
if req . input_token_logprobs is None :
req . input_token_logprobs = [ ]
if req . temp_input_top_logprobs_val is None :
req . temp_input_top_logprobs_val = [ ]
if req . temp_input_top_logprobs_idx is None :
req . temp_input_top_logprobs_idx = [ ]
if req . temp_input_token_ids_logprobs_val is None :
req . temp_input_token_ids_logprobs_val = [ ]
if req . temp_input_token_ids_logprobs_idx is None :
req . temp_input_token_ids_logprobs_idx = [ ]
if req . input_token_logprobs_val is not None :
# The input logprob has been already computed. It only happens
# upon retract.
if req . top_logprobs_num > 0 :
assert req . input_token_logprobs_val is not None
return
2024-09-29 17:42:45 -07:00
2025-03-03 00:12:04 -08:00
# Important for the performance.
assert isinstance ( output . input_token_logprobs , tuple )
input_token_logprobs : Tuple [ int ] = output . input_token_logprobs
input_token_logprobs = input_token_logprobs [
logprob_pt : logprob_pt + num_input_logprobs
]
req . input_token_logprobs . extend ( input_token_logprobs )
2024-09-29 17:42:45 -07:00
2025-03-03 00:12:04 -08:00
if req . top_logprobs_num > 0 :
req . temp_input_top_logprobs_val . append ( output . input_top_logprobs_val [ i ] )
req . temp_input_top_logprobs_idx . append ( output . input_top_logprobs_idx [ i ] )
2024-12-08 12:27:13 -08:00
2025-03-03 00:12:04 -08:00
if req . token_ids_logprob is not None :
req . temp_input_token_ids_logprobs_val . append (
output . input_token_ids_logprobs_val [ i ]
)
req . temp_input_token_ids_logprobs_idx . append (
output . input_token_ids_logprobs_idx [ i ]
)
if last_prefill_chunk :
input_token_logprobs = req . input_token_logprobs
req . input_token_logprobs = None
assert req . input_token_logprobs_val is None
assert req . input_token_logprobs_idx is None
assert req . input_top_logprobs_val is None
assert req . input_top_logprobs_idx is None
# Compute input_token_logprobs_val
# Always pad the first one with None.
req . input_token_logprobs_val = [ None ]
req . input_token_logprobs_val . extend ( input_token_logprobs )
# The last input logprob is for sampling, so just pop it out.
req . input_token_logprobs_val . pop ( )
# Compute input_token_logprobs_idx
input_token_logprobs_idx = req . origin_input_ids [ req . logprob_start_len : ]
2024-12-01 23:36:28 -08:00
# Clip the padded hash values from image tokens.
# Otherwise, it will lead to detokenization errors.
2024-12-08 12:27:13 -08:00
input_token_logprobs_idx = [
2024-12-01 23:36:28 -08:00
x if x < self . model_config . vocab_size - 1 else 0
2024-12-08 12:27:13 -08:00
for x in input_token_logprobs_idx
2024-12-01 23:36:28 -08:00
]
2025-03-03 00:12:04 -08:00
req . input_token_logprobs_idx = input_token_logprobs_idx
2024-12-01 23:36:28 -08:00
2025-03-03 00:12:04 -08:00
if req . top_logprobs_num > 0 :
req . input_top_logprobs_val = [ None ]
req . input_top_logprobs_idx = [ None ]
2024-12-08 12:27:13 -08:00
2025-03-03 00:12:04 -08:00
for val , idx in zip (
req . temp_input_top_logprobs_val ,
req . temp_input_top_logprobs_idx ,
strict = True ,
) :
req . input_top_logprobs_val . extend ( val )
req . input_top_logprobs_idx . extend ( idx )
# Last token is a sample token.
req . input_top_logprobs_val . pop ( )
req . input_top_logprobs_idx . pop ( )
req . temp_input_top_logprobs_idx = None
req . temp_input_top_logprobs_val = None
if req . token_ids_logprob is not None :
req . input_token_ids_logprobs_val = [ None ]
req . input_token_ids_logprobs_idx = [ None ]
for val , idx in zip (
req . temp_input_token_ids_logprobs_val ,
req . temp_input_token_ids_logprobs_idx ,
strict = True ,
) :
req . input_token_ids_logprobs_val . extend ( val )
req . input_token_ids_logprobs_idx . extend ( idx )
# Last token is a sample token.
req . input_token_ids_logprobs_val . pop ( )
req . input_token_ids_logprobs_idx . pop ( )
req . temp_input_token_ids_logprobs_idx = None
req . temp_input_token_ids_logprobs_val = None
if req . return_logprob :
relevant_tokens_len = len ( req . origin_input_ids ) - req . logprob_start_len
assert len ( req . input_token_logprobs_val ) == relevant_tokens_len
assert len ( req . input_token_logprobs_idx ) == relevant_tokens_len
if req . top_logprobs_num > 0 :
assert len ( req . input_top_logprobs_val ) == relevant_tokens_len
assert len ( req . input_top_logprobs_idx ) == relevant_tokens_len
if req . token_ids_logprob is not None :
assert len ( req . input_token_ids_logprobs_val ) == relevant_tokens_len
assert len ( req . input_token_ids_logprobs_idx ) == relevant_tokens_len
def add_logprob_return_values (
self ,
i : int ,
req : Req ,
pt : int ,
next_token_ids : List [ int ] ,
num_input_logprobs : int ,
output : LogitsProcessorOutput ,
) :
""" Attach logprobs to the return values. """
req . output_token_logprobs_val . append ( output . next_token_logprobs [ i ] )
req . output_token_logprobs_idx . append ( next_token_ids [ i ] )
2024-09-29 17:42:45 -07:00
2025-03-03 00:12:04 -08:00
self . add_input_logprob_return_values (
i , req , output , pt , num_input_logprobs , last_prefill_chunk = True
)
2024-09-29 17:42:45 -07:00
if req . top_logprobs_num > 0 :
2024-12-30 04:51:38 -08:00
req . output_top_logprobs_val . append ( output . next_token_top_logprobs_val [ i ] )
req . output_top_logprobs_idx . append ( output . next_token_top_logprobs_idx [ i ] )
2024-09-29 17:42:45 -07:00
2025-03-03 00:12:04 -08:00
if req . token_ids_logprob is not None :
req . output_token_ids_logprobs_val . append (
output . next_token_token_ids_logprobs_val [ i ]
)
req . output_token_ids_logprobs_idx . append (
output . next_token_token_ids_logprobs_idx [ i ]
)
2024-09-29 17:42:45 -07:00
return num_input_logprobs
2024-12-08 12:27:13 -08:00
def stream_output (
self , reqs : List [ Req ] , return_logprob : bool , skip_req : Optional [ Req ] = None
) :
2024-10-19 23:19:26 -07:00
""" Stream the output to detokenizer. """
2024-12-08 12:27:13 -08:00
rids = [ ]
finished_reasons : List [ BaseFinishReason ] = [ ]
2024-09-29 17:42:45 -07:00
if self . is_generation :
decoded_texts = [ ]
2024-12-08 12:27:13 -08:00
decode_ids_list = [ ]
read_offsets = [ ]
2024-10-25 18:51:59 -07:00
output_ids = [ ]
2024-12-09 03:05:59 -08:00
2024-12-08 12:27:13 -08:00
skip_special_tokens = [ ]
spaces_between_special_tokens = [ ]
no_stop_trim = [ ]
prompt_tokens = [ ]
completion_tokens = [ ]
cached_tokens = [ ]
2025-01-26 04:51:54 -08:00
spec_verify_ct = [ ]
2025-03-01 20:51:29 -05:00
output_hidden_states = None
2024-12-08 12:27:13 -08:00
if return_logprob :
input_token_logprobs_val = [ ]
input_token_logprobs_idx = [ ]
output_token_logprobs_val = [ ]
output_token_logprobs_idx = [ ]
input_top_logprobs_val = [ ]
input_top_logprobs_idx = [ ]
output_top_logprobs_val = [ ]
output_top_logprobs_idx = [ ]
2025-03-03 00:12:04 -08:00
input_token_ids_logprobs_val = [ ]
input_token_ids_logprobs_idx = [ ]
output_token_ids_logprobs_val = [ ]
output_token_ids_logprobs_idx = [ ]
2024-12-08 12:27:13 -08:00
else :
input_token_logprobs_val = input_token_logprobs_idx = (
output_token_logprobs_val
) = output_token_logprobs_idx = input_top_logprobs_val = (
input_top_logprobs_idx
2025-03-03 00:12:04 -08:00
) = output_top_logprobs_val = output_top_logprobs_idx = (
input_token_ids_logprobs_val
) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
output_token_ids_logprobs_idx
) = None
2024-12-08 12:27:13 -08:00
for req in reqs :
if req is skip_req :
continue
2024-12-08 03:55:27 -08:00
2025-03-03 00:12:04 -08:00
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
if self . model_config . is_multimodal_gen and req . to_abort :
continue
2024-12-08 12:27:13 -08:00
if (
req . finished ( )
# If stream, follow the given stream_interval
or ( req . stream and len ( req . output_ids ) % self . stream_interval == 0 )
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
2025-03-03 00:12:04 -08:00
# TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
# always increase one-by-one.
or (
not req . stream
and len ( req . output_ids ) % 50 == 0
and not self . model_config . is_multimodal_gen
)
2024-12-08 12:27:13 -08:00
) :
2025-01-02 02:09:08 -08:00
if self . draft_worker and req . finished ( ) :
self . draft_worker . finish_request ( req )
2024-12-08 12:27:13 -08:00
rids . append ( req . rid )
finished_reasons . append (
req . finished_reason . to_json ( ) if req . finished_reason else None
)
2024-09-29 17:42:45 -07:00
decoded_texts . append ( req . decoded_text )
2024-12-08 12:27:13 -08:00
decode_ids , read_offset = req . init_incremental_detokenize ( )
decode_ids_list . append ( decode_ids )
read_offsets . append ( read_offset )
2025-01-11 23:14:26 +00:00
if self . skip_tokenizer_init :
2024-10-25 18:51:59 -07:00
output_ids . append ( req . output_ids )
2024-12-08 12:27:13 -08:00
skip_special_tokens . append ( req . sampling_params . skip_special_tokens )
spaces_between_special_tokens . append (
2024-09-29 17:42:45 -07:00
req . sampling_params . spaces_between_special_tokens
)
2024-12-08 12:27:13 -08:00
no_stop_trim . append ( req . sampling_params . no_stop_trim )
prompt_tokens . append ( len ( req . origin_input_ids ) )
completion_tokens . append ( len ( req . output_ids ) )
cached_tokens . append ( req . cached_tokens )
2025-01-26 04:51:54 -08:00
if not self . spec_algorithm . is_none ( ) :
spec_verify_ct . append ( req . spec_verify_ct )
2024-12-08 12:27:13 -08:00
if return_logprob :
input_token_logprobs_val . append ( req . input_token_logprobs_val )
input_token_logprobs_idx . append ( req . input_token_logprobs_idx )
output_token_logprobs_val . append ( req . output_token_logprobs_val )
output_token_logprobs_idx . append ( req . output_token_logprobs_idx )
input_top_logprobs_val . append ( req . input_top_logprobs_val )
input_top_logprobs_idx . append ( req . input_top_logprobs_idx )
output_top_logprobs_val . append ( req . output_top_logprobs_val )
output_top_logprobs_idx . append ( req . output_top_logprobs_idx )
2025-03-03 00:12:04 -08:00
input_token_ids_logprobs_val . append (
req . input_token_ids_logprobs_val
)
input_token_ids_logprobs_idx . append (
req . input_token_ids_logprobs_idx
)
output_token_ids_logprobs_val . append (
req . output_token_ids_logprobs_val
)
output_token_ids_logprobs_idx . append (
req . output_token_ids_logprobs_idx
)
2024-12-08 12:27:13 -08:00
2025-03-01 20:51:29 -05:00
if req . return_hidden_states :
if output_hidden_states is None :
output_hidden_states = [ ]
2025-02-24 16:17:38 -08:00
output_hidden_states . append ( req . hidden_states )
2025-02-10 15:54:37 -08:00
2024-12-08 12:27:13 -08:00
# Send to detokenizer
if rids :
2025-03-03 00:12:04 -08:00
if self . model_config . is_multimodal_gen :
raise NotImplementedError ( )
2024-10-12 21:35:30 -07:00
self . send_to_detokenizer . send_pyobj (
2024-09-29 17:42:45 -07:00
BatchTokenIDOut (
2024-12-08 12:27:13 -08:00
rids ,
finished_reasons ,
2024-09-29 17:42:45 -07:00
decoded_texts ,
2024-12-08 12:27:13 -08:00
decode_ids_list ,
read_offsets ,
2024-10-25 18:51:59 -07:00
output_ids ,
2024-12-08 12:27:13 -08:00
skip_special_tokens ,
spaces_between_special_tokens ,
no_stop_trim ,
prompt_tokens ,
completion_tokens ,
cached_tokens ,
2025-01-26 04:51:54 -08:00
spec_verify_ct ,
2024-12-08 12:27:13 -08:00
input_token_logprobs_val ,
input_token_logprobs_idx ,
output_token_logprobs_val ,
output_token_logprobs_idx ,
input_top_logprobs_val ,
input_top_logprobs_idx ,
output_top_logprobs_val ,
output_top_logprobs_idx ,
2025-03-03 00:12:04 -08:00
input_token_ids_logprobs_val ,
input_token_ids_logprobs_idx ,
output_token_ids_logprobs_val ,
output_token_ids_logprobs_idx ,
2025-02-24 16:17:38 -08:00
output_hidden_states ,
2024-09-29 17:42:45 -07:00
)
)
2024-12-08 12:27:13 -08:00
else : # embedding or reward model
embeddings = [ ]
prompt_tokens = [ ]
for req in reqs :
2025-01-02 18:25:26 -08:00
if req . finished ( ) :
rids . append ( req . rid )
finished_reasons . append ( req . finished_reason . to_json ( ) )
embeddings . append ( req . embedding )
prompt_tokens . append ( len ( req . origin_input_ids ) )
2024-12-08 12:27:13 -08:00
self . send_to_detokenizer . send_pyobj (
BatchEmbeddingOut ( rids , finished_reasons , embeddings , prompt_tokens )
)
2024-09-29 17:42:45 -07:00
2024-12-06 05:49:29 -08:00
def prepare_dp_attn_batch ( self , local_batch : ScheduleBatch ) :
# Check if other DP workers have running batches
if local_batch is None :
num_tokens = 0
elif local_batch . forward_mode . is_decode ( ) :
num_tokens = local_batch . batch_size ( )
else :
num_tokens = local_batch . extend_num_tokens
local_num_tokens = torch . tensor ( [ num_tokens ] , dtype = torch . int64 )
global_num_tokens = torch . empty ( self . tp_size , dtype = torch . int64 )
torch . distributed . all_gather_into_tensor (
global_num_tokens ,
local_num_tokens ,
group = self . tp_cpu_group ,
)
if local_batch is None and global_num_tokens . max ( ) . item ( ) > 0 :
local_batch = self . get_idle_batch ( )
if local_batch is not None :
local_batch . global_num_tokens = global_num_tokens . tolist ( )
# Check forward mode for cuda graph
if not self . server_args . disable_cuda_graph :
forward_mode_state = torch . tensor (
2025-01-16 11:15:00 -08:00
( 1 if local_batch . forward_mode . is_decode_or_idle ( ) else 0 ) ,
2024-12-06 05:49:29 -08:00
dtype = torch . int32 ,
)
torch . distributed . all_reduce (
forward_mode_state ,
op = torch . distributed . ReduceOp . MIN ,
group = self . tp_cpu_group ,
)
local_batch . can_run_dp_cuda_graph = forward_mode_state . item ( ) == 1
return local_batch
def get_idle_batch ( self ) :
idle_batch = ScheduleBatch . init_new (
[ ] ,
self . req_to_token_pool ,
self . token_to_kv_pool ,
self . tree_cache ,
self . model_config ,
self . enable_overlap ,
2025-01-02 02:09:08 -08:00
self . spec_algorithm ,
2025-01-20 02:00:35 -08:00
self . server_args . enable_custom_logit_processor ,
2024-12-06 05:49:29 -08:00
)
idle_batch . prepare_for_idle ( )
return idle_batch
2024-11-13 01:49:45 -08:00
def move_ready_grammar_requests ( self ) :
""" Move requests whose grammar objects are ready from grammar_queue to waiting_queue. """
num_ready_reqs = 0
for req in self . grammar_queue :
try :
req . grammar = req . grammar . result ( timeout = 0.05 )
num_ready_reqs + = 1
except futures . _base . TimeoutError :
break
2025-02-26 01:32:05 +08:00
if self . server_args . enable_dp_attention :
if self . attn_tp_size > 1 :
# Sync across attn TP ranks to make sure they have the same number of ready requests
tensor = torch . tensor ( num_ready_reqs , dtype = torch . int32 )
torch . distributed . all_reduce (
tensor ,
op = torch . distributed . ReduceOp . MAX ,
group = self . attn_tp_cpu_group ,
)
num_ready_reqs_max = tensor . item ( )
for i in range ( num_ready_reqs , num_ready_reqs_max ) :
self . grammar_queue [ i ] . grammar = self . grammar_queue [
i
] . grammar . result ( )
num_ready_reqs = num_ready_reqs_max
else :
if self . tp_size > 1 :
# Sync across TP ranks to make sure they have the same number of ready requests
tensor = torch . tensor ( num_ready_reqs , dtype = torch . int32 )
torch . distributed . all_reduce (
tensor , op = torch . distributed . ReduceOp . MAX , group = self . tp_cpu_group
)
num_ready_reqs_max = tensor . item ( )
for i in range ( num_ready_reqs , num_ready_reqs_max ) :
self . grammar_queue [ i ] . grammar = self . grammar_queue [
i
] . grammar . result ( )
num_ready_reqs = num_ready_reqs_max
2024-11-13 01:49:45 -08:00
2025-03-03 00:12:04 -08:00
self . _extend_requests_to_queue ( self . grammar_queue [ : num_ready_reqs ] )
2024-11-13 01:49:45 -08:00
self . grammar_queue = self . grammar_queue [ num_ready_reqs : ]
2025-01-19 12:13:27 +08:00
def flush_cache_wrapped ( self , recv_req : FlushCacheReq ) :
self . flush_cache ( )
2024-09-29 17:42:45 -07:00
def flush_cache ( self ) :
2024-10-19 23:19:26 -07:00
""" Flush the memory pool and cache. """
2024-09-29 17:42:45 -07:00
if len ( self . waiting_queue ) == 0 and (
self . running_batch is None or len ( self . running_batch . reqs ) == 0
) :
2025-03-03 00:12:04 -08:00
self . cur_batch = None
self . last_batch = None
2024-09-29 17:42:45 -07:00
self . tree_cache . reset ( )
self . tree_cache_metrics = { " total " : 0 , " hit " : 0 }
2024-11-13 01:49:45 -08:00
if self . grammar_backend :
2024-11-12 21:17:38 -08:00
self . grammar_backend . reset ( )
2024-09-29 17:42:45 -07:00
self . req_to_token_pool . clear ( )
self . token_to_kv_pool . clear ( )
2025-01-20 20:25:13 -08:00
if not self . spec_algorithm . is_none ( ) :
self . draft_worker . model_runner . req_to_token_pool . clear ( )
self . draft_worker . model_runner . token_to_kv_pool . clear ( )
self . num_generated_tokens = 0
self . forward_ct_decode = 0
self . spec_num_total_accepted_tokens = 0
self . spec_num_total_forward_ct = 0
2025-03-03 00:12:04 -08:00
self . cum_spec_accept_length = 0
self . cum_spec_accept_count = 0
2024-09-29 17:42:45 -07:00
torch . cuda . empty_cache ( )
logger . info ( " Cache flushed successfully! " )
if_success = True
else :
logging . warning (
f " Cache not flushed because there are pending requests. "
f " #queue-req: { len ( self . waiting_queue ) } , "
f " #running-req: { 0 if self . running_batch is None else len ( self . running_batch . reqs ) } "
)
if_success = False
return if_success
2025-03-03 00:12:04 -08:00
def get_internal_state ( self , recv_req : GetInternalStateReq ) :
ret = dict ( global_server_args_dict )
ret [ " last_gen_throughput " ] = self . last_gen_throughput
if not self . spec_algorithm . is_none ( ) and self . cum_spec_accept_count > 0 :
ret [ " avg_spec_accept_length " ] = (
self . cum_spec_accept_length / self . cum_spec_accept_count
)
if RECORD_STEP_TIME :
ret [ " step_time_dict " ] = self . step_time_dict
return GetInternalStateReqOutput (
internal_state = ret ,
)
def set_internal_state ( self , recv_req : SetInternalStateReq ) :
server_args_dict = recv_req . server_args
args_allow_update = set (
[
" speculative_accept_threshold_single " ,
" speculative_accept_threshold_acc " ,
]
)
if_success = True
for k , v in server_args_dict . items ( ) :
if k not in args_allow_update :
logging . warning ( f " Updating { k } is not supported. " )
if_success = False
break
if if_success :
if not self . spec_algorithm . is_none ( ) and self . cum_spec_accept_count > 0 :
avg_spec_accept_length = (
self . cum_spec_accept_length / self . cum_spec_accept_count
)
logger . info ( f " { avg_spec_accept_length =} " )
self . cum_spec_accept_length = self . cum_spec_accept_count = 0
for k , v in server_args_dict . items ( ) :
global_server_args_dict [ k ] = v
logger . info ( f " Global server args updated! " f " { global_server_args_dict =} " )
return SetInternalStateReqOutput (
updated = True ,
server_args = global_server_args_dict ,
)
2024-09-29 17:42:45 -07:00
def abort_request ( self , recv_req : AbortReq ) :
# Delete requests in the waiting queue
to_del = None
for i , req in enumerate ( self . waiting_queue ) :
if req . rid == recv_req . rid :
to_del = i
break
if to_del is not None :
del self . waiting_queue [ to_del ]
2024-11-28 02:22:15 -08:00
logger . debug ( f " Abort queued request. { req . rid =} " )
return
2024-09-29 17:42:45 -07:00
# Delete requests in the running batch
if self . running_batch :
for req in self . running_batch . reqs :
2024-10-15 08:15:08 -07:00
if req . rid == recv_req . rid and not req . finished ( ) :
2024-11-28 02:22:15 -08:00
logger . debug ( f " Abort running request. { req . rid =} " )
req . to_abort = True
2024-09-29 17:42:45 -07:00
break
2024-11-29 17:17:00 -08:00
def update_weights_from_disk ( self , recv_req : UpdateWeightFromDiskReqInput ) :
""" In-place update of the weights from disk. """
success , message = self . tp_worker . update_weights_from_disk ( recv_req )
2024-09-29 17:42:45 -07:00
if success :
flash_cache_success = self . flush_cache ( )
assert flash_cache_success , " Cache flush failed after updating weights "
else :
logger . error ( message )
2025-03-03 00:12:04 -08:00
return UpdateWeightFromDiskReqOutput ( success , message , 0 )
2024-09-29 17:42:45 -07:00
2024-12-01 23:23:18 -08:00
def init_weights_update_group ( self , recv_req : InitWeightsUpdateGroupReqInput ) :
""" Initialize the online model parameter update group. """
success , message = self . tp_worker . init_weights_update_group ( recv_req )
2025-01-19 12:13:27 +08:00
return InitWeightsUpdateGroupReqOutput ( success , message )
2024-12-01 23:23:18 -08:00
def update_weights_from_distributed (
2025-01-07 02:52:53 -08:00
self ,
recv_req : UpdateWeightsFromDistributedReqInput ,
) - > Tuple [ bool , str ] :
2024-12-01 23:23:18 -08:00
""" Update the online model parameter. """
success , message = self . tp_worker . update_weights_from_distributed ( recv_req )
if success :
flash_cache_success = self . flush_cache ( )
assert flash_cache_success , " Cache flush failed after updating weights "
else :
logger . error ( message )
2025-01-19 12:13:27 +08:00
return UpdateWeightsFromDistributedReqOutput ( success , message )
2024-12-01 23:23:18 -08:00
2024-12-29 05:30:27 +08:00
def update_weights_from_tensor ( self , recv_req : UpdateWeightsFromTensorReqInput ) :
""" Update the online model parameter from tensors. """
success , message = self . tp_worker . update_weights_from_tensor ( recv_req )
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success :
2025-03-01 01:53:10 +08:00
if recv_req . flush_cache :
flash_cache_success = self . flush_cache ( )
assert flash_cache_success , " Cache flush failed after updating weights "
2024-12-29 05:30:27 +08:00
else :
logger . error ( message )
2025-01-19 12:13:27 +08:00
return UpdateWeightsFromTensorReqOutput ( success , message )
2024-12-29 05:30:27 +08:00
2024-11-29 23:36:38 -08:00
def get_weights_by_name ( self , recv_req : GetWeightsByNameReqInput ) :
parameter = self . tp_worker . get_weights_by_name ( recv_req )
2025-01-19 12:13:27 +08:00
return GetWeightsByNameReqOutput ( parameter )
2024-11-29 23:36:38 -08:00
2025-03-03 00:12:04 -08:00
def release_memory_occupation ( self , recv_req : ReleaseMemoryOccupationReqInput ) :
2025-01-14 03:38:51 +08:00
self . stashed_model_static_state = _export_static_state (
self . tp_worker . worker . model_runner . model
)
self . memory_saver_adapter . pause ( )
self . flush_cache ( )
2025-01-19 12:13:27 +08:00
return ReleaseMemoryOccupationReqOutput ( )
2025-01-14 03:38:51 +08:00
2025-03-03 00:12:04 -08:00
def resume_memory_occupation ( self , recv_req : ResumeMemoryOccupationReqInput ) :
2025-01-14 03:38:51 +08:00
self . memory_saver_adapter . resume ( )
_import_static_state (
self . tp_worker . worker . model_runner . model , self . stashed_model_static_state
)
del self . stashed_model_static_state
2025-01-19 12:13:27 +08:00
return ResumeMemoryOccupationReqOutput ( )
def profile ( self , recv_req : ProfileReq ) :
2025-03-03 00:12:04 -08:00
if recv_req . type == ProfileReqType . START_PROFILE :
return self . start_profile (
recv_req . output_dir , recv_req . num_steps , recv_req . activities
)
2025-01-19 12:13:27 +08:00
else :
2025-03-03 00:12:04 -08:00
return self . stop_profile ( )
def start_profile (
self ,
output_dir : Optional [ str ] ,
num_steps : Optional [ int ] ,
activities : Optional [ List [ str ] ] ,
) - > None :
if self . torch_profiler_activities :
return ProfileReqOutput (
success = False ,
message = " Profiling is already in progress. Call /stop_profile first. " ,
)
if output_dir is None :
output_dir = os . getenv ( " SGLANG_TORCH_PROFILER_DIR " , " /tmp " )
if activities is None :
activities = [ " CPU " , " GPU " ]
self . torch_profiler_output_dir = output_dir
self . torch_profiler_activities = activities
logger . info (
" Profiling starts. Traces will be saved to: %s " ,
self . torch_profiler_output_dir ,
)
activity_map = {
" CPU " : torch . profiler . ProfilerActivity . CPU ,
" GPU " : torch . profiler . ProfilerActivity . CUDA ,
}
torchprof_activities = [
activity_map [ a ] for a in activities if a in activity_map
]
if torchprof_activities :
self . torch_profiler = torch . profiler . profile (
activities = torchprof_activities ,
with_stack = True ,
)
self . torch_profiler . start ( )
if " MEM " in activities :
torch . cuda . memory . _record_memory_history ( max_entries = 100000 )
2025-01-14 03:38:51 +08:00
2025-03-03 00:12:04 -08:00
if num_steps :
self . profiler_target_forward_ct = self . forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else :
self . profiler_target_forward_ct = None
return ProfileReqOutput ( success = True , message = " Succeeded " )
2024-10-11 17:34:25 +08:00
def stop_profile ( self ) - > None :
2025-03-03 00:12:04 -08:00
if self . torch_profiler_activities is None :
return
logger . info ( " Stop profiling... " )
if self . torch_profiler is not None :
self . torch_profiler . stop ( )
self . torch_profiler . export_chrome_trace (
os . path . join (
self . torch_profiler_output_dir ,
str ( time . time ( ) ) + f " -TP- { self . tp_rank } " + " .trace.json.gz " ,
)
)
if " MEM " in self . torch_profiler_activities :
memory_profile_path = os . path . join (
self . torch_profiler_trace_dir ,
str ( time . time ( ) ) + f " -TP- { self . tp_rank } -memory " + " .pickle " ,
)
torch . cuda . memory . _dump_snapshot ( memory_profile_path )
torch . cuda . memory . _record_memory_history ( enabled = None )
logger . info (
" Profiling done. Traces are saved to: %s " ,
self . torch_profiler_output_dir ,
2024-10-11 17:34:25 +08:00
)
2025-03-03 00:12:04 -08:00
self . torch_profiler = None
self . torch_profiler_output_dir = None
self . torch_profiler_activities = None
if self . profiler_target_forward_ct :
self . send_to_tokenizer . send_pyobj (
ProfileReqOutput ( success = True , message = " Succeeded. " )
)
2024-10-11 17:34:25 +08:00
2025-01-19 12:13:27 +08:00
def open_session ( self , recv_req : OpenSessionReqInput ) :
2024-11-20 00:36:53 -08:00
# handle error
session_id = recv_req . session_id
if session_id in self . sessions :
logger . warning ( f " session id { session_id } already exist, cannot open. " )
2025-01-19 12:13:27 +08:00
return OpenSessionReqOutput ( session_id , False )
2024-12-29 02:10:27 -08:00
elif session_id is None :
2025-03-03 00:12:04 -08:00
logger . warning ( " session id is None, cannot open. " )
2025-01-19 12:13:27 +08:00
return OpenSessionReqOutput ( session_id , False )
2024-11-20 00:36:53 -08:00
else :
self . sessions [ session_id ] = Session (
recv_req . capacity_of_str_len , session_id
)
2025-01-19 12:13:27 +08:00
return OpenSessionReqOutput ( session_id , True )
2024-11-20 00:36:53 -08:00
def close_session ( self , recv_req : CloseSessionReqInput ) :
# handle error
session_id = recv_req . session_id
if session_id not in self . sessions :
logger . warning ( f " session id { session_id } does not exist, cannot delete. " )
else :
del self . sessions [ session_id ]
2024-09-29 02:36:12 -07:00
2025-03-03 00:12:04 -08:00
def is_health_check_generate_req ( recv_req ) :
return getattr ( recv_req , " rid " , " " ) . startswith ( " HEALTH_CHECK " )
2025-01-14 03:38:51 +08:00
def _export_static_state ( model ) :
return dict (
buffers = [
( name , buffer . detach ( ) . clone ( ) ) for name , buffer in model . named_buffers ( )
]
)
def _import_static_state ( model , static_params ) :
self_named_buffers = dict ( model . named_buffers ( ) )
for name , tensor in static_params [ " buffers " ] :
self_named_buffers [ name ] [ . . . ] = tensor
2024-09-29 02:36:12 -07:00
def run_scheduler_process (
server_args : ServerArgs ,
port_args : PortArgs ,
gpu_id : int ,
tp_rank : int ,
2024-10-11 07:22:48 -07:00
dp_rank : Optional [ int ] ,
2024-10-06 00:10:48 -07:00
pipe_writer ,
2024-09-29 02:36:12 -07:00
) :
2025-03-03 00:12:04 -08:00
# Config the process
# kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
setproctitle . setproctitle ( f " sglang::scheduler_ { dp_rank } " )
2025-01-13 01:39:14 -08:00
faulthandler . enable ( )
2025-03-03 00:12:04 -08:00
parent_process = psutil . Process ( ) . parent ( )
2024-12-08 01:06:15 -08:00
2024-11-27 09:36:36 -08:00
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and " SGLANG_DP_RANK " in os . environ :
dp_rank = int ( os . environ [ " SGLANG_DP_RANK " ] )
2024-11-21 17:13:33 -08:00
2025-02-25 08:51:23 +08:00
# Configure the logger
2024-10-11 07:22:48 -07:00
if dp_rank is None :
2025-03-03 00:12:04 -08:00
prefix = f " TP { tp_rank } "
2024-10-11 07:22:48 -07:00
else :
2025-03-03 00:12:04 -08:00
prefix = f " DP { dp_rank } TP { tp_rank } "
configure_logger ( server_args , prefix = prefix )
2024-12-29 00:45:57 -08:00
suppress_other_loggers ( )
2024-10-11 07:22:48 -07:00
2024-12-29 00:45:57 -08:00
# Set cpu affinity to this gpu process
2024-12-06 05:49:29 -08:00
if get_bool_env_var ( " SGLANG_SET_CPU_AFFINITY " ) :
set_gpu_proc_affinity ( server_args . tp_size , server_args . nnodes , gpu_id )
2024-12-29 00:45:57 -08:00
# Create a scheduler and run the event loop
2024-09-29 02:36:12 -07:00
try :
2024-10-19 17:39:38 -07:00
scheduler = Scheduler ( server_args , port_args , gpu_id , tp_rank , dp_rank )
2024-11-22 15:10:10 -08:00
pipe_writer . send (
2025-01-19 06:14:19 +08:00
{
" status " : " ready " ,
" max_total_num_tokens " : scheduler . max_total_num_tokens ,
" max_req_input_len " : scheduler . max_req_input_len ,
}
2024-11-22 15:10:10 -08:00
)
2024-11-17 19:49:20 -08:00
if scheduler . enable_overlap :
2024-10-16 01:33:20 -07:00
scheduler . event_loop_overlap ( )
else :
scheduler . event_loop_normal ( )
2024-09-29 02:36:12 -07:00
except Exception :
2024-11-28 00:22:39 -08:00
traceback = get_exception_traceback ( )
logger . error ( f " Scheduler hit an exception: { traceback } " )
parent_process . send_signal ( signal . SIGQUIT )