2789 lines
112 KiB
Python
2789 lines
112 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
|
|
|
import faulthandler
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import threading
|
|
import time
|
|
from collections import deque
|
|
from concurrent import futures
|
|
from dataclasses import dataclass
|
|
from http import HTTPStatus
|
|
from typing import Deque, Dict, List, Optional, Tuple, Union
|
|
|
|
import psutil
|
|
import setproctitle
|
|
import torch
|
|
import zmq
|
|
from torch.cuda import Stream as CudaStream
|
|
from torch.cuda import StreamContext as CudaStreamContext
|
|
from torch.distributed import barrier
|
|
|
|
from sglang.srt.configs.model_config import ModelConfig
|
|
from sglang.srt.constrained.base_grammar_backend import (
|
|
INVALID_GRAMMAR_OBJ,
|
|
create_grammar_backend,
|
|
)
|
|
from sglang.srt.disaggregation.decode import (
|
|
DecodePreallocQueue,
|
|
DecodeTransferQueue,
|
|
SchedulerDisaggregationDecodeMixin,
|
|
)
|
|
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
|
|
DecodeKVCacheOffloadManager,
|
|
)
|
|
from sglang.srt.disaggregation.prefill import (
|
|
PrefillBootstrapQueue,
|
|
SchedulerDisaggregationPrefillMixin,
|
|
)
|
|
from sglang.srt.disaggregation.utils import (
|
|
DisaggregationMode,
|
|
MetadataBuffers,
|
|
ReqToMetadataIdxAllocator,
|
|
TransferBackend,
|
|
prepare_abort,
|
|
)
|
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
|
from sglang.srt.environ import envs
|
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
|
from sglang.srt.layers.moe import initialize_moe_config
|
|
from sglang.srt.managers.io_struct import (
|
|
AbortReq,
|
|
BaseBatchReq,
|
|
BaseReq,
|
|
BatchTokenizedEmbeddingReqInput,
|
|
BatchTokenizedGenerateReqInput,
|
|
ClearHiCacheReqInput,
|
|
ClearHiCacheReqOutput,
|
|
CloseSessionReqInput,
|
|
DestroyWeightsUpdateGroupReqInput,
|
|
ExpertDistributionReq,
|
|
ExpertDistributionReqOutput,
|
|
ExpertDistributionReqType,
|
|
FlushCacheReqInput,
|
|
FlushCacheReqOutput,
|
|
FreezeGCReq,
|
|
GetInternalStateReq,
|
|
GetInternalStateReqOutput,
|
|
GetLoadReqInput,
|
|
GetLoadReqOutput,
|
|
GetWeightsByNameReqInput,
|
|
HealthCheckOutput,
|
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
|
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
|
InitWeightsUpdateGroupReqInput,
|
|
LoadLoRAAdapterReqInput,
|
|
LoadLoRAAdapterReqOutput,
|
|
OpenSessionReqInput,
|
|
OpenSessionReqOutput,
|
|
ProfileReq,
|
|
ReleaseMemoryOccupationReqInput,
|
|
ResumeMemoryOccupationReqInput,
|
|
RpcReqInput,
|
|
RpcReqOutput,
|
|
SendWeightsToRemoteInstanceReqInput,
|
|
SendWeightsToRemoteInstanceReqOutput,
|
|
SetInternalStateReq,
|
|
SetInternalStateReqOutput,
|
|
SlowDownReqInput,
|
|
SlowDownReqOutput,
|
|
TokenizedEmbeddingReqInput,
|
|
TokenizedGenerateReqInput,
|
|
UnloadLoRAAdapterReqInput,
|
|
UnloadLoRAAdapterReqOutput,
|
|
UpdateWeightFromDiskReqInput,
|
|
UpdateWeightsFromDistributedReqInput,
|
|
UpdateWeightsFromTensorReqInput,
|
|
)
|
|
from sglang.srt.managers.mm_utils import init_embedding_cache
|
|
from sglang.srt.managers.overlap_utils import FutureMap
|
|
from sglang.srt.managers.schedule_batch import (
|
|
FINISH_ABORT,
|
|
ModelWorkerBatch,
|
|
MultimodalInputs,
|
|
Req,
|
|
RequestStage,
|
|
ScheduleBatch,
|
|
)
|
|
from sglang.srt.managers.schedule_policy import (
|
|
AddReqResult,
|
|
PrefillAdder,
|
|
SchedulePolicy,
|
|
)
|
|
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
|
|
from sglang.srt.managers.scheduler_metrics_mixin import (
|
|
RECORD_STEP_TIME,
|
|
SchedulerMetricsMixin,
|
|
)
|
|
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
|
SchedulerOutputProcessorMixin,
|
|
)
|
|
from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
|
|
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
|
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
|
from sglang.srt.managers.scheduler_runtime_checker_mixin import (
|
|
SchedulerRuntimeCheckerMixin,
|
|
)
|
|
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
|
SchedulerUpdateWeightsMixin,
|
|
)
|
|
from sglang.srt.managers.session_controller import Session
|
|
from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length
|
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
|
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
|
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
|
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
|
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
from sglang.srt.tracing.trace import (
|
|
process_tracing_init,
|
|
trace_set_proc_propagate_context,
|
|
trace_set_thread_info,
|
|
trace_slice_batch,
|
|
trace_slice_end,
|
|
trace_slice_start,
|
|
)
|
|
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
|
from sglang.srt.utils import (
|
|
DynamicGradMode,
|
|
broadcast_pyobj,
|
|
configure_gc_logger,
|
|
configure_logger,
|
|
disable_request_logging,
|
|
freeze_gc,
|
|
get_available_gpu_memory,
|
|
get_bool_env_var,
|
|
get_int_env_var,
|
|
get_zmq_socket,
|
|
kill_itself_when_parent_died,
|
|
numa_bind_to_node,
|
|
point_to_point_pyobj,
|
|
pyspy_dump_schedulers,
|
|
require_mlp_sync,
|
|
require_mlp_tp_gather,
|
|
set_gpu_proc_affinity,
|
|
set_random_seed,
|
|
suppress_other_loggers,
|
|
)
|
|
from sglang.srt.utils.hf_transformers_utils import (
|
|
get_processor,
|
|
get_tokenizer,
|
|
get_tokenizer_from_processor,
|
|
)
|
|
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Test retract decode for debugging purposes
|
|
TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
|
|
TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
|
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingBatchResult:
|
|
embeddings: torch.Tensor
|
|
|
|
|
|
class Scheduler(
|
|
SchedulerOutputProcessorMixin,
|
|
SchedulerUpdateWeightsMixin,
|
|
SchedulerProfilerMixin,
|
|
SchedulerMetricsMixin,
|
|
SchedulerDisaggregationDecodeMixin,
|
|
SchedulerDisaggregationPrefillMixin,
|
|
SchedulerRuntimeCheckerMixin,
|
|
SchedulerPPMixin,
|
|
):
|
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
|
|
|
def __init__(
|
|
self,
|
|
server_args: ServerArgs,
|
|
port_args: PortArgs,
|
|
gpu_id: int,
|
|
tp_rank: int,
|
|
moe_ep_rank: int,
|
|
pp_rank: int,
|
|
dp_rank: Optional[int],
|
|
):
|
|
# Parse args
|
|
self.server_args = server_args
|
|
self.tp_rank = tp_rank
|
|
self.moe_ep_rank = moe_ep_rank
|
|
self.pp_rank = pp_rank
|
|
self.dp_rank = dp_rank
|
|
self.tp_size = server_args.tp_size
|
|
self.moe_ep_size = server_args.ep_size
|
|
self.pp_size = server_args.pp_size
|
|
self.dp_size = server_args.dp_size
|
|
self.schedule_policy = server_args.schedule_policy
|
|
self.enable_priority_scheduling = server_args.enable_priority_scheduling
|
|
self.abort_on_priority_when_disabled = (
|
|
server_args.abort_on_priority_when_disabled
|
|
)
|
|
self.schedule_low_priority_values_first = (
|
|
server_args.schedule_low_priority_values_first
|
|
)
|
|
self.priority_scheduling_preemption_threshold = (
|
|
server_args.priority_scheduling_preemption_threshold
|
|
)
|
|
self.enable_lora = server_args.enable_lora
|
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
|
self.enable_metrics = server_args.enable_metrics
|
|
self.enable_metrics_for_all_schedulers = (
|
|
server_args.enable_metrics_for_all_schedulers
|
|
)
|
|
self.enable_kv_cache_events = bool(
|
|
server_args.kv_events_config and tp_rank == 0
|
|
)
|
|
self.enable_trace = server_args.enable_trace
|
|
self.stream_interval = server_args.stream_interval
|
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
|
server_args.speculative_algorithm
|
|
)
|
|
self.gpu_id = gpu_id
|
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
|
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
|
|
self.page_size = server_args.page_size
|
|
|
|
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
|
compute_dp_attention_world_info(
|
|
server_args.enable_dp_attention,
|
|
self.tp_rank,
|
|
self.tp_size,
|
|
self.dp_size,
|
|
)
|
|
)
|
|
|
|
# Init model config
|
|
self.model_config = ModelConfig.from_server_args(server_args)
|
|
|
|
# Init inter-process communication
|
|
self.init_sockets(server_args, port_args)
|
|
|
|
# Init tokenizer
|
|
self.init_tokenizer()
|
|
|
|
# Init moe config
|
|
self.init_moe_config()
|
|
|
|
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
|
|
if self.server_args.reasoning_parser and self.tokenizer:
|
|
reasoning_parser = ReasoningParser(
|
|
model_type=self.server_args.reasoning_parser, stream_reasoning=False
|
|
)
|
|
self.tokenizer.think_end_id = self.tokenizer.encode(
|
|
reasoning_parser.detector.think_end_token, add_special_tokens=False
|
|
)[0]
|
|
|
|
# Check whether overlap can be enabled
|
|
if not self.is_generation:
|
|
self.enable_overlap = False
|
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
|
|
|
# Launch a tensor parallel worker
|
|
|
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
|
|
|
self.tp_worker = TpModelWorker(
|
|
server_args=server_args,
|
|
gpu_id=gpu_id,
|
|
tp_rank=tp_rank,
|
|
moe_ep_rank=moe_ep_rank,
|
|
pp_rank=pp_rank,
|
|
dp_rank=dp_rank,
|
|
nccl_port=port_args.nccl_port,
|
|
)
|
|
|
|
# Launch a draft worker for speculative decoding
|
|
|
|
self.launch_draft_worker(
|
|
gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
|
)
|
|
|
|
# Dispatch the model worker
|
|
if self.spec_algorithm.is_none():
|
|
self.model_worker = self.tp_worker
|
|
else:
|
|
self.model_worker = self.draft_worker
|
|
|
|
# Get token and memory info from the model worker
|
|
(
|
|
self.max_total_num_tokens,
|
|
self.max_prefill_tokens,
|
|
self.max_running_requests,
|
|
self.max_queued_requests,
|
|
self.max_req_len,
|
|
self.max_req_input_len,
|
|
self.random_seed,
|
|
self.device,
|
|
_,
|
|
_,
|
|
_,
|
|
) = self.tp_worker.get_worker_info()
|
|
if get_global_server_args().pp_max_micro_batch_size is None:
|
|
get_global_server_args().pp_max_micro_batch_size = max(
|
|
self.max_running_requests // server_args.pp_size, 1
|
|
)
|
|
|
|
self.tp_group = self.tp_worker.get_tp_group()
|
|
self.tp_cpu_group = self.tp_group.cpu_group
|
|
self.attn_tp_group = self.tp_worker.get_attention_tp_group()
|
|
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
|
self.pp_group = get_pp_group()
|
|
self.world_group = get_world_group()
|
|
|
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
|
set_random_seed(self.random_seed)
|
|
|
|
# Hybrid memory pool
|
|
self.is_hybrid = self.tp_worker.is_hybrid
|
|
self.is_hybrid_gdn = self.tp_worker.model_runner.hybrid_gdn_config is not None
|
|
|
|
if self.is_hybrid:
|
|
self.sliding_window_size = self.tp_worker.sliding_window_size
|
|
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
|
|
self.tp_worker.get_tokens_per_layer_info()
|
|
)
|
|
|
|
# Print debug info
|
|
if tp_rank == 0:
|
|
avail_mem = get_available_gpu_memory(
|
|
self.device, self.gpu_id, empty_cache=False
|
|
)
|
|
logger.info(
|
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
|
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
|
f"max_running_requests={self.max_running_requests}, "
|
|
f"context_len={self.model_config.context_len}, "
|
|
f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
|
|
)
|
|
|
|
# Init memory pool and cache
|
|
self.init_memory_pool_and_cache()
|
|
|
|
# Init running status
|
|
self.waiting_queue: List[Req] = []
|
|
# The running decoding batch for continuous batching
|
|
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
|
|
# The current forward batch
|
|
self.cur_batch: Optional[ScheduleBatch] = None
|
|
# The last forward batch
|
|
self.last_batch: Optional[ScheduleBatch] = None
|
|
self.forward_ct = 0
|
|
self.forward_ct_decode = 0
|
|
self.num_generated_tokens = 0
|
|
self.last_prefill_tokens = 0
|
|
self.last_decode_stats_tic = time.perf_counter()
|
|
self.last_prefill_stats_tic = time.perf_counter()
|
|
self.return_health_check_ct = 0
|
|
self.num_retracted_reqs: int = 0
|
|
self.num_paused_reqs: int = 0
|
|
self.kv_transfer_speed_gb_s: float = 0.0
|
|
self.kv_transfer_latency_ms: float = 0.0
|
|
self.sessions: Dict[str, Session] = {}
|
|
self.default_stream: CudaStream = torch.get_device_module(
|
|
self.device
|
|
).current_stream()
|
|
if self.device == "cpu":
|
|
self.default_stream.synchronize = lambda: None # No-op for CPU
|
|
self.forward_sleep_time = None
|
|
|
|
# Init chunked prefill
|
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
|
if self.chunked_prefill_size <= 0: # -1 means disable
|
|
self.chunked_prefill_size = None
|
|
self.chunked_req = None
|
|
self.is_mixed_chunk = (
|
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
|
)
|
|
|
|
# Init the grammar backend for constrained generation
|
|
self.grammar_queue: List[Req] = []
|
|
if not server_args.skip_tokenizer_init:
|
|
self.grammar_backend = create_grammar_backend(
|
|
server_args,
|
|
self.tokenizer,
|
|
self.model_config.vocab_size,
|
|
self.model_config.hf_eos_token_id,
|
|
)
|
|
else:
|
|
self.grammar_backend = None
|
|
|
|
# Init schedule policy and new token estimation
|
|
self.policy = SchedulePolicy(
|
|
self.schedule_policy,
|
|
self.tree_cache,
|
|
self.enable_hierarchical_cache,
|
|
self.enable_priority_scheduling,
|
|
self.schedule_low_priority_values_first,
|
|
)
|
|
# Enable preemption for priority scheduling.
|
|
self.try_preemption = self.enable_priority_scheduling
|
|
|
|
assert (
|
|
server_args.schedule_conservativeness >= 0
|
|
), "Invalid schedule_conservativeness"
|
|
self.init_new_token_ratio = min(
|
|
envs.SGLANG_INIT_NEW_TOKEN_RATIO.get()
|
|
* server_args.schedule_conservativeness,
|
|
1.0,
|
|
)
|
|
self.min_new_token_ratio = min(
|
|
self.init_new_token_ratio * envs.SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR.get(),
|
|
1.0,
|
|
)
|
|
self.new_token_ratio_decay = (
|
|
self.init_new_token_ratio - self.min_new_token_ratio
|
|
) / envs.SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS.get()
|
|
self.new_token_ratio = self.init_new_token_ratio
|
|
|
|
# Init watchdog thread
|
|
self.watchdog_timeout = server_args.watchdog_timeout
|
|
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
|
t.start()
|
|
self.parent_process = psutil.Process().parent()
|
|
|
|
# Init memory saver, profiler and metric stats
|
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
|
enable=server_args.enable_memory_saver
|
|
)
|
|
self.offload_tags = set()
|
|
self.init_profiler()
|
|
|
|
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
|
self.input_blocker = (
|
|
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
|
|
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
|
else None
|
|
)
|
|
|
|
# Init metrics stats
|
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
|
|
|
if self.enable_kv_cache_events:
|
|
self.init_kv_events(server_args.kv_events_config)
|
|
|
|
# Init disaggregation
|
|
self.disaggregation_mode = DisaggregationMode(
|
|
self.server_args.disaggregation_mode
|
|
)
|
|
self.init_disaggregation()
|
|
|
|
if get_bool_env_var("SGLANG_GC_LOG"):
|
|
configure_gc_logger()
|
|
|
|
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
|
self.init_deterministic_inference_config()
|
|
|
|
# Init overlap
|
|
self.init_overlap()
|
|
|
|
# Init request dispatcher
|
|
self._request_dispatcher = TypeBasedDispatcher(
|
|
[
|
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
|
(BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
|
|
(BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
|
|
(FlushCacheReqInput, self.flush_cache_wrapped),
|
|
(ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
|
|
(AbortReq, self.abort_request),
|
|
(OpenSessionReqInput, self.open_session),
|
|
(CloseSessionReqInput, self.close_session),
|
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
|
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
|
|
(
|
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
|
self.init_weights_send_group_for_remote_instance,
|
|
),
|
|
(
|
|
SendWeightsToRemoteInstanceReqInput,
|
|
self.send_weights_to_remote_instance,
|
|
),
|
|
(
|
|
UpdateWeightsFromDistributedReqInput,
|
|
self.update_weights_from_distributed,
|
|
),
|
|
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
|
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
|
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
|
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
|
(SlowDownReqInput, self.slow_down),
|
|
(ProfileReq, self.profile),
|
|
(FreezeGCReq, self.handle_freeze_gc),
|
|
(GetInternalStateReq, self.get_internal_state),
|
|
(SetInternalStateReq, self.set_internal_state),
|
|
(RpcReqInput, self.handle_rpc_request),
|
|
(ExpertDistributionReq, self.expert_distribution_handle),
|
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
|
(GetLoadReqInput, self.get_load),
|
|
]
|
|
)
|
|
|
|
def launch_draft_worker(
|
|
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
|
):
|
|
if server_args.speculative_draft_load_format is not None:
|
|
server_args.load_format = server_args.speculative_draft_load_format
|
|
logger.info(
|
|
f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
|
|
)
|
|
|
|
if self.spec_algorithm.is_eagle():
|
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
|
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
|
|
|
|
WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
|
|
|
|
self.draft_worker = WorkerClass(
|
|
gpu_id=gpu_id,
|
|
tp_rank=tp_rank,
|
|
moe_ep_rank=moe_ep_rank,
|
|
server_args=server_args,
|
|
nccl_port=port_args.nccl_port,
|
|
target_worker=self.tp_worker,
|
|
dp_rank=dp_rank,
|
|
)
|
|
elif self.spec_algorithm.is_standalone():
|
|
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
|
|
|
self.draft_worker = StandaloneWorker(
|
|
gpu_id=gpu_id,
|
|
tp_rank=tp_rank,
|
|
moe_ep_rank=moe_ep_rank,
|
|
server_args=server_args,
|
|
nccl_port=port_args.nccl_port,
|
|
target_worker=self.tp_worker,
|
|
dp_rank=dp_rank,
|
|
)
|
|
elif self.spec_algorithm.is_ngram():
|
|
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
|
|
|
self.draft_worker = NGRAMWorker(
|
|
gpu_id=gpu_id,
|
|
tp_rank=tp_rank,
|
|
moe_ep_rank=moe_ep_rank,
|
|
server_args=server_args,
|
|
nccl_port=port_args.nccl_port,
|
|
target_worker=self.tp_worker,
|
|
dp_rank=dp_rank,
|
|
)
|
|
else:
|
|
self.draft_worker = None
|
|
|
|
def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
|
|
context = zmq.Context(2)
|
|
self.idle_sleeper = None
|
|
|
|
class SenderWrapper:
|
|
def __init__(self, socket: zmq.Socket):
|
|
self.socket = socket
|
|
|
|
def send_output(
|
|
self,
|
|
output: Union[BaseReq, BaseBatchReq],
|
|
recv_obj: Optional[Union[BaseReq, BaseBatchReq]] = None,
|
|
):
|
|
if self.socket is None:
|
|
return
|
|
|
|
if (
|
|
isinstance(recv_obj, BaseReq)
|
|
and recv_obj.http_worker_ipc is not None
|
|
and output.http_worker_ipc is None
|
|
):
|
|
# handle communicator reqs for multi-http worker case
|
|
output.http_worker_ipc = recv_obj.http_worker_ipc
|
|
|
|
self.socket.send_pyobj(output)
|
|
|
|
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
|
self.recv_from_tokenizer = get_zmq_socket(
|
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
|
)
|
|
self.recv_from_rpc = get_zmq_socket(
|
|
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
|
)
|
|
|
|
send_to_tokenizer = get_zmq_socket(
|
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
|
)
|
|
if server_args.skip_tokenizer_init:
|
|
# Directly send to the TokenizerManager
|
|
send_to_detokenizer = get_zmq_socket(
|
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
|
)
|
|
else:
|
|
# Send to the DetokenizerManager
|
|
send_to_detokenizer = get_zmq_socket(
|
|
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
|
)
|
|
|
|
self.send_to_tokenizer = SenderWrapper(send_to_tokenizer)
|
|
self.send_to_detokenizer = SenderWrapper(send_to_detokenizer)
|
|
|
|
if self.server_args.sleep_on_idle:
|
|
self.idle_sleeper = IdleSleeper(
|
|
[
|
|
self.recv_from_tokenizer,
|
|
self.recv_from_rpc,
|
|
]
|
|
)
|
|
else:
|
|
self.recv_from_tokenizer = None
|
|
self.recv_from_rpc = None
|
|
self.send_to_tokenizer = SenderWrapper(None)
|
|
self.send_to_detokenizer = SenderWrapper(None)
|
|
|
|
if self.current_scheduler_metrics_enabled():
|
|
self.send_metrics_from_scheduler = get_zmq_socket(
|
|
context, zmq.PUSH, port_args.metrics_ipc_name, False
|
|
)
|
|
|
|
def init_deterministic_inference_config(self):
|
|
"""Initialize deterministic inference configuration for different attention backends."""
|
|
if not self.server_args.enable_deterministic_inference:
|
|
self.truncation_align_size = None
|
|
return
|
|
|
|
backend_sizes = {
|
|
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
|
|
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
|
|
}
|
|
env_var, default_size = backend_sizes.get(
|
|
self.server_args.attention_backend, (None, None)
|
|
)
|
|
self.truncation_align_size = (
|
|
get_int_env_var(env_var, default_size) if env_var else None
|
|
)
|
|
|
|
def init_tokenizer(self):
|
|
server_args = self.server_args
|
|
self.is_generation = self.model_config.is_generation
|
|
|
|
if server_args.skip_tokenizer_init:
|
|
self.tokenizer = self.processor = None
|
|
else:
|
|
if self.model_config.is_multimodal:
|
|
self.processor = get_processor(
|
|
server_args.tokenizer_path,
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
revision=server_args.revision,
|
|
use_fast=not server_args.disable_fast_image_processor,
|
|
)
|
|
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
|
else:
|
|
self.tokenizer = get_tokenizer(
|
|
server_args.tokenizer_path,
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
revision=server_args.revision,
|
|
)
|
|
|
|
def init_memory_pool_and_cache(self):
|
|
server_args = self.server_args
|
|
|
|
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
|
self.tp_worker.get_memory_pool()
|
|
)
|
|
|
|
if (
|
|
server_args.chunked_prefill_size is not None
|
|
and server_args.disable_radix_cache
|
|
):
|
|
if self.is_hybrid:
|
|
ChunkCacheClass = SWAChunkCache
|
|
else:
|
|
ChunkCacheClass = ChunkCache
|
|
self.tree_cache = ChunkCacheClass(
|
|
req_to_token_pool=self.req_to_token_pool,
|
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
page_size=self.page_size,
|
|
)
|
|
else:
|
|
if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
|
|
# lazy import to avoid JIT overhead
|
|
from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp
|
|
|
|
self.tree_cache = RadixCacheCpp(
|
|
disable=False,
|
|
use_hicache=self.enable_hierarchical_cache,
|
|
req_to_token_pool=self.req_to_token_pool,
|
|
token_to_kv_pool=self.token_to_kv_pool_allocator,
|
|
tp_cache_group=self.tp_cpu_group,
|
|
page_size=self.page_size,
|
|
hicache_ratio=server_args.hicache_ratio,
|
|
hicache_size=server_args.hicache_size,
|
|
hicache_write_policy=server_args.hicache_write_policy,
|
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
|
)
|
|
elif self.enable_hierarchical_cache:
|
|
self.tree_cache = HiRadixCache(
|
|
req_to_token_pool=self.req_to_token_pool,
|
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
tp_cache_group=(
|
|
self.attn_tp_cpu_group
|
|
if self.server_args.enable_dp_attention
|
|
else self.tp_cpu_group
|
|
),
|
|
page_size=self.page_size,
|
|
eviction_policy=server_args.radix_eviction_policy,
|
|
hicache_ratio=server_args.hicache_ratio,
|
|
hicache_size=server_args.hicache_size,
|
|
hicache_write_policy=server_args.hicache_write_policy,
|
|
hicache_io_backend=server_args.hicache_io_backend,
|
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
|
enable_metrics=self.enable_metrics,
|
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
|
model_name=server_args.served_model_name,
|
|
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
|
is_eagle=self.spec_algorithm.is_eagle(),
|
|
)
|
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
|
self.tree_cache.cache_controller.layer_done_counter
|
|
)
|
|
elif self.is_hybrid:
|
|
self.tree_cache = SWARadixCache(
|
|
req_to_token_pool=self.req_to_token_pool,
|
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
sliding_window_size=self.sliding_window_size,
|
|
page_size=self.page_size,
|
|
disable=server_args.disable_radix_cache,
|
|
is_eagle=self.spec_algorithm.is_eagle(),
|
|
)
|
|
elif self.is_hybrid_gdn:
|
|
self.tree_cache = MambaRadixCache(
|
|
req_to_token_pool=self.req_to_token_pool,
|
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
page_size=self.page_size,
|
|
disable=server_args.disable_radix_cache,
|
|
)
|
|
elif server_args.enable_lmcache:
|
|
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
|
LMCRadixCache,
|
|
)
|
|
|
|
self.tree_cache = LMCRadixCache(
|
|
req_to_token_pool=self.req_to_token_pool,
|
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
page_size=self.page_size,
|
|
disable=server_args.disable_radix_cache,
|
|
model_config=self.model_config,
|
|
tp_size=self.tp_size,
|
|
rank=self.tp_rank,
|
|
tp_group=self.tp_group,
|
|
eviction_policy=server_args.radix_eviction_policy,
|
|
)
|
|
else:
|
|
self.tree_cache = RadixCache(
|
|
req_to_token_pool=self.req_to_token_pool,
|
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
page_size=self.page_size,
|
|
disable=server_args.disable_radix_cache,
|
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
|
eviction_policy=server_args.radix_eviction_policy,
|
|
is_eagle=self.spec_algorithm.is_eagle(),
|
|
)
|
|
|
|
if (
|
|
server_args.disaggregation_mode == "decode"
|
|
and server_args.disaggregation_decode_enable_offload_kvcache
|
|
):
|
|
self.decode_offload_manager = DecodeKVCacheOffloadManager(
|
|
req_to_token_pool=self.req_to_token_pool,
|
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
tp_group=(
|
|
self.attn_tp_cpu_group
|
|
if self.server_args.enable_dp_attention
|
|
else self.tp_cpu_group
|
|
),
|
|
tree_cache=self.tree_cache,
|
|
server_args=self.server_args,
|
|
)
|
|
else:
|
|
self.decode_offload_manager = None
|
|
|
|
self.decode_mem_cache_buf_multiplier = (
|
|
1
|
|
if self.spec_algorithm.is_none()
|
|
else (
|
|
server_args.speculative_num_draft_tokens
|
|
+ (
|
|
(server_args.speculative_eagle_topk or 1)
|
|
* (server_args.speculative_num_steps or 1)
|
|
)
|
|
)
|
|
)
|
|
|
|
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
|
|
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
|
|
|
def init_disaggregation(self):
|
|
self.transfer_backend = TransferBackend(
|
|
self.server_args.disaggregation_transfer_backend
|
|
)
|
|
|
|
if (
|
|
self.disaggregation_mode == DisaggregationMode.DECODE
|
|
): # *2 for the headroom.
|
|
buffer_size = (self.req_to_token_pool.size) * 2
|
|
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
|
buffer_size
|
|
)
|
|
self.disagg_metadata_buffers = MetadataBuffers(
|
|
buffer_size,
|
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
|
hidden_states_dtype=self.model_config.dtype,
|
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
|
)
|
|
|
|
# The decode requests polling kv cache
|
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
|
gloo_group=self.attn_tp_cpu_group,
|
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
|
tp_rank=self.tp_rank,
|
|
metadata_buffers=self.disagg_metadata_buffers,
|
|
scheduler=self,
|
|
tree_cache=self.tree_cache,
|
|
)
|
|
|
|
# The decode requests pending for pre-allocation
|
|
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
|
|
req_to_token_pool=self.req_to_token_pool,
|
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
draft_token_to_kv_pool=(
|
|
None
|
|
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
|
else self.draft_worker.model_runner.token_to_kv_pool
|
|
),
|
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
|
metadata_buffers=self.disagg_metadata_buffers,
|
|
scheduler=self,
|
|
transfer_queue=self.disagg_decode_transfer_queue,
|
|
tree_cache=self.tree_cache,
|
|
gloo_group=self.attn_tp_cpu_group,
|
|
tp_rank=self.tp_rank,
|
|
tp_size=self.tp_size,
|
|
dp_size=self.server_args.dp_size,
|
|
gpu_id=self.gpu_id,
|
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
|
max_total_num_tokens=self.max_total_num_tokens,
|
|
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
|
|
num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
|
|
transfer_backend=self.transfer_backend,
|
|
)
|
|
|
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
# *2 for the headroom.
|
|
buffer_size = self.max_running_requests * 2
|
|
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
|
buffer_size
|
|
)
|
|
self.disagg_metadata_buffers = MetadataBuffers(
|
|
buffer_size,
|
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
|
hidden_states_dtype=self.model_config.dtype,
|
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
|
)
|
|
|
|
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
|
draft_token_to_kv_pool=(
|
|
None
|
|
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
|
else self.draft_worker.model_runner.token_to_kv_pool
|
|
),
|
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
|
metadata_buffers=self.disagg_metadata_buffers,
|
|
tp_rank=self.tp_rank,
|
|
tp_size=self.tp_size,
|
|
gpu_id=self.gpu_id,
|
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
|
gloo_group=self.attn_tp_cpu_group,
|
|
max_total_num_tokens=self.max_total_num_tokens,
|
|
decode_tp_size=self.server_args.disaggregation_decode_tp,
|
|
decode_dp_size=self.server_args.disaggregation_decode_dp,
|
|
scheduler=self,
|
|
pp_rank=self.pp_rank,
|
|
pp_size=self.pp_size,
|
|
transfer_backend=self.transfer_backend,
|
|
)
|
|
# The prefill requests that are in the middle of kv sending
|
|
self.disagg_prefill_inflight_queue: List[Req] = []
|
|
|
|
def init_overlap(self):
|
|
if not self.enable_overlap:
|
|
return
|
|
|
|
self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
|
self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
|
|
self.device
|
|
).stream(self.forward_stream)
|
|
self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
|
self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
|
|
self.device
|
|
).stream(self.copy_stream)
|
|
|
|
self.future_map = FutureMap(
|
|
self.max_running_requests, self.device, self.spec_algorithm
|
|
)
|
|
self.batch_record_buf = [None] * 2
|
|
self.batch_record_ct = 0
|
|
|
|
def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
|
|
# FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
|
|
# NOTE: More Reliable: record all tensors into the forward stream
|
|
# NOTE: - for all future tensors, we shall always read from future map
|
|
# - for all non-future tensors (produced only by schedule stream),
|
|
# we shall keep its reference not being release during all the forwarding pass
|
|
self.batch_record_ct = (self.batch_record_ct + 1) % 2
|
|
self.batch_record_buf[self.batch_record_ct] = model_worker_batch
|
|
|
|
def init_moe_config(self):
|
|
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
|
|
initialize_moe_config(self.server_args)
|
|
|
|
@DynamicGradMode()
|
|
def event_loop_normal(self):
|
|
"""A normal scheduler loop."""
|
|
while True:
|
|
recv_reqs = self.recv_requests()
|
|
self.process_input_requests(recv_reqs)
|
|
|
|
batch = self.get_next_batch_to_run()
|
|
self.cur_batch = batch
|
|
|
|
if batch:
|
|
result = self.run_batch(batch)
|
|
self.process_batch_result(batch, result)
|
|
else:
|
|
# When the server is idle, do self-check and re-init some states
|
|
self.self_check_during_idle()
|
|
|
|
self.last_batch = batch
|
|
|
|
@DynamicGradMode()
|
|
def event_loop_overlap(self):
|
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
|
self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
|
|
|
|
while True:
|
|
recv_reqs = self.recv_requests()
|
|
self.process_input_requests(recv_reqs)
|
|
|
|
batch = self.get_next_batch_to_run()
|
|
self.cur_batch = batch
|
|
|
|
batch_result = None
|
|
if batch:
|
|
batch_result = self.run_batch(batch)
|
|
self.result_queue.append((batch.copy(), batch_result))
|
|
|
|
if self.last_batch:
|
|
# Process the results of the last batch
|
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
|
self.process_batch_result(tmp_batch, tmp_result)
|
|
elif batch is None:
|
|
# When the server is idle, do self-check and re-init some states
|
|
self.self_check_during_idle()
|
|
|
|
self.launch_batch_sample_if_needed(batch_result)
|
|
self.last_batch = batch
|
|
|
|
if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get():
|
|
self._check_runtime_mem_leak()
|
|
|
|
def recv_requests(self) -> List[Req]:
|
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
|
|
|
if self.recv_skipper is not None:
|
|
last_forward_mode = (
|
|
self.last_batch.forward_mode if self.last_batch is not None else None
|
|
)
|
|
if not self.recv_skipper.handle(last_forward_mode):
|
|
return []
|
|
|
|
if self.pp_rank == 0:
|
|
if self.attn_tp_rank == 0:
|
|
recv_reqs = []
|
|
|
|
while True:
|
|
try:
|
|
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
|
except zmq.ZMQError:
|
|
break
|
|
recv_reqs.append(recv_req)
|
|
|
|
while True:
|
|
try:
|
|
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
|
except zmq.ZMQError:
|
|
break
|
|
recv_reqs.append(recv_rpc)
|
|
else:
|
|
recv_reqs = None
|
|
else:
|
|
if self.attn_tp_rank == 0:
|
|
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
|
recv_reqs = point_to_point_pyobj(
|
|
[],
|
|
self.pp_rank * self.tp_size + dp_offset,
|
|
self.world_group.device_group,
|
|
(self.pp_rank - 1) * self.tp_size + dp_offset,
|
|
self.pp_rank * self.tp_size + dp_offset,
|
|
)
|
|
else:
|
|
recv_reqs = None
|
|
|
|
if self.input_blocker is not None:
|
|
recv_reqs = self.input_blocker.handle(recv_reqs)
|
|
|
|
if self.server_args.enable_dp_attention:
|
|
if self.attn_tp_rank == 0:
|
|
work_reqs = [
|
|
req
|
|
for req in recv_reqs
|
|
if isinstance(
|
|
req,
|
|
(
|
|
TokenizedGenerateReqInput,
|
|
TokenizedEmbeddingReqInput,
|
|
BatchTokenizedGenerateReqInput,
|
|
BatchTokenizedEmbeddingReqInput,
|
|
),
|
|
)
|
|
]
|
|
control_reqs = [
|
|
req
|
|
for req in recv_reqs
|
|
if not isinstance(
|
|
req,
|
|
(
|
|
TokenizedGenerateReqInput,
|
|
TokenizedEmbeddingReqInput,
|
|
BatchTokenizedGenerateReqInput,
|
|
BatchTokenizedEmbeddingReqInput,
|
|
),
|
|
)
|
|
]
|
|
else:
|
|
work_reqs = None
|
|
control_reqs = None
|
|
|
|
if self.attn_tp_size != 1:
|
|
work_reqs = broadcast_pyobj(
|
|
work_reqs,
|
|
self.attn_tp_group.rank,
|
|
self.attn_tp_cpu_group,
|
|
src=self.attn_tp_group.ranks[0],
|
|
)
|
|
if self.tp_size != 1:
|
|
control_reqs = broadcast_pyobj(
|
|
control_reqs,
|
|
self.tp_group.rank,
|
|
self.tp_cpu_group,
|
|
src=self.tp_group.ranks[0],
|
|
)
|
|
recv_reqs = work_reqs + control_reqs
|
|
elif self.tp_size != 1:
|
|
recv_reqs = broadcast_pyobj(
|
|
recv_reqs,
|
|
self.tp_group.rank,
|
|
self.tp_cpu_group,
|
|
src=self.tp_group.ranks[0],
|
|
)
|
|
|
|
if self.enable_trace:
|
|
for req in recv_reqs:
|
|
if isinstance(
|
|
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
|
):
|
|
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
|
trace_slice_start("", req.rid, anonymous=True)
|
|
|
|
return recv_reqs
|
|
|
|
def process_input_requests(self, recv_reqs: List):
|
|
for recv_req in recv_reqs:
|
|
# If it is a health check generation request and there are running requests, ignore it.
|
|
if is_health_check_generate_req(recv_req) and (
|
|
self.chunked_req is not None
|
|
or not self.running_batch.is_empty()
|
|
or len(self.offload_tags) > 0
|
|
):
|
|
self.return_health_check_ct += 1
|
|
continue
|
|
|
|
output = self._request_dispatcher(recv_req)
|
|
if output is not None:
|
|
if isinstance(output, RpcReqOutput):
|
|
if self.recv_from_rpc is not None:
|
|
self.recv_from_rpc.send_pyobj(output)
|
|
else:
|
|
self.send_to_tokenizer.send_output(output, recv_req)
|
|
|
|
def init_req_max_new_tokens(self, req):
|
|
req.sampling_params.max_new_tokens = min(
|
|
(
|
|
req.sampling_params.max_new_tokens
|
|
if req.sampling_params.max_new_tokens is not None
|
|
else 1 << 30
|
|
),
|
|
self.max_req_len - len(req.origin_input_ids) - 1,
|
|
)
|
|
|
|
def handle_generate_request(
|
|
self,
|
|
recv_req: TokenizedGenerateReqInput,
|
|
):
|
|
# Create a new request
|
|
if (
|
|
recv_req.session_params is None
|
|
or recv_req.session_params.id is None
|
|
or recv_req.session_params.id not in self.sessions
|
|
):
|
|
if recv_req.input_embeds is not None:
|
|
# Generate fake input_ids based on the length of input_embeds
|
|
seq_length = len(recv_req.input_embeds)
|
|
fake_input_ids = [1] * seq_length
|
|
recv_req.input_ids = fake_input_ids
|
|
|
|
if recv_req.bootstrap_port is None:
|
|
# Use default bootstrap port
|
|
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
|
|
|
|
req = Req(
|
|
recv_req.rid,
|
|
recv_req.input_text,
|
|
recv_req.input_ids,
|
|
recv_req.sampling_params,
|
|
return_logprob=recv_req.return_logprob,
|
|
top_logprobs_num=recv_req.top_logprobs_num,
|
|
token_ids_logprob=recv_req.token_ids_logprob,
|
|
stream=recv_req.stream,
|
|
lora_id=recv_req.lora_id,
|
|
input_embeds=recv_req.input_embeds,
|
|
custom_logit_processor=recv_req.custom_logit_processor,
|
|
return_hidden_states=recv_req.return_hidden_states,
|
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
|
bootstrap_host=recv_req.bootstrap_host,
|
|
bootstrap_port=recv_req.bootstrap_port,
|
|
bootstrap_room=recv_req.bootstrap_room,
|
|
disagg_mode=self.disaggregation_mode,
|
|
data_parallel_rank=recv_req.data_parallel_rank,
|
|
vocab_size=self.model_config.vocab_size,
|
|
priority=recv_req.priority,
|
|
metrics_collector=(
|
|
self.metrics_collector if self.enable_metrics else None
|
|
),
|
|
http_worker_ipc=recv_req.http_worker_ipc,
|
|
)
|
|
req.tokenizer = self.tokenizer
|
|
|
|
if self.disaggregation_mode != DisaggregationMode.NULL:
|
|
# Invalid request for disaggregated mode
|
|
if recv_req.bootstrap_room is None:
|
|
error_msg = (
|
|
f"Invalid request: Disaggregated request received without "
|
|
f"boostrap room id. {req.rid=}"
|
|
)
|
|
logger.error(error_msg)
|
|
prepare_abort(req, error_msg, status_code=HTTPStatus.BAD_REQUEST)
|
|
self.stream_output([req], req.return_logprob)
|
|
return
|
|
|
|
if (
|
|
recv_req.session_params is not None
|
|
and recv_req.session_params.id is not None
|
|
):
|
|
req.set_finish_with_abort(
|
|
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
|
)
|
|
self.init_req_max_new_tokens(req)
|
|
self._add_request_to_queue(req)
|
|
return
|
|
else:
|
|
# Create a new request from a previous session
|
|
session = self.sessions[recv_req.session_params.id]
|
|
req = session.create_req(recv_req, self.tokenizer)
|
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
|
self.init_req_max_new_tokens(req)
|
|
self._add_request_to_queue(req)
|
|
return
|
|
|
|
# Handle multimodal inputs
|
|
if recv_req.mm_inputs is not None:
|
|
image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
|
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
|
req.origin_input_ids = self.pad_input_ids_func(
|
|
req.origin_input_ids, image_inputs
|
|
)
|
|
req.extend_image_inputs(image_inputs)
|
|
|
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
|
req.set_finish_with_abort(
|
|
error_msg=(
|
|
"Multimodal prompt is too long after expanding multimodal tokens. "
|
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
|
)
|
|
)
|
|
self.init_req_max_new_tokens(req)
|
|
self._add_request_to_queue(req)
|
|
return
|
|
|
|
# initialize before returning
|
|
self.init_req_max_new_tokens(req)
|
|
|
|
# Validate prompt length
|
|
error_msg = validate_input_length(
|
|
req,
|
|
self.max_req_input_len,
|
|
self.server_args.allow_auto_truncate,
|
|
)
|
|
if error_msg:
|
|
req.set_finish_with_abort(error_msg)
|
|
self._add_request_to_queue(req)
|
|
return
|
|
|
|
# Copy more attributes
|
|
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
|
|
# By default, only return the logprobs for output tokens
|
|
# For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
|
|
# to skip input logprob computation entirely
|
|
if req.is_prefill_only:
|
|
req.logprob_start_len = len(req.origin_input_ids)
|
|
else:
|
|
# TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
|
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
|
else:
|
|
req.logprob_start_len = recv_req.logprob_start_len
|
|
|
|
if not req.is_prefill_only and req.logprob_start_len >= len(
|
|
req.origin_input_ids
|
|
):
|
|
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
|
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
|
req.set_finish_with_abort(error_msg)
|
|
self._add_request_to_queue(req)
|
|
return
|
|
|
|
# Init grammar cache for this request
|
|
add_to_grammar_queue = False
|
|
if (
|
|
req.sampling_params.json_schema is not None
|
|
or req.sampling_params.regex is not None
|
|
or req.sampling_params.ebnf is not None
|
|
or req.sampling_params.structural_tag is not None
|
|
):
|
|
if self.grammar_backend is None:
|
|
error_msg = "Grammar-based generation (json_schema, regex, ebnf, structural_tag) is not supported when the server is launched with --grammar-backend none"
|
|
req.set_finish_with_abort(error_msg)
|
|
else:
|
|
if req.sampling_params.json_schema is not None:
|
|
key = ("json", req.sampling_params.json_schema)
|
|
elif req.sampling_params.regex is not None:
|
|
key = ("regex", req.sampling_params.regex)
|
|
elif req.sampling_params.ebnf is not None:
|
|
key = ("ebnf", req.sampling_params.ebnf)
|
|
elif req.sampling_params.structural_tag:
|
|
key = ("structural_tag", req.sampling_params.structural_tag)
|
|
|
|
value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
|
|
req.grammar = value
|
|
|
|
if not cache_hit:
|
|
req.grammar_key = key
|
|
add_to_grammar_queue = True
|
|
else:
|
|
if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar.
|
|
error_msg = f"Invalid grammar request with cache hit: {key=}"
|
|
req.set_finish_with_abort(error_msg)
|
|
|
|
if add_to_grammar_queue:
|
|
self.grammar_queue.append(req)
|
|
else:
|
|
self._add_request_to_queue(req)
|
|
|
|
def handle_batch_generate_request(
|
|
self,
|
|
recv_req: BatchTokenizedGenerateReqInput,
|
|
):
|
|
"""Handle optimized batch generate request."""
|
|
logger.debug(f"Processing batch generate request with {len(recv_req)} requests")
|
|
|
|
# Process each request in the batch
|
|
for tokenized_req in recv_req:
|
|
self.handle_generate_request(tokenized_req)
|
|
|
|
def _prefetch_kvcache(self, req: Req):
|
|
if self.enable_hicache_storage:
|
|
req.init_next_round_input(self.tree_cache)
|
|
if req.last_node.backuped:
|
|
# only to initiate the prefetch if the last node is backuped
|
|
# otherwise, the allocated GPU memory must be locked for integrity
|
|
last_hash = req.last_host_node.get_last_hash_value()
|
|
matched_len = len(req.prefix_indices) + req.host_hit_length
|
|
new_input_tokens = req.fill_ids[matched_len:]
|
|
|
|
prefix_keys = (
|
|
req.last_node.get_prefix_hash_values(req.last_node.parent)
|
|
if self.tree_cache.hicache_storage_pass_prefix_keys
|
|
else None
|
|
)
|
|
self.tree_cache.prefetch_from_storage(
|
|
req.rid,
|
|
req.last_host_node,
|
|
new_input_tokens,
|
|
last_hash,
|
|
prefix_keys,
|
|
)
|
|
|
|
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
|
|
if self.disaggregation_mode == DisaggregationMode.NULL:
|
|
self._set_or_validate_priority(req)
|
|
if self._abort_on_queued_limit(req):
|
|
return
|
|
self._prefetch_kvcache(req)
|
|
self.waiting_queue.append(req)
|
|
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
|
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
self._prefetch_kvcache(req)
|
|
self.disagg_prefill_bootstrap_queue.add(
|
|
req, self.model_config.num_key_value_heads
|
|
)
|
|
req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
|
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
|
|
if not is_retracted:
|
|
req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
|
|
else:
|
|
raise ValueError(f"Invalid {self.disaggregation_mode=}")
|
|
|
|
def _set_or_validate_priority(self, req: Req):
|
|
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
|
if self.enable_priority_scheduling and req.priority is None:
|
|
if self.schedule_low_priority_values_first:
|
|
req.priority = sys.maxsize
|
|
else:
|
|
req.priority = -sys.maxsize - 1
|
|
elif (
|
|
not self.enable_priority_scheduling
|
|
and req.priority is not None
|
|
and self.abort_on_priority_when_disabled
|
|
):
|
|
abort_req = AbortReq(
|
|
finished_reason={
|
|
"type": "abort",
|
|
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
|
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
|
|
},
|
|
rid=req.rid,
|
|
)
|
|
self.send_to_tokenizer.send_output(abort_req, req)
|
|
|
|
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
|
|
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
|
|
if (
|
|
self.max_queued_requests is None
|
|
or len(self.waiting_queue) + 1 <= self.max_queued_requests
|
|
):
|
|
return False
|
|
|
|
# Reject the incoming request by default.
|
|
req_to_abort = recv_req
|
|
message = "The request queue is full."
|
|
if self.enable_priority_scheduling:
|
|
# With priority scheduling, consider aboritng an existing request based on the priority.
|
|
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
|
|
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
|
|
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
|
|
direction = 1 if self.schedule_low_priority_values_first else -1
|
|
key_fn = lambda item: (
|
|
direction * item[1].priority,
|
|
item[1].time_stats.wait_queue_entry_time,
|
|
)
|
|
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
|
|
abort_existing_req = (
|
|
direction * recv_req.priority < direction * candidate_req.priority
|
|
)
|
|
if abort_existing_req:
|
|
self.waiting_queue.pop(idx)
|
|
req_to_abort = candidate_req
|
|
message = "The request is aborted by a higher priority request."
|
|
|
|
self.send_to_tokenizer.send_output(
|
|
AbortReq(
|
|
finished_reason={
|
|
"type": "abort",
|
|
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
|
"message": message,
|
|
},
|
|
rid=req_to_abort.rid,
|
|
),
|
|
req_to_abort,
|
|
)
|
|
return req_to_abort.rid == recv_req.rid
|
|
|
|
def handle_embedding_request(
|
|
self,
|
|
recv_req: TokenizedEmbeddingReqInput,
|
|
):
|
|
req = Req(
|
|
recv_req.rid,
|
|
recv_req.input_text,
|
|
recv_req.input_ids,
|
|
recv_req.sampling_params,
|
|
token_type_ids=recv_req.token_type_ids,
|
|
priority=recv_req.priority,
|
|
http_worker_ipc=recv_req.http_worker_ipc,
|
|
)
|
|
req.tokenizer = self.tokenizer
|
|
|
|
# Handle multimodal inputs
|
|
if recv_req.image_inputs is not None:
|
|
image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
|
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
|
req.origin_input_ids = self.pad_input_ids_func(
|
|
req.origin_input_ids, image_inputs
|
|
)
|
|
req.extend_image_inputs(image_inputs)
|
|
|
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
|
req.set_finish_with_abort(
|
|
error_msg=(
|
|
"Multimodal prompt is too long after expanding multimodal tokens. "
|
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
|
)
|
|
)
|
|
self._add_request_to_queue(req)
|
|
return
|
|
|
|
# Validate prompts length
|
|
error_msg = validate_input_length(
|
|
req,
|
|
self.max_req_input_len,
|
|
self.server_args.allow_auto_truncate,
|
|
)
|
|
if error_msg:
|
|
self._add_request_to_queue(req)
|
|
return
|
|
|
|
# Copy more attributes
|
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
|
self._add_request_to_queue(req)
|
|
|
|
def handle_batch_embedding_request(
|
|
self,
|
|
recv_req: BatchTokenizedEmbeddingReqInput,
|
|
):
|
|
"""Handle optimized batch embedding request."""
|
|
logger.debug(
|
|
f"Processing batch embedding request with {len(recv_req)} requests"
|
|
)
|
|
|
|
# Process each request in the batch
|
|
for tokenized_req in recv_req:
|
|
self.handle_embedding_request(tokenized_req)
|
|
|
|
def _get_token_info(self):
|
|
available_size = self.token_to_kv_pool_allocator.available_size()
|
|
evictable_size = self.tree_cache.evictable_size()
|
|
num_used = self.max_total_num_tokens - (available_size + evictable_size)
|
|
token_usage = num_used / self.max_total_num_tokens
|
|
return num_used, token_usage, available_size, evictable_size
|
|
|
|
def _get_mamba_token_info(self):
|
|
is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
|
|
full_available_size = self.token_to_kv_pool_allocator.available_size()
|
|
full_evictable_size = (
|
|
self.tree_cache.full_evictable_size() if is_radix_tree else 0
|
|
)
|
|
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
|
|
mamba_evictable_size = (
|
|
self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
|
|
)
|
|
full_num_used = self.token_to_kv_pool_allocator.size - (
|
|
full_available_size + full_evictable_size
|
|
)
|
|
mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
|
|
mamba_available_size + mamba_evictable_size
|
|
)
|
|
full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
|
|
mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
|
|
return (
|
|
full_num_used,
|
|
mamba_num_used,
|
|
full_token_usage,
|
|
mamba_usage,
|
|
full_available_size,
|
|
full_evictable_size,
|
|
mamba_available_size,
|
|
mamba_evictable_size,
|
|
)
|
|
|
|
def _get_swa_token_info(self):
|
|
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
|
full_evictable_size = self.tree_cache.full_evictable_size()
|
|
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
|
swa_evictable_size = self.tree_cache.swa_evictable_size()
|
|
full_num_used = self.full_tokens_per_layer - (
|
|
full_available_size + full_evictable_size
|
|
)
|
|
swa_num_used = self.swa_tokens_per_layer - (
|
|
swa_available_size + swa_evictable_size
|
|
)
|
|
full_token_usage = full_num_used / self.full_tokens_per_layer
|
|
swa_token_usage = swa_num_used / self.swa_tokens_per_layer
|
|
return (
|
|
full_num_used,
|
|
swa_num_used,
|
|
full_token_usage,
|
|
swa_token_usage,
|
|
full_available_size,
|
|
full_evictable_size,
|
|
swa_available_size,
|
|
swa_evictable_size,
|
|
)
|
|
|
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
|
# Merge the prefill batch into the running batch
|
|
chunked_req_to_exclude = set()
|
|
if self.chunked_req:
|
|
# Move the chunked request out of the batch so that we can merge
|
|
# only finished requests to running_batch.
|
|
chunked_req_to_exclude.add(self.chunked_req)
|
|
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
|
# chunked request keeps its rid but will get a new req_pool_idx
|
|
if self.tp_worker.model_runner.mambaish_config is not None:
|
|
self.req_to_token_pool.free(
|
|
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
|
)
|
|
else:
|
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
|
if self.last_batch.chunked_req is not None:
|
|
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
|
# We need to discard it.
|
|
chunked_req_to_exclude.add(self.last_batch.chunked_req)
|
|
|
|
# Filter batch
|
|
last_bs = self.last_batch.batch_size()
|
|
self.last_batch.filter_batch(
|
|
chunked_req_to_exclude=list(chunked_req_to_exclude)
|
|
)
|
|
if self.last_batch.batch_size() < last_bs:
|
|
self.running_batch.batch_is_full = False
|
|
|
|
# Merge the new batch into the running batch.
|
|
# For prefill-only batch, we can avoid going through decoding step.
|
|
if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
|
|
if self.running_batch.is_empty():
|
|
self.running_batch = self.last_batch
|
|
else:
|
|
# Merge running_batch with prefill batch
|
|
self.running_batch.merge_batch(self.last_batch)
|
|
|
|
new_batch = self.get_new_batch_prefill()
|
|
|
|
need_dp_attn_preparation = require_mlp_sync(self.server_args)
|
|
|
|
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
|
|
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
|
|
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
|
|
new_batch = self.prepare_mlp_sync_batch(new_batch)
|
|
need_dp_attn_preparation = new_batch is None
|
|
|
|
if new_batch is not None:
|
|
# Run prefill first if possible
|
|
ret = new_batch
|
|
else:
|
|
# Run decode
|
|
if not self.running_batch.is_empty():
|
|
self.running_batch = self.update_running_batch(self.running_batch)
|
|
ret = self.running_batch if not self.running_batch.is_empty() else None
|
|
else:
|
|
ret = None
|
|
|
|
# Handle DP attention
|
|
if need_dp_attn_preparation:
|
|
ret = self.prepare_mlp_sync_batch(ret)
|
|
|
|
return ret
|
|
|
|
def get_num_allocatable_reqs(self, running_bs):
|
|
res = get_global_server_args().pp_max_micro_batch_size - running_bs
|
|
if self.pp_size > 1:
|
|
res = min(res, self.req_to_token_pool.available_size())
|
|
return res
|
|
|
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
|
# Check if the grammar is ready in the grammar queue
|
|
if self.grammar_queue:
|
|
self.move_ready_grammar_requests()
|
|
|
|
if self.try_preemption:
|
|
# Reset batch_is_full to try preemption with a prefill adder.
|
|
self.running_batch.batch_is_full = False
|
|
|
|
# Handle the cases where prefill is not allowed
|
|
if (
|
|
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
|
) and self.chunked_req is None:
|
|
return None
|
|
|
|
running_bs = len(self.running_batch.reqs)
|
|
# Ignore the check if self.chunked_req is not None.
|
|
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
|
|
# as the space for the chunked request has just been released.
|
|
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
|
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
|
if (
|
|
self.get_num_allocatable_reqs(running_bs) <= 0
|
|
and not self.chunked_req
|
|
and not self.try_preemption
|
|
):
|
|
self.running_batch.batch_is_full = True
|
|
return None
|
|
|
|
if self.enable_hierarchical_cache:
|
|
self.tree_cache.check_hicache_events()
|
|
|
|
# Get priority queue
|
|
self.policy.calc_priority(self.waiting_queue)
|
|
|
|
# Prefill policy
|
|
adder = PrefillAdder(
|
|
self.page_size,
|
|
self.tree_cache,
|
|
self.token_to_kv_pool_allocator,
|
|
self.running_batch,
|
|
self.new_token_ratio,
|
|
self.max_prefill_tokens,
|
|
self.chunked_prefill_size,
|
|
running_bs if self.is_mixed_chunk else 0,
|
|
self.priority_scheduling_preemption_threshold,
|
|
)
|
|
|
|
if self.chunked_req is not None:
|
|
self.chunked_req.init_next_round_input()
|
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
|
|
|
if self.enable_lora:
|
|
lora_set = set([req.lora_id for req in self.running_batch.reqs])
|
|
|
|
# Get requests from the waiting queue to a new prefill batch
|
|
for req in self.waiting_queue:
|
|
|
|
if self.enable_lora and not self.tp_worker.can_run_lora_batch(
|
|
lora_set
|
|
| set([req.lora_id for req in adder.can_run_list])
|
|
| set([req.lora_id])
|
|
):
|
|
self.running_batch.batch_is_full = True
|
|
break
|
|
|
|
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
|
|
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
|
self.running_batch.batch_is_full = True
|
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
|
# so we need to check if the available size for the actual available size.
|
|
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
|
self.running_batch.batch_is_full = True
|
|
|
|
if self.running_batch.batch_is_full:
|
|
if not self.try_preemption:
|
|
break
|
|
if not adder.preempt_to_schedule(req, self.server_args):
|
|
break
|
|
|
|
if self.enable_hicache_storage:
|
|
prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
|
|
if not prefetch_done:
|
|
# skip staging requests that are ongoing prefetch
|
|
continue
|
|
|
|
req.init_next_round_input(self.tree_cache)
|
|
res = adder.add_one_req(
|
|
req,
|
|
has_chunked_req=(self.chunked_req is not None),
|
|
truncation_align_size=self.truncation_align_size,
|
|
)
|
|
|
|
if res != AddReqResult.CONTINUE:
|
|
if res == AddReqResult.NO_TOKEN:
|
|
if self.enable_hierarchical_cache:
|
|
# Set batch_is_full after making sure there are requests that can be served
|
|
self.running_batch.batch_is_full = len(
|
|
adder.can_run_list
|
|
) > 0 or (not self.running_batch.is_empty())
|
|
else:
|
|
self.running_batch.batch_is_full = True
|
|
break
|
|
|
|
# Update waiting queue
|
|
can_run_list: List[Req] = adder.can_run_list
|
|
if len(can_run_list) == 0:
|
|
return None
|
|
|
|
if self.enable_metrics:
|
|
# only record queue time when enable_metrics is True to avoid overhead
|
|
for req in can_run_list:
|
|
req.add_latency(RequestStage.PREFILL_WAITING)
|
|
|
|
self.waiting_queue = [
|
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
|
]
|
|
if adder.preempt_list:
|
|
for req in adder.preempt_list:
|
|
self._add_request_to_queue(req)
|
|
|
|
if adder.new_chunked_req is not None:
|
|
assert self.chunked_req is None
|
|
self.chunked_req = adder.new_chunked_req
|
|
|
|
if self.chunked_req:
|
|
self.chunked_req.is_chunked += 1
|
|
|
|
# Print stats
|
|
if self.current_scheduler_metrics_enabled():
|
|
self.log_prefill_stats(adder, can_run_list, running_bs, 0)
|
|
|
|
for req in can_run_list:
|
|
if req.time_stats.forward_entry_time == 0:
|
|
# Avoid update chunked request many times
|
|
req.time_stats.forward_entry_time = time.perf_counter()
|
|
if self.enable_metrics:
|
|
self.metrics_collector.observe_queue_time(
|
|
req.time_stats.get_queueing_time(),
|
|
)
|
|
|
|
# Create a new batch
|
|
new_batch = ScheduleBatch.init_new(
|
|
can_run_list,
|
|
self.req_to_token_pool,
|
|
self.token_to_kv_pool_allocator,
|
|
self.tree_cache,
|
|
self.model_config,
|
|
self.enable_overlap,
|
|
self.spec_algorithm,
|
|
chunked_req=self.chunked_req,
|
|
)
|
|
if self.enable_hierarchical_cache:
|
|
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
|
|
new_batch.hicache_consumer_index = (
|
|
self.tree_cache.ready_to_load_host_cache()
|
|
)
|
|
|
|
new_batch.prepare_for_extend()
|
|
|
|
# Mixed-style chunked prefill
|
|
if (
|
|
self.is_mixed_chunk
|
|
and not self.running_batch.is_empty()
|
|
and not (new_batch.return_logprob or self.running_batch.return_logprob)
|
|
):
|
|
# TODO (lianmin): support return_logprob + mixed chunked prefill
|
|
self.running_batch.filter_batch()
|
|
if not self.running_batch.is_empty():
|
|
self.running_batch.prepare_for_decode()
|
|
new_batch.mix_with_running(self.running_batch)
|
|
new_batch.decoding_reqs = self.running_batch.reqs
|
|
self.running_batch = ScheduleBatch(
|
|
reqs=[], batch_is_full=self.running_batch.batch_is_full
|
|
)
|
|
else:
|
|
new_batch.decoding_reqs = None
|
|
|
|
return new_batch
|
|
|
|
def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
|
|
"""Update the current running decoding batch."""
|
|
initial_bs = batch.batch_size()
|
|
|
|
batch.filter_batch()
|
|
if batch.is_empty():
|
|
batch.batch_is_full = False
|
|
return batch
|
|
|
|
# Check if decode out of memory
|
|
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
|
TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
|
|
):
|
|
old_ratio = self.new_token_ratio
|
|
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
|
|
self.server_args
|
|
)
|
|
self.num_retracted_reqs = len(retracted_reqs)
|
|
self.new_token_ratio = new_token_ratio
|
|
for req in reqs_to_abort:
|
|
self.send_to_tokenizer.send_output(
|
|
AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
|
|
)
|
|
|
|
logger.info(
|
|
"KV cache pool is full. Retract requests. "
|
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
|
f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
|
|
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
|
|
)
|
|
|
|
for req in retracted_reqs:
|
|
self._add_request_to_queue(req, is_retracted=True)
|
|
else:
|
|
self.new_token_ratio = max(
|
|
self.new_token_ratio - self.new_token_ratio_decay,
|
|
self.min_new_token_ratio,
|
|
)
|
|
|
|
if batch.batch_size() < initial_bs:
|
|
batch.batch_is_full = False
|
|
|
|
# Update batch tensors
|
|
batch.prepare_for_decode()
|
|
return batch
|
|
|
|
# placeholder for override
|
|
def update_cache_from_scheduler(
|
|
self, schedule_batch: ScheduleBatch, batch_result: GenerationBatchResult
|
|
):
|
|
pass
|
|
|
|
def run_batch(
|
|
self, batch: ScheduleBatch
|
|
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
|
"""Run a batch."""
|
|
self.forward_ct += 1
|
|
|
|
# Whether to run the profiler
|
|
self._profile_batch_predicate(batch)
|
|
if self.forward_sleep_time is not None:
|
|
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
|
|
time.sleep(self.forward_sleep_time)
|
|
|
|
# Run forward
|
|
if self.is_generation:
|
|
|
|
batch_or_worker_batch = batch
|
|
|
|
if self.enable_overlap or self.spec_algorithm.is_none():
|
|
# FIXME(lsyin): remove this if and finally unify the abstraction
|
|
batch_or_worker_batch = batch.get_model_worker_batch()
|
|
|
|
if self.enable_overlap:
|
|
# FIXME: remove this assert
|
|
assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
|
|
model_worker_batch = batch_or_worker_batch
|
|
self.record_batch_in_overlap(model_worker_batch)
|
|
|
|
# Sampling info will be modified during forward
|
|
model_worker_batch.sampling_info = (
|
|
model_worker_batch.sampling_info.copy_for_forward()
|
|
)
|
|
|
|
bs = len(model_worker_batch.seq_lens)
|
|
future_indices = self.future_map.alloc_future_indices(bs)
|
|
|
|
with self.forward_stream_ctx:
|
|
self.forward_stream.wait_stream(self.default_stream)
|
|
self.future_map.resolve_future(model_worker_batch)
|
|
batch_result = self.model_worker.forward_batch_generation(
|
|
model_worker_batch
|
|
)
|
|
# FIXME(lsyin): maybe move this to forward_batch_generation
|
|
batch_result.copy_done = torch.get_device_module(
|
|
self.device
|
|
).Event()
|
|
if batch_result.delay_sample_func is None:
|
|
self.future_map.store_to_map(future_indices, batch_result)
|
|
batch_result.copy_to_cpu()
|
|
else:
|
|
batch_result.future_indices = future_indices
|
|
|
|
# FIXME(lsyin): move this assignment elsewhere
|
|
future_indices_or_next_token_ids = -future_indices.indices
|
|
|
|
if batch.is_v2_eagle:
|
|
# FIXME(lsyin): tmp code for eagle v2
|
|
# We only keep future indices for next draft input
|
|
|
|
batch.spec_info = batch_result.next_draft_input
|
|
batch.spec_info.future_indices = future_indices
|
|
|
|
# batch.spec_info = EagleDraftInput(
|
|
# future_indices=future_indices,
|
|
# verify_done=batch_result.next_draft_input.verify_done,
|
|
# # FIXME(lsyin): remove the allocate_lens in EagleDraftInput
|
|
# allocate_lens=batch_result.next_draft_input.allocate_lens,
|
|
# )
|
|
|
|
# The future value, usually for next batch preparation
|
|
# Current implementation strictly synchronizes the seq_lens
|
|
batch.seq_lens = batch_result.next_draft_input.new_seq_lens
|
|
else:
|
|
batch_result = self.model_worker.forward_batch_generation(
|
|
batch_or_worker_batch
|
|
)
|
|
future_indices_or_next_token_ids = batch_result.next_token_ids
|
|
self.update_cache_from_scheduler(batch, batch_result)
|
|
|
|
# NOTE: future_indices_or_next_token_ids is used in ScheduleBatch,
|
|
# which can probably be replaced by future_indices later [TODO(lsyin)].
|
|
# we shall still keep the original outputs, e.g. next_token_ids
|
|
# in the GenerationBatchOutput for processing after copy_done.
|
|
batch.output_ids = future_indices_or_next_token_ids
|
|
|
|
# These 2 values are needed for processing the output, but the values can be
|
|
# modified by overlap schedule. So we have to copy them here so that
|
|
# we can use the correct values in output processing.
|
|
if batch.return_logprob or self.spec_algorithm.is_eagle():
|
|
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
|
else:
|
|
extend_input_len_per_req = None
|
|
|
|
if batch.return_logprob:
|
|
extend_logprob_start_len_per_req = [
|
|
req.extend_logprob_start_len for req in batch.reqs
|
|
]
|
|
else:
|
|
extend_logprob_start_len_per_req = None
|
|
|
|
batch_result.extend_input_len_per_req = extend_input_len_per_req
|
|
batch_result.extend_logprob_start_len_per_req = (
|
|
extend_logprob_start_len_per_req
|
|
)
|
|
return batch_result
|
|
else: # embedding or reward model
|
|
model_worker_batch = batch.get_model_worker_batch()
|
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
|
ret = EmbeddingBatchResult(embeddings=embeddings)
|
|
return ret
|
|
|
|
def launch_batch_sample_if_needed(
|
|
self, batch_result: GenerationBatchResult
|
|
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
|
# TODO(lsyin): make the delayed sample a default behavior after
|
|
# unifying the forward_batch_generation interface (related to spec V2).
|
|
if batch_result is None or batch_result.delay_sample_func is None:
|
|
return
|
|
|
|
with self.forward_stream_ctx:
|
|
self.forward_stream.wait_stream(self.default_stream)
|
|
_batch_result = batch_result.delay_sample_func()
|
|
assert _batch_result is batch_result
|
|
self.future_map.store_to_map(batch_result.future_indices, batch_result)
|
|
batch_result.copy_to_cpu()
|
|
|
|
def process_batch_result(
|
|
self,
|
|
batch: ScheduleBatch,
|
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
|
):
|
|
if batch.forward_mode.is_decode():
|
|
self.process_batch_result_decode(batch, result)
|
|
if self.enable_trace:
|
|
trace_slice_batch("decode loop", batch.reqs)
|
|
|
|
elif batch.forward_mode.is_extend():
|
|
self.process_batch_result_prefill(batch, result)
|
|
if self.enable_trace:
|
|
trace_slice_batch("prefill", batch.reqs)
|
|
|
|
elif batch.forward_mode.is_idle():
|
|
if self.enable_overlap:
|
|
if result.copy_done is not None:
|
|
result.copy_done.synchronize()
|
|
|
|
self.maybe_send_health_check_signal()
|
|
|
|
def maybe_send_health_check_signal(self):
|
|
if self.return_health_check_ct:
|
|
# Return some signal for the health check.
|
|
# This is used to prevent the health check signal being blocked by long context prefill.
|
|
# However, one minor issue is that this code path does not check the status of detokenizer manager.
|
|
self.return_health_check_ct -= 1
|
|
self.send_to_tokenizer.send_output(HealthCheckOutput())
|
|
|
|
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
|
|
return self.prepare_mlp_sync_batch_raw(
|
|
local_batch,
|
|
dp_size=self.server_args.dp_size,
|
|
attn_tp_size=self.attn_tp_size,
|
|
tp_group=self.tp_group,
|
|
get_idle_batch=self.get_idle_batch,
|
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
|
spec_algorithm=self.spec_algorithm,
|
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
|
offload_tags=self.offload_tags,
|
|
)
|
|
|
|
@staticmethod
|
|
def prepare_mlp_sync_batch_raw(
|
|
local_batch: ScheduleBatch,
|
|
dp_size,
|
|
attn_tp_size: int,
|
|
tp_group,
|
|
get_idle_batch,
|
|
disable_cuda_graph: bool,
|
|
spec_algorithm,
|
|
speculative_num_draft_tokens,
|
|
require_mlp_tp_gather: bool,
|
|
disable_overlap_schedule: bool,
|
|
offload_tags: set[str],
|
|
):
|
|
# Check if other DP workers have running batches
|
|
if local_batch is None:
|
|
num_tokens = 0
|
|
num_tokens_for_logprob = 0
|
|
elif local_batch.forward_mode.is_decode():
|
|
num_tokens = local_batch.batch_size()
|
|
num_tokens_for_logprob = num_tokens
|
|
else:
|
|
num_tokens = local_batch.extend_num_tokens
|
|
num_tokens_for_logprob = sum(
|
|
[
|
|
# We should have at least 1 token for sample in every case.
|
|
max(extend_len - logprob_start_len, 1)
|
|
for logprob_start_len, extend_len in zip(
|
|
local_batch.extend_logprob_start_lens, local_batch.extend_lens
|
|
)
|
|
]
|
|
)
|
|
|
|
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
|
|
can_cuda_graph = 1
|
|
else:
|
|
can_cuda_graph = 0
|
|
|
|
is_extend_in_batch = (
|
|
local_batch.forward_mode.is_extend() if local_batch else False
|
|
)
|
|
|
|
tbo_preparer = TboDPAttentionPreparer()
|
|
if len(offload_tags) == 0 and disable_overlap_schedule:
|
|
group = tp_group.device_group
|
|
device = tp_group.device
|
|
else:
|
|
group = tp_group.cpu_group
|
|
device = "cpu"
|
|
|
|
local_info = torch.tensor(
|
|
[
|
|
num_tokens,
|
|
can_cuda_graph,
|
|
num_tokens_for_logprob,
|
|
is_extend_in_batch,
|
|
*tbo_preparer.prepare_all_gather(
|
|
local_batch,
|
|
),
|
|
],
|
|
dtype=torch.int64,
|
|
device=device,
|
|
)
|
|
global_info = torch.empty(
|
|
(dp_size, attn_tp_size, 6),
|
|
dtype=torch.int64,
|
|
device=device,
|
|
)
|
|
torch.distributed.all_gather_into_tensor(
|
|
global_info.flatten(),
|
|
local_info,
|
|
group=group,
|
|
)
|
|
global_num_tokens = global_info[:, 0, 0].tolist()
|
|
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
|
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
|
|
is_extend_in_batch = global_info[:, 0, 3].tolist()
|
|
|
|
tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
|
|
global_info[:, :, 4:6]
|
|
)
|
|
|
|
if local_batch is None and max(global_num_tokens) > 0:
|
|
local_batch = get_idle_batch()
|
|
|
|
if local_batch is not None:
|
|
# TODO: handle the case when moe_dense_tp_size != 1
|
|
if not require_mlp_tp_gather:
|
|
local_batch.global_num_tokens = [num_tokens]
|
|
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
|
|
else:
|
|
local_batch.global_num_tokens = global_num_tokens
|
|
local_batch.global_num_tokens_for_logprob = (
|
|
global_num_tokens_for_logprob
|
|
)
|
|
local_batch.is_extend_in_batch = any(is_extend_in_batch)
|
|
local_batch.tbo_split_seq_index = tbo_split_seq_index
|
|
local_batch.global_forward_mode = global_forward_mode
|
|
|
|
# Check forward mode for cuda graph
|
|
if not disable_cuda_graph:
|
|
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
|
|
|
return local_batch
|
|
|
|
def get_idle_batch(self):
|
|
idle_batch = ScheduleBatch.init_new(
|
|
[],
|
|
self.req_to_token_pool,
|
|
self.token_to_kv_pool_allocator,
|
|
self.tree_cache,
|
|
self.model_config,
|
|
self.enable_overlap,
|
|
self.spec_algorithm,
|
|
)
|
|
idle_batch.prepare_for_idle()
|
|
return idle_batch
|
|
|
|
def move_ready_grammar_requests(self):
|
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
|
|
|
num_ready_reqs = 0
|
|
num_timeout_reqs = 0
|
|
for req in self.grammar_queue:
|
|
try:
|
|
if req.finished(): # It is aborted by AbortReq
|
|
num_ready_reqs += 1
|
|
continue
|
|
|
|
req.grammar = req.grammar.result(timeout=0.03)
|
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
|
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
|
req.set_finish_with_abort(error_msg)
|
|
|
|
num_ready_reqs += 1
|
|
except futures._base.TimeoutError:
|
|
req.grammar_wait_ct += 1
|
|
# NOTE(lianmin): this timeout is the waiting time of the above line. It is
|
|
# not the waiting time from it enters the grammar queue.
|
|
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
|
|
num_timeout_reqs = 1
|
|
break
|
|
|
|
if self.server_args.enable_dp_attention:
|
|
tp_size = self.attn_tp_size
|
|
tp_group = self.attn_tp_cpu_group
|
|
else:
|
|
tp_size = self.tp_size
|
|
tp_group = self.tp_cpu_group
|
|
|
|
if tp_size > 1:
|
|
# Sync across TP ranks to make sure they have the same number of ready requests
|
|
tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
|
|
torch.distributed.all_reduce(
|
|
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
|
)
|
|
num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
|
|
|
|
for i in range(num_ready_reqs, num_ready_reqs_max):
|
|
req = self.grammar_queue[i]
|
|
if req.finished(): # It is aborted by AbortReq
|
|
continue
|
|
req.grammar = req.grammar.result()
|
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
|
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
|
req.set_finish_with_abort(error_msg)
|
|
else:
|
|
num_ready_reqs_max = num_ready_reqs
|
|
num_timeout_reqs_max = num_timeout_reqs
|
|
|
|
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
|
|
req = self.grammar_queue[i]
|
|
req.grammar.cancel()
|
|
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
|
|
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
|
req.set_finish_with_abort(error_msg)
|
|
|
|
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
|
|
|
|
for req in self.grammar_queue[:num_ready_reqs]:
|
|
self._add_request_to_queue(req)
|
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
|
|
|
def watchdog_thread(self):
|
|
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
|
self.watchdog_last_forward_ct = 0
|
|
self.watchdog_last_time = time.perf_counter()
|
|
|
|
while True:
|
|
current = time.perf_counter()
|
|
if self.cur_batch is not None:
|
|
if self.watchdog_last_forward_ct == self.forward_ct:
|
|
if current > self.watchdog_last_time + self.watchdog_timeout:
|
|
break
|
|
else:
|
|
self.watchdog_last_forward_ct = self.forward_ct
|
|
self.watchdog_last_time = current
|
|
time.sleep(self.watchdog_timeout // 2)
|
|
|
|
if not disable_request_logging():
|
|
# Print batch size and memory pool info to check whether there are de-sync issues.
|
|
if self.is_hybrid:
|
|
(
|
|
_,
|
|
_,
|
|
_,
|
|
_,
|
|
full_available_size,
|
|
full_evictable_size,
|
|
swa_available_size,
|
|
swa_evictable_size,
|
|
) = self._get_swa_token_info()
|
|
info_msg = (
|
|
f"{full_available_size=}, "
|
|
f"{full_evictable_size=}, "
|
|
f"{swa_available_size=}, "
|
|
f"{swa_evictable_size=}, "
|
|
)
|
|
else:
|
|
_, _, available_size, evictable_size = self._get_token_info()
|
|
info_msg = f"{available_size=}, " f"{evictable_size=}, "
|
|
logger.error(
|
|
f"{self.cur_batch.batch_size()=}, "
|
|
f"{self.cur_batch.reqs=}, "
|
|
f"{info_msg}"
|
|
)
|
|
|
|
pyspy_dump_schedulers()
|
|
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
|
print(file=sys.stderr, flush=True)
|
|
print(file=sys.stdout, flush=True)
|
|
|
|
# Wait for some time so that the parent process can print the error.
|
|
time.sleep(5)
|
|
self.parent_process.send_signal(signal.SIGQUIT)
|
|
|
|
def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
|
|
success = self.flush_cache()
|
|
return FlushCacheReqOutput(success=success)
|
|
|
|
def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
|
|
if self.enable_hierarchical_cache:
|
|
self.tree_cache.clear_storage_backend()
|
|
logger.info("Hierarchical cache cleared successfully!")
|
|
if_success = True
|
|
else:
|
|
logging.warning("Hierarchical cache is not enabled.")
|
|
if_success = False
|
|
return ClearHiCacheReqOutput(success=if_success)
|
|
|
|
def flush_cache(self):
|
|
"""Flush the memory pool and cache."""
|
|
if (
|
|
len(self.waiting_queue) == 0
|
|
and self.running_batch.is_empty()
|
|
and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
|
|
):
|
|
self.cur_batch = None
|
|
self.last_batch = None
|
|
self.tree_cache.reset()
|
|
if self.grammar_backend:
|
|
self.grammar_backend.reset()
|
|
self.req_to_token_pool.clear()
|
|
self.token_to_kv_pool_allocator.clear()
|
|
|
|
if self.draft_worker:
|
|
self.draft_worker.clear_cache_pool()
|
|
|
|
self.num_generated_tokens = 0
|
|
self.forward_ct_decode = 0
|
|
self.spec_num_total_accepted_tokens = 0
|
|
self.spec_num_total_forward_ct = 0
|
|
self.cum_spec_accept_length = 0
|
|
self.cum_spec_accept_count = 0
|
|
torch.cuda.empty_cache()
|
|
logger.info("Cache flushed successfully!")
|
|
if_success = True
|
|
else:
|
|
logging.warning(
|
|
f"Cache not flushed because there are pending requests. "
|
|
f"#queue-req: {len(self.waiting_queue)}, "
|
|
f"#running-req: {len(self.running_batch.reqs)}"
|
|
)
|
|
if_success = False
|
|
return if_success
|
|
|
|
def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
|
|
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
|
|
|
if self.is_hybrid:
|
|
num_tokens_full = (
|
|
self.full_tokens_per_layer
|
|
- self.token_to_kv_pool_allocator.full_available_size()
|
|
- self.tree_cache.full_evictable_size()
|
|
)
|
|
num_tokens_swa = (
|
|
self.swa_tokens_per_layer
|
|
- self.token_to_kv_pool_allocator.swa_available_size()
|
|
- self.tree_cache.swa_evictable_size()
|
|
)
|
|
num_tokens = max(num_tokens_full, num_tokens_swa)
|
|
else:
|
|
num_tokens = (
|
|
self.max_total_num_tokens
|
|
- self.token_to_kv_pool_allocator.available_size()
|
|
- self.tree_cache.evictable_size()
|
|
)
|
|
|
|
# Tokens in waiting queue, bootstrap queue, prealloc queue
|
|
num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
|
num_waiting_reqs = len(self.waiting_queue)
|
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
num_tokens += sum(
|
|
len(req.origin_input_ids)
|
|
for req in self.disagg_prefill_bootstrap_queue.queue
|
|
)
|
|
num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
|
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
num_tokens += sum(
|
|
len(req.req.origin_input_ids)
|
|
for req in self.disagg_decode_prealloc_queue.queue
|
|
)
|
|
num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
|
|
|
|
return GetLoadReqOutput(
|
|
dp_rank=self.dp_rank,
|
|
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
|
|
num_waiting_reqs=num_waiting_reqs,
|
|
num_tokens=num_tokens,
|
|
)
|
|
|
|
def get_internal_state(self, recv_req: GetInternalStateReq):
|
|
ret = vars(get_global_server_args())
|
|
ret["last_gen_throughput"] = self.last_gen_throughput
|
|
ret["memory_usage"] = {
|
|
"weight": round(self.tp_worker.model_runner.weight_load_mem_usage, 2),
|
|
"kvcache": round(
|
|
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
|
|
),
|
|
"token_capacity": int(self.max_total_num_tokens),
|
|
}
|
|
|
|
ret["memory_usage"]["graph"] = round(
|
|
self.tp_worker.model_runner.graph_mem_usage, 2
|
|
)
|
|
|
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
|
ret["avg_spec_accept_length"] = (
|
|
self.cum_spec_accept_length / self.cum_spec_accept_count
|
|
)
|
|
if RECORD_STEP_TIME:
|
|
ret["step_time_dict"] = self.step_time_dict
|
|
|
|
return GetInternalStateReqOutput(internal_state=ret)
|
|
|
|
def set_internal_state(self, recv_req: SetInternalStateReq):
|
|
server_args_dict = recv_req.server_args
|
|
args_allow_update = set(
|
|
[
|
|
"pp_max_micro_batch_size",
|
|
"speculative_accept_threshold_single",
|
|
"speculative_accept_threshold_acc",
|
|
]
|
|
)
|
|
if_success = True
|
|
for k, v in server_args_dict.items():
|
|
if k not in args_allow_update:
|
|
logging.warning(f"Updating {k} is not supported.")
|
|
if_success = False
|
|
break
|
|
elif k == "pp_max_micro_batch_size" and (
|
|
v > self.max_running_requests // self.pp_size or v < 1
|
|
):
|
|
logging.warning(
|
|
f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
|
|
)
|
|
if_success = False
|
|
break
|
|
if if_success:
|
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
|
avg_spec_accept_length = (
|
|
self.cum_spec_accept_length / self.cum_spec_accept_count
|
|
)
|
|
logger.info(f"{avg_spec_accept_length=}")
|
|
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
|
|
for k, v in server_args_dict.items():
|
|
setattr(get_global_server_args(), k, v)
|
|
logger.info(f"Global server args updated! {get_global_server_args()=}")
|
|
return SetInternalStateReqOutput(
|
|
updated=True,
|
|
server_args=vars(get_global_server_args()),
|
|
)
|
|
|
|
def handle_rpc_request(self, recv_req: RpcReqInput):
|
|
# Handle RPC requests
|
|
logger.info(
|
|
f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
|
|
)
|
|
|
|
success = True
|
|
exec = None
|
|
try:
|
|
func = getattr(self, recv_req.method)
|
|
func(recv_req.parameters)
|
|
except Exception as e:
|
|
success = False
|
|
exec = e
|
|
logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
|
|
|
|
barrier()
|
|
return RpcReqOutput(success, "" if not exec else str(exec))
|
|
|
|
def abort_request(self, recv_req: AbortReq):
|
|
# Delete requests in the waiting queue
|
|
to_del = []
|
|
for i, req in enumerate(self.waiting_queue):
|
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
|
to_del.append(i)
|
|
|
|
# Sort in reverse order to avoid index issues when deleting
|
|
for i in reversed(to_del):
|
|
# Abort method 1: directly pop from the queue
|
|
# This only works for requests that have not started anything.
|
|
# We still need to send something back to TokenizerManager to clean up the state.
|
|
req = self.waiting_queue.pop(i)
|
|
if self.enable_hicache_storage:
|
|
# to release prefetch events associated with the request
|
|
self.tree_cache.release_aborted_request(req.rid)
|
|
self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req)
|
|
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
self.tree_cache.cache_finished_req(req)
|
|
|
|
logger.debug(f"Abort queued request. {req.rid=}")
|
|
|
|
# Delete the requests in the grammar queue
|
|
for req in self.grammar_queue:
|
|
# Abort method 2: call `set_finish_with_abort`
|
|
# The request will still run one prefill forward pass.
|
|
# In this case, we change the input_ids to be only one token to make this prefill cheap.
|
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
|
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
|
if req.grammar:
|
|
req.grammar.cancel()
|
|
req.set_finish_with_abort("Aborted by AbortReq.")
|
|
|
|
# Delete requests not in the waiting queue when PD disaggregation is enabled
|
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
# Abort requests that have not yet been bootstrapped
|
|
for req in self.disagg_prefill_bootstrap_queue.queue:
|
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
|
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
|
if hasattr(req.disagg_kv_sender, "abort"):
|
|
req.disagg_kv_sender.abort()
|
|
|
|
# Abort in-flight requests
|
|
for req in self.disagg_prefill_inflight_queue:
|
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
|
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
|
if hasattr(req.disagg_kv_sender, "abort"):
|
|
req.disagg_kv_sender.abort()
|
|
|
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
# Abort requests that have not yet finished preallocation
|
|
for decode_req in self.disagg_decode_prealloc_queue.queue:
|
|
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
|
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
|
if hasattr(decode_req.kv_receiver, "abort"):
|
|
decode_req.kv_receiver.abort()
|
|
|
|
# Abort requests waiting for kvcache to release tree cache
|
|
for decode_req in self.disagg_decode_transfer_queue.queue:
|
|
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
|
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
|
if hasattr(decode_req.kv_receiver, "abort"):
|
|
decode_req.kv_receiver.abort()
|
|
|
|
# Delete requests in the running batch
|
|
if self.cur_batch is self.running_batch or self.cur_batch is None:
|
|
reqs = self.running_batch.reqs
|
|
else:
|
|
reqs = self.running_batch.reqs + self.cur_batch.reqs
|
|
|
|
for req in reqs:
|
|
if not req.finished() and (
|
|
recv_req.abort_all or req.rid.startswith(recv_req.rid)
|
|
):
|
|
# Abort method 3: set `to_abort=True`
|
|
# The request will still run one decode forward pass.
|
|
# Then we reuse all existing code to clean up the KV cache allocation.
|
|
logger.debug(f"Abort running request. {req.rid=}")
|
|
req.to_abort = True
|
|
|
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
|
raise NotImplementedError()
|
|
|
|
def load_lora_adapter(
|
|
self, recv_req: LoadLoRAAdapterReqInput
|
|
) -> LoadLoRAAdapterReqOutput:
|
|
"""In-place loading a new lora adapter from disk or huggingface."""
|
|
|
|
result = self.tp_worker.load_lora_adapter(recv_req)
|
|
return result
|
|
|
|
def unload_lora_adapter(
|
|
self, recv_req: UnloadLoRAAdapterReqInput
|
|
) -> UnloadLoRAAdapterReqOutput:
|
|
"""Unload the lora adapter."""
|
|
|
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
|
return result
|
|
|
|
def init_weights_send_group_for_remote_instance(
|
|
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
|
):
|
|
"""Init the seed and client instance communication group."""
|
|
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
|
|
recv_req
|
|
)
|
|
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
|
|
|
|
def send_weights_to_remote_instance(
|
|
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
|
):
|
|
"""Send the seed instance weights to the destination instance."""
|
|
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
|
|
return SendWeightsToRemoteInstanceReqOutput(success, message)
|
|
|
|
def slow_down(self, recv_req: SlowDownReqInput):
|
|
t = recv_req.forward_sleep_time
|
|
if t is not None and t <= 0:
|
|
t = None
|
|
self.forward_sleep_time = t
|
|
return SlowDownReqOutput()
|
|
|
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
|
action = recv_req.action
|
|
if action == ExpertDistributionReqType.START_RECORD:
|
|
get_global_expert_distribution_recorder().start_record()
|
|
elif action == ExpertDistributionReqType.STOP_RECORD:
|
|
get_global_expert_distribution_recorder().stop_record()
|
|
elif action == ExpertDistributionReqType.DUMP_RECORD:
|
|
get_global_expert_distribution_recorder().dump_record()
|
|
else:
|
|
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
|
|
return ExpertDistributionReqOutput()
|
|
|
|
def open_session(self, recv_req: OpenSessionReqInput):
|
|
# handle error
|
|
session_id = recv_req.session_id
|
|
if session_id in self.sessions:
|
|
logger.warning(f"session id {session_id} already exist, cannot open.")
|
|
return OpenSessionReqOutput(session_id, False)
|
|
elif session_id is None:
|
|
logger.warning("session id is None, cannot open.")
|
|
return OpenSessionReqOutput(session_id, False)
|
|
else:
|
|
self.sessions[session_id] = Session(
|
|
recv_req.capacity_of_str_len, session_id
|
|
)
|
|
return OpenSessionReqOutput(session_id, True)
|
|
|
|
def close_session(self, recv_req: CloseSessionReqInput):
|
|
# handle error
|
|
session_id = recv_req.session_id
|
|
if session_id not in self.sessions:
|
|
logger.warning(f"session id {session_id} does not exist, cannot delete.")
|
|
else:
|
|
del self.sessions[session_id]
|
|
|
|
def get_print_prefix(self):
|
|
prefix = ""
|
|
if self.attn_dp_rank is not None:
|
|
prefix += f" DP{self.attn_dp_rank}"
|
|
if self.server_args.tp_size > 1:
|
|
prefix += f" TP{self.tp_rank}"
|
|
if self.pp_size > 1:
|
|
prefix += f" PP{self.pp_rank}"
|
|
return prefix
|
|
|
|
def current_scheduler_metrics_enabled(self):
|
|
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
|
|
|
|
def maybe_sleep_on_idle(self):
|
|
if self.idle_sleeper is not None:
|
|
self.idle_sleeper.maybe_sleep()
|
|
|
|
def handle_freeze_gc(self, recv_req: FreezeGCReq):
|
|
"""Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
|
|
freeze_gc("Scheduler")
|
|
self.send_to_detokenizer.send_output(recv_req, recv_req)
|
|
return None
|
|
|
|
|
|
class IdleSleeper:
|
|
"""
|
|
In setups which have long inactivity periods it is desirable to reduce
|
|
system power consumption when sglang does nothing. This would lead not only
|
|
to power savings, but also to more CPU thermal headroom when a request
|
|
eventually comes. This is important in cases when multiple GPUs are connected
|
|
as each GPU would otherwise pin one thread at 100% CPU usage.
|
|
|
|
The simplest solution is to use zmq.Poller on all sockets that may receive
|
|
data that needs handling immediately.
|
|
"""
|
|
|
|
def __init__(self, sockets):
|
|
self.poller = zmq.Poller()
|
|
self.last_empty_time = time.time()
|
|
for s in sockets:
|
|
self.poller.register(s, zmq.POLLIN)
|
|
|
|
self.empty_cache_interval = envs.SGLANG_EMPTY_CACHE_INTERVAL.get()
|
|
|
|
def maybe_sleep(self):
|
|
self.poller.poll(1000)
|
|
if (
|
|
self.empty_cache_interval > 0
|
|
and time.time() - self.last_empty_time > self.empty_cache_interval
|
|
):
|
|
self.last_empty_time = time.time()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def is_health_check_generate_req(recv_req):
|
|
rid = getattr(recv_req, "rid", None)
|
|
return rid is not None and rid.startswith("HEALTH_CHECK")
|
|
|
|
|
|
def is_work_request(recv_req):
|
|
return isinstance(
|
|
recv_req,
|
|
(
|
|
TokenizedGenerateReqInput,
|
|
TokenizedEmbeddingReqInput,
|
|
BatchTokenizedGenerateReqInput,
|
|
BatchTokenizedEmbeddingReqInput,
|
|
),
|
|
)
|
|
|
|
|
|
def run_scheduler_process(
|
|
server_args: ServerArgs,
|
|
port_args: PortArgs,
|
|
gpu_id: int,
|
|
tp_rank: int,
|
|
moe_ep_rank: int,
|
|
pp_rank: int,
|
|
dp_rank: Optional[int],
|
|
pipe_writer,
|
|
):
|
|
# Generate the logger prefix
|
|
prefix = ""
|
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
|
if dp_rank is not None:
|
|
prefix += f" DP{dp_rank}"
|
|
if server_args.tp_size > 1:
|
|
prefix += f" TP{tp_rank}"
|
|
if server_args.ep_size > 1:
|
|
prefix += f" EP{moe_ep_rank}"
|
|
if server_args.pp_size > 1:
|
|
prefix += f" PP{pp_rank}"
|
|
|
|
# Config the process
|
|
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
|
faulthandler.enable()
|
|
kill_itself_when_parent_died()
|
|
parent_process = psutil.Process().parent()
|
|
|
|
# Configure the logger
|
|
configure_logger(server_args, prefix=prefix)
|
|
suppress_other_loggers()
|
|
|
|
# Set cpu affinity to this gpu process
|
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
|
set_gpu_proc_affinity(
|
|
server_args.pp_size, server_args.tp_size, server_args.nnodes, gpu_id
|
|
)
|
|
if (numa_node := server_args.numa_node) is not None:
|
|
numa_bind_to_node(numa_node[gpu_id])
|
|
|
|
# Set up tracing
|
|
if server_args.enable_trace:
|
|
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
|
if server_args.disaggregation_mode == "null":
|
|
thread_label = "Scheduler"
|
|
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
|
|
|
# Create a scheduler and run the event loop
|
|
try:
|
|
scheduler = Scheduler(
|
|
server_args,
|
|
port_args,
|
|
gpu_id,
|
|
tp_rank,
|
|
moe_ep_rank,
|
|
pp_rank,
|
|
dp_rank,
|
|
)
|
|
pipe_writer.send(
|
|
{
|
|
"status": "ready",
|
|
"max_total_num_tokens": scheduler.max_total_num_tokens,
|
|
"max_req_input_len": scheduler.max_req_input_len,
|
|
}
|
|
)
|
|
|
|
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
|
if disaggregation_mode == DisaggregationMode.NULL:
|
|
if server_args.pp_size > 1:
|
|
scheduler.event_loop_pp()
|
|
elif scheduler.enable_overlap:
|
|
scheduler.event_loop_overlap()
|
|
else:
|
|
scheduler.event_loop_normal()
|
|
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
|
if scheduler.enable_overlap:
|
|
scheduler.event_loop_overlap_disagg_prefill()
|
|
else:
|
|
if server_args.pp_size > 1:
|
|
scheduler.event_loop_pp_disagg_prefill()
|
|
else:
|
|
scheduler.event_loop_normal_disagg_prefill()
|
|
|
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
|
if scheduler.enable_overlap:
|
|
scheduler.event_loop_overlap_disagg_decode()
|
|
else:
|
|
scheduler.event_loop_normal_disagg_decode()
|
|
|
|
except Exception:
|
|
traceback = get_exception_traceback()
|
|
logger.error(f"Scheduler hit an exception: {traceback}")
|
|
parent_process.send_signal(signal.SIGQUIT)
|