Files
sglang/python/sglang/srt/managers/scheduler.py

2377 lines
93 KiB
Python
Raw Normal View History

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-09-29 02:36:12 -07:00
"""A scheduler that manages a tensor parallel GPU worker."""
import faulthandler
2024-09-29 02:36:12 -07:00
import logging
import os
import signal
import sys
2024-10-27 02:00:50 -07:00
import threading
import time
from collections import defaultdict, deque
2024-11-12 21:17:38 -08:00
from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union
2024-09-29 02:36:12 -07:00
import psutil
import setproctitle
import torch
2024-09-29 02:36:12 -07:00
import zmq
from torch.distributed import barrier
2024-09-29 02:36:12 -07:00
from sglang.global_config import global_config
from sglang.srt import two_batch_overlap
2024-11-24 04:47:10 -08:00
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.disaggregation.decode import (
DecodePreallocQueue,
DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
MetadataBuffers,
ReqToMetadataIdxAllocator,
TransferBackend,
prepare_abort,
)
2025-04-30 18:18:07 -07:00
from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.hf_transformers_utils import (
get_processor,
get_tokenizer,
get_tokenizer_from_processor,
)
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.io_struct import (
AbortReq,
2024-11-20 00:36:53 -08:00
CloseSessionReqInput,
ExpertDistributionReq,
ExpertDistributionReqOutput,
FlushCacheReqInput,
FlushCacheReqOutput,
GetInternalStateReq,
GetInternalStateReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
2024-11-20 00:36:53 -08:00
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
RpcReqInput,
RpcReqOutput,
SetInternalStateReq,
SetInternalStateReqOutput,
SlowDownReqInput,
SlowDownReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
2024-11-29 17:17:00 -08:00
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
2025-03-25 11:08:40 +08:00
MultimodalInputs,
Req,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.managers.schedule_policy import (
AddReqResult,
PrefillAdder,
SchedulePolicy,
)
from sglang.srt.managers.scheduler_output_processor_mixin import (
SchedulerOutputProcessorMixin,
)
2024-11-20 00:36:53 -08:00
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
2025-05-10 21:54:46 -07:00
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.reasoning_parser import ReasoningParser
2024-09-29 02:36:12 -07:00
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import (
DeepEPMode,
DynamicGradMode,
broadcast_pyobj,
configure_logger,
2025-05-10 21:54:46 -07:00
disable_request_logging,
get_bool_env_var,
2024-10-25 23:07:07 -07:00
get_zmq_socket,
2025-03-12 22:22:39 -07:00
kill_itself_when_parent_died,
2025-04-30 18:18:07 -07:00
point_to_point_pyobj,
pyspy_dump_schedulers,
set_gpu_proc_affinity,
set_random_seed,
suppress_other_loggers,
)
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
2024-09-29 02:36:12 -07:00
logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
2024-09-29 02:36:12 -07:00
@dataclass
class GenerationBatchResult:
2025-04-30 18:18:07 -07:00
logits_output: Optional[LogitsProcessorOutput]
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
next_token_ids: Optional[List[int]]
extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int]
bid: int
can_run_cuda_graph: bool
@dataclass
class EmbeddingBatchResult:
embeddings: torch.Tensor
bid: int
class Scheduler(
SchedulerOutputProcessorMixin,
SchedulerDisaggregationDecodeMixin,
SchedulerDisaggregationPrefillMixin,
):
2024-09-29 02:36:12 -07:00
"""A scheduler that manages a tensor parallel GPU worker."""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
2025-04-30 18:18:07 -07:00
pp_rank: int,
dp_rank: Optional[int],
2024-09-29 02:36:12 -07:00
):
# Parse args
self.server_args = server_args
2024-09-29 02:36:12 -07:00
self.tp_rank = tp_rank
2025-04-30 18:18:07 -07:00
self.pp_rank = pp_rank
2024-09-29 02:36:12 -07:00
self.tp_size = server_args.tp_size
2025-04-30 18:18:07 -07:00
self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
2024-11-19 22:07:58 -08:00
self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics
self.enable_kv_cache_events = server_args.kv_events_config is not None
self.stream_interval = server_args.stream_interval
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
2025-03-12 22:22:39 -07:00
self.page_size = server_args.page_size
# Distributed rank info
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
self.tp_rank,
self.tp_size,
self.dp_size,
)
)
# Init inter-process communication
context = zmq.Context(2)
2025-04-30 18:18:07 -07:00
if self.pp_rank == 0 and self.attn_tp_rank == 0:
2024-10-25 23:07:07 -07:00
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
2024-10-25 23:07:07 -07:00
)
self.send_to_tokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
2024-09-29 02:36:12 -07:00
if server_args.skip_tokenizer_init:
# Directly send to the TokenizerManager
2024-10-25 23:07:07 -07:00
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
)
else:
# Send to the DetokenizerManager
2024-10-25 23:07:07 -07:00
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
2024-09-29 02:36:12 -07:00
else:
self.recv_from_tokenizer = None
self.recv_from_rpc = None
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
# Init tokenizer
self.init_tokenizer()
2024-09-29 02:36:12 -07:00
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
if self.server_args.reasoning_parser and self.tokenizer:
reasoning_parser = ReasoningParser(
model_type=self.server_args.reasoning_parser, stream_reasoning=False
)
self.tokenizer.think_end_id = self.tokenizer.encode(
reasoning_parser.detector.think_end_token, add_special_tokens=False
)[0]
2024-11-19 22:07:58 -08:00
# Check whether overlap can be enabled
if not self.is_generation:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.")
# Launch a tensor parallel worker
2024-10-20 18:17:41 -07:00
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
else:
TpWorkerClass = TpModelWorker
self.tp_worker = TpWorkerClass(
server_args=server_args,
2024-09-29 02:36:12 -07:00
gpu_id=gpu_id,
tp_rank=tp_rank,
2025-04-30 18:18:07 -07:00
pp_rank=pp_rank,
dp_rank=dp_rank,
2024-10-11 07:22:48 -07:00
nccl_port=port_args.nccl_port,
2024-09-29 02:36:12 -07:00
)
# Launch a draft worker for speculative decoding
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
self.draft_worker = EAGLEWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
# Get token and memory info from the model worker
(
self.max_total_num_tokens,
self.max_prefill_tokens,
self.max_running_requests,
self.max_req_len,
self.max_req_input_len,
self.random_seed,
self.device,
worker_global_server_args_dict,
_,
_,
_,
) = self.tp_worker.get_worker_info()
2025-04-30 18:18:07 -07:00
if global_server_args_dict["max_micro_batch_size"] is None:
global_server_args_dict["max_micro_batch_size"] = max(
self.max_running_requests // server_args.pp_size, 1
)
self.tp_group = self.tp_worker.get_tp_group()
self.tp_cpu_group = self.tp_group.cpu_group
self.attn_tp_group = self.tp_worker.get_attention_tp_group()
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
2025-04-30 18:18:07 -07:00
self.pp_group = get_pp_group()
self.world_group = get_world_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
# Print debug info
if tp_rank == 0:
logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}"
)
2025-03-12 22:22:39 -07:00
# Init memory pool and cache
self.init_memory_pool_and_cache()
# Init running status
self.waiting_queue: List[Req] = []
# The running decoding batch for continuous batching
2025-03-12 22:22:39 -07:00
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
# The current forward batch
2024-10-16 01:33:20 -07:00
self.cur_batch: Optional[ScheduleBatch] = None
2025-03-12 22:22:39 -07:00
# The last forward batch
self.last_batch: Optional[ScheduleBatch] = None
2024-10-27 02:00:50 -07:00
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
2025-03-12 22:22:39 -07:00
self.num_prefill_tokens = 0
self.last_decode_stats_tic = time.perf_counter()
self.last_prefill_stats_tic = time.perf_counter()
self.return_health_check_ct = 0
2024-12-06 05:49:29 -08:00
self.current_stream = torch.get_device_module(self.device).current_stream()
2025-01-17 13:22:53 +08:00
if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU
2024-12-06 05:49:29 -08:00
# Init session info
self.sessions: Dict[str, Session] = {}
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
if self.chunked_prefill_size <= 0: # -1 means disable
self.chunked_prefill_size = None
self.chunked_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
2024-11-12 21:17:38 -08:00
# Init the grammar backend for constrained generation
self.grammar_queue: List[Req] = []
if not server_args.skip_tokenizer_init:
self.grammar_backend = create_grammar_backend(
server_args, self.tokenizer, self.model_config.vocab_size
)
2024-11-12 21:17:38 -08:00
else:
self.grammar_backend = None
# Init schedule policy and new token estimation
self.policy = SchedulePolicy(
2025-03-12 22:22:39 -07:00
self.schedule_policy,
self.tree_cache,
self.enable_hierarchical_cache,
)
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.init_new_token_ratio = min(
global_config.default_init_new_token_ratio
* server_args.schedule_conservativeness,
1.0,
)
self.min_new_token_ratio = min(
self.init_new_token_ratio
* global_config.default_min_new_token_ratio_factor,
1.0,
)
self.new_token_ratio_decay = (
self.init_new_token_ratio - self.min_new_token_ratio
) / global_config.default_new_token_ratio_decay_steps
self.new_token_ratio = self.init_new_token_ratio
2024-10-27 02:00:50 -07:00
# Init watchdog thread
self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
self.parent_process = psutil.Process().parent()
2024-10-27 02:00:50 -07:00
# Init memory saver
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
# Init profiler
self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None
self.profiler_activities: Optional[List[str]] = None
self.profiler_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None
self.forward_sleep_time = None
# Init metrics stats
self.init_metrics()
self.init_kv_events(server_args.kv_events_config)
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped),
(AbortReq, self.abort_request),
(OpenSessionReqInput, self.open_session),
(CloseSessionReqInput, self.close_session),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
(
UpdateWeightsFromDistributedReqInput,
self.update_weights_from_distributed,
),
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(GetWeightsByNameReqInput, self.get_weights_by_name),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(SlowDownReqInput, self.slow_down),
(ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request),
(ExpertDistributionReq, self.expert_distribution_handle),
]
)
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
def init_tokenizer(self):
server_args = self.server_args
2024-10-27 02:00:50 -07:00
self.model_config = ModelConfig.from_server_args(server_args)
self.is_generation = self.model_config.is_generation
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
)
self.tokenizer = get_tokenizer_from_processor(self.processor)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
def init_memory_pool_and_cache(self):
server_args = self.server_args
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
self.tp_worker.get_memory_pool()
)
if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
)
else:
if self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_cpu_group,
page_size=self.page_size,
hicache_ratio=server_args.hicache_ratio,
2025-04-20 23:08:30 -07:00
hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
2025-03-12 22:22:39 -07:00
page_size=self.page_size,
disable=server_args.disable_radix_cache,
enable_kv_cache_events=self.enable_kv_cache_events,
)
self.decode_mem_cache_buf_multiplier = (
1
if self.spec_algorithm.is_none()
else (
server_args.speculative_num_draft_tokens
+ (
server_args.speculative_eagle_topk
* server_args.speculative_num_steps
)
)
)
def init_metrics(self):
self.last_gen_throughput: float = 0.0
2025-03-12 22:22:39 -07:00
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.stats = SchedulerStats()
if self.enable_metrics:
engine_type = "unified"
self.metrics_collector = SchedulerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
"engine_type": engine_type,
},
)
2024-10-27 02:00:50 -07:00
def init_kv_events(self, kv_events_config: Optional[str]):
if self.enable_kv_cache_events:
self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)
def init_disaggregation(self):
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
if (
self.disaggregation_mode == DisaggregationMode.DECODE
): # *2 for the headroom.
buffer_size = (self.req_to_token_pool.size) * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
# The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.attn_tp_cpu_group,
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=self.disagg_metadata_buffers,
2025-05-22 20:32:03 -07:00
scheduler=self,
tree_cache=self.tree_cache,
)
# The decode requests pending for pre-allocation
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
draft_token_to_kv_pool=(
None
if self.draft_worker is None
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=self.disagg_metadata_buffers,
scheduler=self,
transfer_queue=self.disagg_decode_transfer_queue,
tree_cache=self.tree_cache,
gloo_group=self.attn_tp_cpu_group,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
transfer_backend=self.transfer_backend,
)
2025-04-25 17:25:45 +08:00
# Metric for pre-allocation
self.num_tokens_pre_allocated = 0
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
# *2 for the headroom.
buffer_size = self.max_running_requests * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
2025-04-25 17:25:45 +08:00
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
draft_token_to_kv_pool=(
None
if self.draft_worker is None
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=self.disagg_metadata_buffers,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.attn_tp_cpu_group,
transfer_backend=self.transfer_backend,
scheduler=self,
)
# The prefill requests that are in the middle of kv sending
2025-04-14 01:25:30 +08:00
self.disagg_prefill_inflight_queue: List[Req] = []
@DynamicGradMode()
def event_loop_normal(self):
"""A normal scheduler loop."""
2024-09-29 02:36:12 -07:00
while True:
2024-10-06 03:24:04 -07:00
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
2024-09-29 02:36:12 -07:00
batch = self.get_next_batch_to_run()
2024-10-27 02:00:50 -07:00
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
2024-10-14 05:25:00 -07:00
else:
2025-03-12 22:22:39 -07:00
# When the server is idle, do self-check and re-init some states
2024-10-14 05:25:00 -07:00
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
2024-09-29 02:36:12 -07:00
@DynamicGradMode()
2024-10-16 01:33:20 -07:00
def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
self.result_queue = deque()
2024-10-16 01:33:20 -07:00
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
self.cur_batch = batch
2024-10-16 01:33:20 -07:00
if batch:
2025-04-28 11:19:16 +08:00
batch.launch_done = threading.Event()
2024-10-16 01:33:20 -07:00
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
2024-10-16 01:33:20 -07:00
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
2025-04-28 11:19:16 +08:00
self.process_batch_result(tmp_batch, None, batch.launch_done)
2024-10-16 01:33:20 -07:00
if self.last_batch:
# Process the results of the last batch
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
2025-04-28 11:19:16 +08:00
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
self.process_batch_result(
tmp_batch, tmp_result, batch.launch_done if batch else None
)
2024-10-16 01:33:20 -07:00
elif batch is None:
2025-03-12 22:22:39 -07:00
# When the server is idle, do self-check and re-init some states
2024-10-16 01:33:20 -07:00
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
2024-10-16 01:33:20 -07:00
self.last_batch = batch
2025-04-30 18:18:07 -07:00
@DynamicGradMode()
def event_loop_pp(self):
"""A non-overlap scheduler loop for pipeline parallelism."""
mbs = [None] * self.pp_size
last_mbs = [None] * self.pp_size
self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
]
bids = [None] * self.pp_size
pp_outputs: Optional[PPProxyTensors] = None
while True:
server_is_idle = True
for mb_id in range(self.pp_size):
self.running_batch = self.running_mbs[mb_id]
self.last_batch = last_mbs[mb_id]
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
mbs[mb_id] = self.get_next_batch_to_run()
self.running_mbs[mb_id] = self.running_batch
self.cur_batch = mbs[mb_id]
if self.cur_batch:
server_is_idle = False
result = self.run_batch(self.cur_batch)
# (last rank) send the outputs to the next step
2025-04-30 18:18:07 -07:00
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
result.next_token_ids,
result.bid,
)
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
}
)
# send the output from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
# receive outputs and post-process (filter finished reqs) the coming microbatch
next_mb_id = (mb_id + 1) % self.pp_size
next_pp_outputs = None
if mbs[next_mb_id] is not None:
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
self.pp_group.recv_tensor_dict(
all_gather_group=self.attn_tp_group
)
)
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
output_result = GenerationBatchResult(
logits_output=None,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=None,
extend_logprob_start_len_per_req=None,
bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph,
2025-04-30 18:18:07 -07:00
)
self.process_batch_result(mbs[next_mb_id], output_result)
last_mbs[next_mb_id] = mbs[next_mb_id]
# (not last rank)
2025-04-30 18:18:07 -07:00
if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
2025-04-30 18:18:07 -07:00
if pp_outputs:
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
# send out reqs to the next stage
dp_offset = self.attn_dp_rank * self.attn_tp_size
2025-04-30 18:18:07 -07:00
if self.attn_tp_rank == 0:
point_to_point_pyobj(
recv_reqs,
self.pp_rank * self.tp_size + dp_offset,
self.world_group.cpu_group,
self.pp_rank * self.tp_size + dp_offset,
(self.pp_rank + 1) * self.tp_size + dp_offset,
)
# send out proxy tensors to the next stage
if self.cur_batch:
self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors,
all_gather_group=self.attn_tp_group,
)
pp_outputs = next_pp_outputs
# When the server is idle, self-check and re-init some states
if server_is_idle:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
2025-04-30 18:18:07 -07:00
if self.pp_rank == 0:
if self.attn_tp_rank == 0:
recv_reqs = []
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
while True:
try:
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_rpc)
else:
recv_reqs = None
2024-10-06 03:24:04 -07:00
else:
2025-04-30 18:18:07 -07:00
if self.attn_tp_rank == 0:
dp_offset = self.attn_dp_rank * self.attn_tp_size
2025-04-30 18:18:07 -07:00
recv_reqs = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
self.world_group.cpu_group,
(self.pp_rank - 1) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset,
)
else:
recv_reqs = None
2024-09-29 02:36:12 -07:00
if self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0:
work_reqs = [
req
for req in recv_reqs
if isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
)
]
control_reqs = [
req
for req in recv_reqs
if not isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
)
]
else:
work_reqs = None
control_reqs = None
if self.attn_tp_size != 1:
work_reqs = broadcast_pyobj(
work_reqs,
2025-04-30 18:18:07 -07:00
self.attn_tp_group.rank,
self.attn_tp_cpu_group,
2025-04-30 18:18:07 -07:00
src=self.attn_tp_group.ranks[0],
)
if self.tp_size != 1:
control_reqs = broadcast_pyobj(
2025-04-30 18:18:07 -07:00
control_reqs,
self.tp_group.rank,
self.tp_cpu_group,
src=self.tp_group.ranks[0],
)
recv_reqs = work_reqs + control_reqs
elif self.tp_size != 1:
2025-04-30 18:18:07 -07:00
recv_reqs = broadcast_pyobj(
recv_reqs,
self.tp_group.rank,
self.tp_cpu_group,
src=self.tp_group.ranks[0],
)
2024-09-29 02:36:12 -07:00
return recv_reqs
2024-10-06 03:24:04 -07:00
def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs:
# If it is a health check generation request and there are running requests, ignore it.
if is_health_check_generate_req(recv_req) and (
2025-03-12 22:22:39 -07:00
self.chunked_req is not None or not self.running_batch.is_empty()
):
self.return_health_check_ct += 1
continue
output = self._request_dispatcher(recv_req)
if output is not None:
if isinstance(output, RpcReqOutput):
if self.recv_from_rpc is not None:
self.recv_from_rpc.send_pyobj(output)
else:
self.send_to_tokenizer.send_pyobj(output)
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
# Create a new request
if (
recv_req.session_params is None
or recv_req.session_params.id is None
or recv_req.session_params.id not in self.sessions
):
2024-11-25 19:35:04 -05:00
if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds
seq_length = len(recv_req.input_embeds)
fake_input_ids = [1] * seq_length
recv_req.input_ids = fake_input_ids
2025-05-07 01:12:57 +08:00
if recv_req.bootstrap_port is None:
# Use default bootstrap port
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
2024-11-20 00:36:53 -08:00
req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
2024-12-08 12:27:13 -08:00
return_logprob=recv_req.return_logprob,
top_logprobs_num=recv_req.top_logprobs_num,
token_ids_logprob=recv_req.token_ids_logprob,
2024-12-08 12:27:13 -08:00
stream=recv_req.stream,
2024-11-20 00:36:53 -08:00
lora_path=recv_req.lora_path,
2024-11-25 19:35:04 -05:00
input_embeds=recv_req.input_embeds,
2025-05-10 21:54:46 -07:00
custom_logit_processor=recv_req.custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
bootstrap_port=recv_req.bootstrap_port,
bootstrap_room=recv_req.bootstrap_room,
2024-11-20 00:36:53 -08:00
)
req.tokenizer = self.tokenizer
2024-11-25 16:38:43 -08:00
if self.disaggregation_mode != DisaggregationMode.NULL:
# Invalid request for disaggregated mode
if recv_req.bootstrap_room is None:
error_message = (
f"Invalid request: Disaggregated request received without "
f"boostrap room id. {req.rid=}"
)
logger.error(error_message)
prepare_abort(req, error_message)
self.stream_output([req], req.return_logprob)
return
if (
recv_req.session_params is not None
and recv_req.session_params.id is not None
):
2024-11-20 00:36:53 -08:00
req.finished_reason = FINISH_ABORT(
f"Invalid request: session id {recv_req.session_params.id} does not exist"
2024-11-20 00:36:53 -08:00
)
self._add_request_to_queue(req)
2024-11-20 00:36:53 -08:00
return
else:
# Create a new request from a previous session
session = self.sessions[recv_req.session_params.id]
req = session.create_req(recv_req, self.tokenizer)
2024-11-20 00:36:53 -08:00
if isinstance(req.finished_reason, FINISH_ABORT):
self._add_request_to_queue(req)
2024-11-20 00:36:53 -08:00
return
# Handle multimodal inputs
2025-03-25 11:08:40 +08:00
if recv_req.mm_inputs is not None:
image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
)
req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) >= self.max_req_input_len:
error_msg = (
"Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
)
logger.error(error_msg)
2024-11-28 02:22:15 -08:00
req.origin_input_ids = [0]
2025-03-25 11:08:40 +08:00
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
self._add_request_to_queue(req)
return
# Validate prompts length
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
if error_msg:
req.origin_input_ids = [0]
req.sampling_params.max_new_tokens = 0
self._add_request_to_queue(req)
return
# Copy more attributes
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1
else:
req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len >= len(req.origin_input_ids):
req.finished_reason = FINISH_ABORT(
f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
HTTPStatus.BAD_REQUEST,
"BadRequestError",
)
req.logprob_start_len = len(req.origin_input_ids) - 1
self._add_request_to_queue(req)
return
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_len - len(req.origin_input_ids) - 1,
)
# Init grammar cache for this request
add_to_grammar_queue = False
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
or req.sampling_params.structural_tag is not None
):
assert self.grammar_backend is not None
if req.sampling_params.json_schema is not None:
key = ("json", req.sampling_params.json_schema)
elif req.sampling_params.regex is not None:
key = ("regex", req.sampling_params.regex)
elif req.sampling_params.ebnf is not None:
key = ("ebnf", req.sampling_params.ebnf)
elif req.sampling_params.structural_tag:
key = ("structural_tag", req.sampling_params.structural_tag)
value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
req.grammar = value
if not cache_hit:
req.grammar_key = key
add_to_grammar_queue = True
if add_to_grammar_queue:
req.queue_time_start = time.perf_counter()
self.grammar_queue.append(req)
else:
self._add_request_to_queue(req)
def _add_request_to_queue(self, req: Req):
req.queue_time_start = time.perf_counter()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
2025-04-25 17:25:45 +08:00
self.disagg_prefill_bootstrap_queue.add(req)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
self.waiting_queue.append(req)
2025-05-23 21:49:00 -07:00
def _extend_requests_to_queue(self, reqs: List[Req]):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.extend(reqs)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
# If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs)
else:
self.waiting_queue.extend(reqs)
def handle_embedding_request(
self,
2024-11-03 08:38:26 -08:00
recv_req: TokenizedEmbeddingReqInput,
):
req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
)
req.tokenizer = self.tokenizer
# Handle multimodal inputs
if recv_req.image_inputs is not None:
2025-03-25 11:08:40 +08:00
image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
)
req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) >= self.max_req_input_len:
error_msg = (
"Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
)
logger.error(error_msg)
req.origin_input_ids = [0]
2025-03-25 11:08:40 +08:00
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
req.queue_time_start = time.perf_counter()
self.waiting_queue.append(req)
return
# Validate prompts length
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
if error_msg:
self._add_request_to_queue(req)
return
# Copy more attributes
req.logprob_start_len = len(req.origin_input_ids) - 1
self._add_request_to_queue(req)
def log_prefill_stats(
self,
adder: PrefillAdder,
can_run_list: List[Req],
running_bs: int,
):
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.perf_counter()
2025-03-12 22:22:39 -07:00
self.last_input_throughput = self.num_prefill_tokens / gap_latency
self.num_prefill_tokens = 0
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
2025-04-09 17:19:27 -07:00
num_new_seq = len(can_run_list)
f = (
f"Prefill batch. "
2025-04-09 17:19:27 -07:00
f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
)
2025-04-25 17:25:45 +08:00
if self.disaggregation_mode == DisaggregationMode.PREFILL:
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
f += f"#queue-req: {len(self.waiting_queue)}, "
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
else:
f += f"#queue-req: {len(self.waiting_queue)}"
logger.info(f)
if self.enable_metrics:
cache_hit_rate = adder.log_hit_tokens / (
adder.log_input_tokens + adder.log_hit_tokens
)
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate
2025-04-09 17:19:27 -07:00
total_queue_latency = 0
for req in can_run_list:
total_queue_latency += req.queue_time_end - req.queue_time_start
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def log_decode_stats(
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
):
2025-04-30 18:18:07 -07:00
batch = running_batch or self.running_batch
gap_latency = time.perf_counter() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.perf_counter()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
2025-04-30 18:18:07 -07:00
num_running_reqs = len(batch.reqs)
2024-10-06 03:24:04 -07:00
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
2024-10-06 03:24:04 -07:00
)
if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval
)
2024-10-06 03:24:04 -07:00
2025-04-25 17:25:45 +08:00
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
)
if self.spec_algorithm.is_none():
spec_accept_length = 0
else:
spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
self.cum_spec_accept_count += self.spec_num_total_forward_ct
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
2025-04-25 17:25:45 +08:00
msg += f"accept len: {spec_accept_length:.2f}, "
if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += (
f"cuda graph: {can_run_cuda_graph}, "
2025-04-25 17:25:45 +08:00
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
logger.info(msg)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.stats.spec_accept_length = spec_accept_length
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
2024-10-06 03:24:04 -07:00
def check_memory(self):
available_size = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
2024-10-06 03:24:04 -07:00
)
protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
self.max_total_num_tokens
if not self.enable_hierarchical_cache
else self.max_total_num_tokens - protected_size
)
if memory_leak:
msg = (
2025-03-30 00:46:23 -07:00
"token_to_kv_pool_allocator memory leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
2025-03-12 22:22:39 -07:00
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
2024-10-06 03:24:04 -07:00
)
2025-05-10 21:54:46 -07:00
raise ValueError(msg)
2024-10-06 03:24:04 -07:00
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
msg = (
2025-03-30 00:46:23 -07:00
"req_to_token_pool memory leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n"
2024-10-06 03:24:04 -07:00
)
2025-05-10 21:54:46 -07:00
raise ValueError(msg)
2024-10-06 03:24:04 -07:00
if (
self.enable_metrics
and self.attn_tp_rank == 0
and time.perf_counter() > self.metrics_collector.last_log_time + 30
):
# During idle time, also collect metrics every 30 seconds.
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
2025-03-12 22:22:39 -07:00
num_running_reqs = len(self.running_batch.reqs)
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.gen_throughput = 0
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
2025-04-30 18:18:07 -07:00
chunked_req_to_exclude = set()
if self.chunked_req:
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
2024-11-24 04:47:10 -08:00
if self.last_batch and self.last_batch.forward_mode.is_extend():
2025-04-30 18:18:07 -07:00
if self.last_batch.chunked_req is not None:
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
# We need to discard it.
chunked_req_to_exclude.add(self.last_batch.chunked_req)
2024-11-24 04:47:10 -08:00
# Filter batch
last_bs = self.last_batch.batch_size()
2025-04-30 18:18:07 -07:00
self.last_batch.filter_batch(
chunked_req_to_exclude=list(chunked_req_to_exclude)
)
if self.last_batch.batch_size() < last_bs:
2025-03-12 22:22:39 -07:00
self.running_batch.batch_is_full = False
# Merge the new batch into the running batch
if not self.last_batch.is_empty():
2025-03-12 22:22:39 -07:00
if self.running_batch.is_empty():
self.running_batch = self.last_batch
else:
2025-03-12 22:22:39 -07:00
# Merge running_batch with prefill batch
self.running_batch.merge_batch(self.last_batch)
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# Run prefill first if possible
ret = new_batch
else:
# Run decode
2025-03-12 22:22:39 -07:00
if not self.running_batch.is_empty():
self.running_batch = self.update_running_batch(self.running_batch)
2025-03-12 22:22:39 -07:00
ret = self.running_batch if not self.running_batch.is_empty() else None
else:
ret = None
# Handle DP attention
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
ret, _ = self.prepare_dp_attn_batch(ret)
return ret
2025-04-30 18:18:07 -07:00
def get_num_allocatable_reqs(self, running_bs):
res = global_server_args_dict["max_micro_batch_size"] - running_bs
if self.pp_size > 1:
res = min(res, self.req_to_token_pool.available_size())
return res
2024-10-06 03:24:04 -07:00
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
2024-11-12 21:17:38 -08:00
# Check if the grammar is ready in the grammar queue
if self.grammar_queue:
self.move_ready_grammar_requests()
2024-10-06 03:24:04 -07:00
# Handle the cases where prefill is not allowed
if (
2025-03-12 22:22:39 -07:00
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
) and self.chunked_req is None:
2024-10-06 03:24:04 -07:00
return None
2025-03-12 22:22:39 -07:00
running_bs = len(self.running_batch.reqs)
# Ignore the check if self.chunked_req is not None.
2025-04-30 18:18:07 -07:00
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
# as the space for the chunked request has just been released.
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
2025-03-12 22:22:39 -07:00
self.running_batch.batch_is_full = True
return None
if self.enable_hierarchical_cache:
# check for completion of hierarchical cache activities to release memory
self.tree_cache.writing_check()
self.tree_cache.loading_check()
# Get priority queue
prefix_computed = self.policy.calc_priority(self.waiting_queue)
2024-10-06 03:24:04 -07:00
# Prefill policy
adder = PrefillAdder(
self.tree_cache,
self.token_to_kv_pool_allocator,
self.running_batch,
self.new_token_ratio,
self.max_prefill_tokens,
self.chunked_prefill_size,
running_bs if self.is_mixed_chunk else 0,
)
2025-03-12 22:22:39 -07:00
if self.chunked_req is not None:
self.chunked_req.init_next_round_input()
self.chunked_req = adder.add_chunked_req(self.chunked_req)
2024-10-14 05:25:00 -07:00
if self.lora_paths:
2025-03-12 22:22:39 -07:00
lora_set = set([req.lora_path for req in self.running_batch.reqs])
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
2024-10-14 05:25:00 -07:00
self.lora_paths
and len(
lora_set
| set([req.lora_path for req in adder.can_run_list])
| set([req.lora_path])
)
> self.max_loras_per_batch
):
2025-03-12 22:22:39 -07:00
self.running_batch.batch_is_full = True
break
2025-04-30 18:18:07 -07:00
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
2025-03-12 22:22:39 -07:00
self.running_batch.batch_is_full = True
break
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# In prefill mode, prealloc queue and transfer queue can also take memory,
# so we need to check if the available size for the actual available size.
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
self.running_batch.batch_is_full = True
break
req.init_next_round_input(
None if prefix_computed else self.tree_cache,
self.enable_hierarchical_cache,
)
res = adder.add_one_req(
req, self.chunked_req, self.enable_hierarchical_cache
)
2025-04-30 18:18:07 -07:00
if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache:
# Set batch_is_full after making sure there are requests that can be served
2025-03-12 22:22:39 -07:00
self.running_batch.batch_is_full = len(
adder.can_run_list
2025-04-30 18:18:07 -07:00
) > 0 or (not self.running_batch.is_empty())
else:
2025-03-12 22:22:39 -07:00
self.running_batch.batch_is_full = True
break
2024-10-14 05:25:00 -07:00
# Update waiting queue
can_run_list: List[Req] = adder.can_run_list
2024-10-14 05:25:00 -07:00
if len(can_run_list) == 0:
return None
2025-04-09 17:19:27 -07:00
if self.enable_metrics:
# only record queue time when enable_metrics is True to avoid overhead
for req in can_run_list:
req.queue_time_end = time.perf_counter()
2025-04-09 17:19:27 -07:00
2024-10-14 05:25:00 -07:00
self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list)
]
if self.enable_hierarchical_cache:
self.tree_cache.ready_to_load_cache()
if adder.new_chunked_req is not None:
assert self.chunked_req is None
self.chunked_req = adder.new_chunked_req
if self.chunked_req:
self.chunked_req.is_chunked += 1
2024-10-06 03:24:04 -07:00
# Print stats
if self.attn_tp_rank == 0:
self.log_prefill_stats(adder, can_run_list, running_bs)
2024-10-06 03:24:04 -07:00
# Create a new batch
new_batch = ScheduleBatch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
2024-10-21 15:01:21 -07:00
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
2025-04-30 18:18:07 -07:00
chunked_req=self.chunked_req,
)
new_batch.prepare_for_extend()
2024-10-06 03:24:04 -07:00
# Mixed-style chunked prefill
if (
self.is_mixed_chunk
2025-03-12 22:22:39 -07:00
and not self.running_batch.is_empty()
and not (new_batch.return_logprob or self.running_batch.return_logprob)
):
# TODO (lianmin): support return_logprob + mixed chunked prefill
2024-10-30 21:20:41 -07:00
self.running_batch.filter_batch()
if not self.running_batch.is_empty():
self.running_batch.prepare_for_decode()
2024-10-30 21:20:41 -07:00
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
2025-03-12 22:22:39 -07:00
self.running_batch = ScheduleBatch(
reqs=[], batch_is_full=self.running_batch.batch_is_full
)
2024-10-14 05:25:00 -07:00
else:
new_batch.decoding_reqs = None
2024-10-06 03:24:04 -07:00
return new_batch
2024-11-24 04:47:10 -08:00
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
"""Update the current running decoding batch."""
2024-11-24 04:47:10 -08:00
initial_bs = batch.batch_size()
2024-10-06 03:24:04 -07:00
batch.filter_batch()
if batch.is_empty():
2025-03-12 22:22:39 -07:00
batch.batch_is_full = False
return batch
2024-10-06 03:24:04 -07:00
# Check if decode out of memory
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
TEST_RETRACT and batch.batch_size() > 10
):
2024-10-06 03:24:04 -07:00
old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
2024-10-06 03:24:04 -07:00
self.new_token_ratio = new_token_ratio
2024-10-06 03:24:04 -07:00
logger.info(
"Decode out of memory happened. "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self._extend_requests_to_queue(retracted_reqs)
2024-10-06 03:24:04 -07:00
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
2024-10-06 03:24:04 -07:00
self.min_new_token_ratio,
)
2024-11-24 04:47:10 -08:00
if batch.batch_size() < initial_bs:
2025-03-12 22:22:39 -07:00
batch.batch_is_full = False
2024-10-06 03:24:04 -07:00
# Update batch tensors
batch.prepare_for_decode()
2024-11-24 04:47:10 -08:00
return batch
2024-10-06 03:24:04 -07:00
def run_batch(
self, batch: ScheduleBatch
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
"""Run a batch."""
2024-10-27 02:00:50 -07:00
self.forward_ct += 1
# Check profiler
if (
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.send_to_tokenizer.send_pyobj(self.stop_profile())
if self.forward_sleep_time is not None:
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
time.sleep(self.forward_sleep_time)
# Run forward
if self.is_generation:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
2025-04-30 18:18:07 -07:00
if self.pp_group.is_last_rank:
logits_output, next_token_ids, can_run_cuda_graph = (
2025-04-30 18:18:07 -07:00
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
2025-04-30 18:18:07 -07:00
self.tp_worker.forward_batch_generation(model_worker_batch)
)
bid = model_worker_batch.bid
2024-10-06 03:24:04 -07:00
else:
(
logits_output,
next_token_ids,
bid,
num_accepted_tokens,
can_run_cuda_graph,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
2025-04-30 18:18:07 -07:00
if self.pp_group.is_last_rank:
batch.output_ids = next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
if batch.return_logprob:
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
extend_logprob_start_len_per_req = [
req.extend_logprob_start_len for req in batch.reqs
]
else:
extend_input_len_per_req = None
extend_logprob_start_len_per_req = None
ret = GenerationBatchResult(
2025-04-30 18:18:07 -07:00
logits_output=logits_output if self.pp_group.is_last_rank else None,
pp_hidden_states_proxy_tensors=(
pp_hidden_states_proxy_tensors
if not self.pp_group.is_last_rank
else None
),
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=bid,
can_run_cuda_graph=can_run_cuda_graph,
)
2024-10-06 03:24:04 -07:00
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult(
embeddings=embeddings, bid=model_worker_batch.bid
)
return ret
2024-11-07 15:42:47 -08:00
def process_batch_result(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
2025-04-28 11:19:16 +08:00
launch_done: Optional[threading.Event] = None,
):
2024-10-06 03:24:04 -07:00
if batch.forward_mode.is_decode():
2025-04-28 11:19:16 +08:00
self.process_batch_result_decode(batch, result, launch_done)
elif batch.forward_mode.is_extend():
2025-04-28 11:19:16 +08:00
self.process_batch_result_prefill(batch, result, launch_done)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
2025-04-28 11:19:16 +08:00
self.tp_worker.resolve_last_batch_result(launch_done)
self.set_next_batch_sampling_info_done(batch)
elif batch.forward_mode.is_dummy_first():
self.set_next_batch_sampling_info_done(batch)
2024-10-06 03:24:04 -07:00
if self.return_health_check_ct:
# Return some signal for the health check.
# This is used to prevent the health check signal being blocked by long context prefill.
# However, one minor issue is that this code path does not check the status of detokenizer manager.
self.return_health_check_ct -= 1
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
2024-12-06 05:49:29 -08:00
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
return self.prepare_dp_attn_batch_raw(
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
)
@staticmethod
def prepare_dp_attn_batch_raw(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
moe_dense_tp_size: Optional[int],
tp_cpu_group,
get_idle_batch,
disable_cuda_graph: bool,
spec_algorithm,
speculative_num_draft_tokens,
enable_two_batch_overlap: bool,
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
):
2024-12-06 05:49:29 -08:00
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
num_tokens_for_logprob = 0
2024-12-06 05:49:29 -08:00
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
num_tokens_for_logprob = num_tokens
2024-12-06 05:49:29 -08:00
else:
num_tokens = local_batch.extend_num_tokens
num_tokens_for_logprob = sum(
[
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
for logprob_start_len, extend_len in zip(
local_batch.extend_logprob_start_lens, local_batch.extend_lens
)
]
)
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
can_cuda_graph = 1
else:
can_cuda_graph = 0
if not spec_algorithm.is_none():
2025-05-12 12:53:26 -07:00
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0
2024-12-06 05:49:29 -08:00
is_extend_in_batch = (
local_batch.forward_mode.is_extend() if local_batch else False
)
tbo_preparer = TboDPAttentionPreparer()
local_info = torch.tensor(
[
num_tokens,
can_cuda_graph,
num_tokens_for_logprob,
is_extend_in_batch,
*tbo_preparer.prepare_all_gather(
local_batch,
deepep_mode,
enable_deepep_moe,
enable_two_batch_overlap,
),
],
dtype=torch.int64,
)
global_info = torch.empty(
(dp_size, attn_tp_size, 6),
dtype=torch.int64,
)
2024-12-06 05:49:29 -08:00
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=tp_cpu_group,
2024-12-06 05:49:29 -08:00
)
global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist())
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()
2024-12-06 05:49:29 -08:00
tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
global_info[:, :, 4:6]
)
if local_batch is None and max(global_num_tokens) > 0:
local_batch = get_idle_batch()
2024-12-06 05:49:29 -08:00
if local_batch is not None:
# TODO: handle the case when moe_dense_tp_size != 1
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
local_batch.global_num_tokens = [num_tokens]
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
else:
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)
local_batch.tbo_split_seq_index = tbo_split_seq_index
local_batch.global_forward_mode = global_forward_mode
2024-12-06 05:49:29 -08:00
2025-05-12 12:53:26 -07:00
# Check forward mode for cuda graph
if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph
2024-12-06 05:49:29 -08:00
return local_batch, any(is_extend_in_batch)
2024-12-06 05:49:29 -08:00
def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
2024-12-06 05:49:29 -08:00
self.tree_cache,
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
2024-12-06 05:49:29 -08:00
)
idle_batch.prepare_for_idle()
return idle_batch
def move_ready_grammar_requests(self):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
2025-04-30 18:18:07 -07:00
num_ready_reqs = 0
num_abort_reqs = 0
for req in self.grammar_queue:
try:
req.grammar = req.grammar.result(timeout=0.03)
if req.grammar:
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
num_ready_reqs += 1
except futures._base.TimeoutError:
req.grammar_wait_ct += 1
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
num_abort_reqs = 1
break
if self.server_args.enable_dp_attention:
tp_size = self.attn_tp_size
tp_group = self.attn_tp_cpu_group
else:
tp_size = self.tp_size
tp_group = self.tp_cpu_group
if tp_size > 1:
# Sync across TP ranks to make sure they have the same number of ready requests
tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
)
num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
for i in range(num_ready_reqs, num_ready_reqs_max):
req = self.grammar_queue[i]
req.grammar = req.grammar.result()
if req.grammar:
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
req = self.grammar_queue[i]
req.grammar.cancel()
req.grammar = None
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
logger.error(error_msg)
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
if batch.next_batch_sampling_info:
if batch.next_batch_sampling_info.grammars is not None:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.perf_counter()
while True:
current = time.perf_counter()
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if current > self.watchdog_last_time + self.watchdog_timeout:
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = current
time.sleep(self.watchdog_timeout // 2)
2025-05-10 21:54:46 -07:00
if not disable_request_logging():
# Print batch size and memory pool info to check whether there are de-sync issues.
logger.error(
f"{self.cur_batch.batch_size()=}, "
f"{self.cur_batch.reqs=}, "
f"{self.token_to_kv_pool_allocator.available_size()=}, "
f"{self.tree_cache.evictable_size()=}, "
)
pyspy_dump_schedulers()
2025-05-10 21:54:46 -07:00
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
print(file=sys.stderr, flush=True)
print(file=sys.stdout, flush=True)
2025-05-10 21:54:46 -07:00
# Wait for some time so that the parent process can print the error.
time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT)
def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
success = self.flush_cache()
return FlushCacheReqOutput(success=success)
def flush_cache(self):
"""Flush the memory pool and cache."""
2025-04-30 18:18:07 -07:00
if (
len(self.waiting_queue) == 0
and self.running_batch.is_empty()
and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
):
self.cur_batch = None
self.last_batch = None
self.tree_cache.reset()
if self.grammar_backend:
2024-11-12 21:17:38 -08:00
self.grammar_backend.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool_allocator.clear()
if not self.spec_algorithm.is_none():
self.draft_worker.model_runner.req_to_token_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
self.num_generated_tokens = 0
self.forward_ct_decode = 0
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
torch.cuda.empty_cache()
logger.info("Cache flushed successfully!")
if_success = True
else:
logging.warning(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.waiting_queue)}, "
2025-03-12 22:22:39 -07:00
f"#running-req: {len(self.running_batch.reqs)}"
)
if_success = False
return if_success
def get_load(self):
# TODO(lsyin): use dynamically maintained num_waiting_tokens
load = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.evictable_size()
)
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
load += sum(
len(req.origin_input_ids)
for req in self.disagg_prefill_bootstrap_queue.queue
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
load += sum(
len(req.req.origin_input_ids)
for req in self.disagg_decode_prealloc_queue.queue
)
return load
def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
ret["last_gen_throughput"] = self.last_gen_throughput
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
ret["avg_spec_accept_length"] = (
self.cum_spec_accept_length / self.cum_spec_accept_count
)
if RECORD_STEP_TIME:
ret["step_time_dict"] = self.step_time_dict
ret["load"] = self.get_load()
return GetInternalStateReqOutput(internal_state=ret)
def set_internal_state(self, recv_req: SetInternalStateReq):
server_args_dict = recv_req.server_args
args_allow_update = set(
[
2025-04-30 18:18:07 -07:00
"max_micro_batch_size",
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
]
)
if_success = True
for k, v in server_args_dict.items():
if k not in args_allow_update:
logging.warning(f"Updating {k} is not supported.")
if_success = False
break
2025-04-30 18:18:07 -07:00
elif k == "max_micro_batch_size" and (
v > self.max_running_requests // self.pp_size or v < 1
):
logging.warning(
f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
)
if_success = False
break
if if_success:
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
avg_spec_accept_length = (
self.cum_spec_accept_length / self.cum_spec_accept_count
)
logger.info(f"{avg_spec_accept_length=}")
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
for k, v in server_args_dict.items():
global_server_args_dict[k] = v
logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
return SetInternalStateReqOutput(
updated=True,
server_args=global_server_args_dict,
)
def handle_rpc_request(self, recv_req: RpcReqInput):
# Handle RPC requests
logger.info(
f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
)
success = True
exec = None
try:
func = getattr(self, recv_req.method)
func(recv_req.parameters)
except Exception as e:
success = False
exec = e
logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
barrier()
return RpcReqOutput(success, "" if not exec else str(exec))
def save_remote_model(self, params):
url = params["url"]
worker = self.tp_worker.worker
worker.model_runner.save_remote_model(url)
def save_sharded_model(self, params):
worker = self.tp_worker.worker
worker.model_runner.save_sharded_model(
path=params["path"],
pattern=params["pattern"],
max_size=params["max_size"],
)
def abort_request(self, recv_req: AbortReq):
2025-05-10 21:54:46 -07:00
# TODO(lmzheng): abort the requests in the grammar queue.
# Delete requests in the waiting queue
2025-03-12 22:22:39 -07:00
to_del = []
for i, req in enumerate(self.waiting_queue):
2025-03-12 22:22:39 -07:00
if req.rid.startswith(recv_req.rid):
to_del.append(i)
2025-03-12 22:22:39 -07:00
# Sort in reverse order to avoid index issues when deleting
2025-05-10 21:54:46 -07:00
for i in reversed(to_del):
2025-03-12 22:22:39 -07:00
req = self.waiting_queue.pop(i)
2025-05-10 21:54:46 -07:00
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2024-11-28 02:22:15 -08:00
logger.debug(f"Abort queued request. {req.rid=}")
# Delete requests in the running batch
2025-05-10 21:54:46 -07:00
if self.cur_batch is self.running_batch or self.cur_batch is None:
reqs = self.running_batch.reqs
else:
reqs = self.running_batch.reqs + self.cur_batch.reqs
for req in reqs:
2025-03-12 22:22:39 -07:00
if req.rid.startswith(recv_req.rid) and not req.finished():
logger.debug(f"Abort running request. {req.rid=}")
req.to_abort = True
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
2024-11-29 17:17:00 -08:00
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
return InitWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed(
self,
recv_req: UpdateWeightsFromDistributedReqInput,
) -> Tuple[bool, str]:
"""Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromDistributedReqOutput(success, message)
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
if recv_req.flush_cache:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return UpdateWeightsFromTensorReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity(
caller_name="release_memory_occupation"
)
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.memory_saver_adapter.pause()
self.flush_cache()
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
self.memory_saver_adapter.resume()
_import_static_state(
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
)
del self.stashed_model_static_state
return ResumeMemoryOccupationReqOutput()
def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time
if t is not None and t <= 0:
t = None
self.forward_sleep_time = t
return SlowDownReqOutput()
def profile(self, recv_req: ProfileReq):
if recv_req.type == ProfileReqType.START_PROFILE:
return self.start_profile(
recv_req.output_dir,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_id,
)
else:
return self.stop_profile()
def start_profile(
self,
output_dir: Optional[str],
num_steps: Optional[int],
activities: Optional[List[str]],
with_stack: Optional[bool],
record_shapes: Optional[bool],
profile_id: Optional[str],
) -> None:
if self.profiler_activities:
return ProfileReqOutput(
success=False,
message="Profiling is already in progress. Call /stop_profile first.",
)
if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
if activities is None:
activities = ["CPU", "GPU"]
self.torch_profiler_output_dir = output_dir
self.profiler_activities = activities
self.profiler_id = profile_id
logger.info(
"Profiling starts. Traces will be saved to: %s (with id %s)",
self.torch_profiler_output_dir,
self.profiler_id,
)
activity_map = {
"CPU": torch.profiler.ProfilerActivity.CPU,
"GPU": torch.profiler.ProfilerActivity.CUDA,
}
torchprof_activities = [
activity_map[a] for a in activities if a in activity_map
]
if torchprof_activities:
self.torch_profiler = torch.profiler.profile(
activities=torchprof_activities,
with_stack=with_stack if with_stack is not None else True,
record_shapes=record_shapes if record_shapes is not None else False,
)
self.torch_profiler.start()
if "MEM" in activities:
torch.cuda.memory._record_memory_history(max_entries=100000)
if "CUDA_PROFILER" in activities:
torch.cuda.cudart().cudaProfilerStart()
if num_steps:
self.profiler_target_forward_ct = self.forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def stop_profile(self) -> None:
if self.profiler_activities is None:
return ProfileReqOutput(
success=False,
message="Profiling is not in progress. Call /start_profile first.",
)
logger.info("Stop profiling...")
if self.torch_profiler is not None:
self.torch_profiler.stop()
self.torch_profiler.export_chrome_trace(
os.path.join(
self.torch_profiler_output_dir,
self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
)
)
if "MEM" in self.profiler_activities:
memory_profile_path = os.path.join(
self.torch_profiler_output_dir,
self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
)
torch.cuda.memory._dump_snapshot(memory_profile_path)
torch.cuda.memory._record_memory_history(enabled=None)
if "CUDA_PROFILER" in self.profiler_activities:
torch.cuda.cudart().cudaProfilerStop()
logger.info(
"Profiling done. Traces are saved to: %s",
self.torch_profiler_output_dir,
)
self.torch_profiler = None
self.torch_profiler_output_dir = None
self.profiler_activities = None
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
get_global_expert_distribution_recorder().start_record()
elif recv_req == ExpertDistributionReq.STOP_RECORD:
get_global_expert_distribution_recorder().stop_record()
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
get_global_expert_distribution_recorder().dump_record()
else:
raise ValueError("Unrecognized ExpertDistributionReq value")
return ExpertDistributionReqOutput()
def open_session(self, recv_req: OpenSessionReqInput):
2024-11-20 00:36:53 -08:00
# handle error
session_id = recv_req.session_id
if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.")
return OpenSessionReqOutput(session_id, False)
elif session_id is None:
logger.warning("session id is None, cannot open.")
return OpenSessionReqOutput(session_id, False)
2024-11-20 00:36:53 -08:00
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
)
return OpenSessionReqOutput(session_id, True)
2024-11-20 00:36:53 -08:00
def close_session(self, recv_req: CloseSessionReqInput):
# handle error
session_id = recv_req.session_id
if session_id not in self.sessions:
logger.warning(f"session id {session_id} does not exist, cannot delete.")
else:
del self.sessions[session_id]
2025-04-30 18:18:07 -07:00
def get_print_prefix(self):
prefix = ""
if self.attn_dp_rank is not None:
prefix += f" DP{self.attn_dp_rank}"
2025-04-30 18:18:07 -07:00
if self.server_args.tp_size > 1:
prefix += f" TP{self.tp_rank}"
if self.pp_size > 1:
prefix += f" PP{self.pp_rank}"
return prefix
def _publish_kv_events(self):
if self.enable_kv_cache_events:
events = self.tree_cache.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
2024-09-29 02:36:12 -07:00
def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
def _export_static_state(model):
return dict(
buffers=[
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
]
)
def _import_static_state(model, static_params):
self_named_buffers = dict(model.named_buffers())
for name, tensor in static_params["buffers"]:
self_named_buffers[name][...] = tensor
2024-09-29 02:36:12 -07:00
def run_scheduler_process(
server_args: ServerArgs,
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
2025-04-30 18:18:07 -07:00
pp_rank: int,
2024-10-11 07:22:48 -07:00
dp_rank: Optional[int],
pipe_writer,
2024-09-29 02:36:12 -07:00
):
# Generate the prefix
2025-04-30 18:18:07 -07:00
prefix = ""
if dp_rank is not None:
prefix += f" DP{dp_rank}"
if server_args.tp_size > 1:
prefix += f" TP{tp_rank}"
if server_args.pp_size > 1:
prefix += f" PP{pp_rank}"
# Config the process
kill_itself_when_parent_died()
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
faulthandler.enable()
parent_process = psutil.Process().parent()
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
dp_rank = int(os.environ["SGLANG_DP_RANK"])
# Configure the logger
configure_logger(server_args, prefix=prefix)
suppress_other_loggers()
2024-10-11 07:22:48 -07:00
# Set cpu affinity to this gpu process
2024-12-06 05:49:29 -08:00
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
embedding_cache_size = 100
if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
init_embedding_cache(embedding_cache_size * 1024 * 1024)
# Create a scheduler and run the event loop
2024-09-29 02:36:12 -07:00
try:
2025-04-30 18:18:07 -07:00
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
pipe_writer.send(
{
"status": "ready",
"max_total_num_tokens": scheduler.max_total_num_tokens,
"max_req_input_len": scheduler.max_req_input_len,
}
)
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
if disaggregation_mode == DisaggregationMode.NULL:
2025-04-30 18:18:07 -07:00
if server_args.pp_size > 1:
scheduler.event_loop_pp()
elif scheduler.enable_overlap:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()
elif disaggregation_mode == DisaggregationMode.PREFILL:
if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_prefill()
else:
scheduler.event_loop_normal_disagg_prefill()
2025-04-30 18:18:07 -07:00
elif disaggregation_mode == DisaggregationMode.DECODE:
if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_decode()
else:
scheduler.event_loop_normal_disagg_decode()
2024-09-29 02:36:12 -07:00
except Exception:
traceback = get_exception_traceback()
logger.error(f"Scheduler hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)