2159 lines
86 KiB
Python
2159 lines
86 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""ModelRunner runs the forward passes of the models."""
|
|
|
|
import datetime
|
|
import gc
|
|
import inspect
|
|
import json
|
|
import logging
|
|
import os
|
|
import socket
|
|
import threading
|
|
import time
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Tuple, Union
|
|
from urllib.parse import urlparse
|
|
|
|
import requests
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from sglang.srt.configs.device_config import DeviceConfig
|
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
|
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
|
from sglang.srt.connector import ConnectorType
|
|
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
|
from sglang.srt.distributed import (
|
|
get_pp_group,
|
|
get_tp_group,
|
|
get_world_group,
|
|
init_distributed_environment,
|
|
initialize_model_parallel,
|
|
set_custom_all_reduce,
|
|
set_mscclpp_all_reduce,
|
|
)
|
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
|
from sglang.srt.eplb.eplb_manager import EPLBManager
|
|
from sglang.srt.eplb.expert_distribution import (
|
|
ExpertDistributionRecorder,
|
|
get_global_expert_distribution_recorder,
|
|
set_global_expert_distribution_recorder,
|
|
)
|
|
from sglang.srt.eplb.expert_location import (
|
|
ExpertLocationMetadata,
|
|
compute_initial_expert_location_metadata,
|
|
get_global_expert_location_metadata,
|
|
set_global_expert_location_metadata,
|
|
)
|
|
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
|
|
from sglang.srt.layers.attention.attention_registry import ATTENTION_BACKENDS
|
|
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
|
from sglang.srt.layers.dp_attention import (
|
|
get_attention_tp_group,
|
|
get_attention_tp_size,
|
|
initialize_dp_attention,
|
|
)
|
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
from sglang.srt.layers.quantization import (
|
|
deep_gemm_wrapper,
|
|
monkey_patch_isinstance_for_vllm_base_layer,
|
|
)
|
|
from sglang.srt.layers.sampler import Sampler
|
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
|
from sglang.srt.lora.lora_manager import LoRAManager
|
|
from sglang.srt.lora.lora_registry import LoRARef
|
|
from sglang.srt.managers.schedule_batch import (
|
|
GLOBAL_SERVER_ARGS_KEYS,
|
|
global_server_args_dict,
|
|
)
|
|
from sglang.srt.mem_cache.allocator import (
|
|
BaseTokenToKVPoolAllocator,
|
|
PagedTokenToKVPoolAllocator,
|
|
SWATokenToKVPoolAllocator,
|
|
TokenToKVPoolAllocator,
|
|
)
|
|
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
|
|
from sglang.srt.mem_cache.memory_pool import (
|
|
AscendMLAPagedTokenToKVPool,
|
|
AscendTokenToKVPool,
|
|
DoubleSparseTokenToKVPool,
|
|
HybridLinearKVPool,
|
|
HybridReqToTokenPool,
|
|
MHATokenToKVPool,
|
|
MLATokenToKVPool,
|
|
ReqToTokenPool,
|
|
SWAKVPool,
|
|
)
|
|
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
|
from sglang.srt.model_loader import get_model
|
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
|
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
|
trigger_init_weights_send_group_for_remote_instance_request,
|
|
)
|
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
from sglang.srt.offloader import (
|
|
create_offloader_from_server_args,
|
|
get_offloader,
|
|
set_offloader,
|
|
)
|
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
from sglang.srt.server_args import ServerArgs
|
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
from sglang.srt.utils import (
|
|
MultiprocessingSerializer,
|
|
cpu_has_amx_support,
|
|
dynamic_import,
|
|
enable_show_time_cost,
|
|
get_available_gpu_memory,
|
|
get_bool_env_var,
|
|
get_cpu_ids_by_node,
|
|
init_custom_process_group,
|
|
is_blackwell,
|
|
is_fa3_default_architecture,
|
|
is_flashinfer_available,
|
|
is_hip,
|
|
is_hopper_with_cuda_12_3,
|
|
is_no_spec_infer_or_topk_one,
|
|
is_npu,
|
|
is_sm100_supported,
|
|
log_info_on_rank0,
|
|
monkey_patch_p2p_access_check,
|
|
monkey_patch_vllm_gguf_config,
|
|
parse_connector_type,
|
|
set_cuda_arch,
|
|
)
|
|
from sglang.srt.weight_sync.tensor_bucket import (
|
|
FlattenedTensorBucket,
|
|
FlattenedTensorMetadata,
|
|
)
|
|
|
|
MLA_ATTENTION_BACKENDS = [
|
|
"aiter",
|
|
"flashinfer",
|
|
"fa3",
|
|
"fa4",
|
|
"triton",
|
|
"flashmla",
|
|
"cutlass_mla",
|
|
"trtllm_mla",
|
|
"ascend",
|
|
]
|
|
|
|
|
|
def add_mla_attention_backend(backend_name):
|
|
if backend_name not in MLA_ATTENTION_BACKENDS:
|
|
MLA_ATTENTION_BACKENDS.append(backend_name)
|
|
logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")
|
|
|
|
|
|
_is_hip = is_hip()
|
|
_is_npu = is_npu()
|
|
_is_cpu_amx_available = cpu_has_amx_support()
|
|
|
|
# Use a small KV cache pool size for tests in CI
|
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
|
|
|
# Detect stragger ranks in model loading
|
|
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
if _is_npu:
|
|
import torch_npu
|
|
|
|
torch.npu.config.allow_internal_format = True
|
|
torch_npu.npu.set_compile_mode(jit_compile=False)
|
|
|
|
|
|
class RankZeroFilter(logging.Filter):
|
|
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
|
|
|
|
def __init__(self, is_rank_zero):
|
|
super().__init__()
|
|
self.is_rank_zero = is_rank_zero
|
|
|
|
def filter(self, record):
|
|
if record.levelno == logging.INFO:
|
|
return self.is_rank_zero
|
|
return True
|
|
|
|
|
|
class ModelRunner:
|
|
"""ModelRunner runs the forward passes of the models."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_config: ModelConfig,
|
|
mem_fraction_static: float,
|
|
gpu_id: int,
|
|
tp_rank: int,
|
|
tp_size: int,
|
|
moe_ep_rank: int,
|
|
moe_ep_size: int,
|
|
pp_rank: int,
|
|
pp_size: int,
|
|
nccl_port: int,
|
|
server_args: ServerArgs,
|
|
dp_rank: Optional[int] = None,
|
|
is_draft_worker: bool = False,
|
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
|
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
|
):
|
|
# Parse args
|
|
self.mem_fraction_static = mem_fraction_static
|
|
self.device = server_args.device
|
|
self.gpu_id = gpu_id
|
|
self.tp_rank = tp_rank
|
|
self.tp_size = tp_size
|
|
self.moe_ep_rank = moe_ep_rank
|
|
self.moe_ep_size = moe_ep_size
|
|
self.dp_size = server_args.dp_size
|
|
self.pp_rank = pp_rank
|
|
self.pp_size = pp_size
|
|
self.model_config = model_config
|
|
self.dist_port = nccl_port
|
|
self.server_args = server_args
|
|
self.is_draft_worker = is_draft_worker
|
|
self.is_generation = model_config.is_generation
|
|
self.is_multimodal = model_config.is_multimodal
|
|
self.is_multimodal_chunked_prefill_supported = (
|
|
model_config.is_multimodal_chunked_prefill_supported
|
|
)
|
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
|
server_args.speculative_algorithm
|
|
)
|
|
self.page_size = server_args.page_size
|
|
self.req_to_token_pool = req_to_token_pool
|
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
|
self.is_hybrid = model_config.is_hybrid
|
|
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
|
self.attention_chunk_size = model_config.attention_chunk_size
|
|
self.forward_pass_id = 0
|
|
|
|
# Apply the rank zero filter to logger
|
|
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
|
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
|
if server_args.show_time_cost:
|
|
enable_show_time_cost()
|
|
|
|
# Model-specific adjustment
|
|
self.model_specific_adjustment()
|
|
|
|
# Global vars
|
|
global_server_args_dict.update(
|
|
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
|
|
| {
|
|
# TODO it is indeed not a "server args"
|
|
"use_mla_backend": self.use_mla_backend,
|
|
"speculative_algorithm": self.spec_algorithm,
|
|
}
|
|
)
|
|
|
|
# Init OpenMP threads binding for CPU
|
|
if self.device == "cpu":
|
|
self.init_threads_binding()
|
|
|
|
# Get memory before model loading
|
|
min_per_gpu_memory = self.init_torch_distributed()
|
|
|
|
# CPU offload
|
|
set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
|
|
|
|
# Update deep gemm configure
|
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
|
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
|
|
|
# Initialize the model runner
|
|
self.initialize(min_per_gpu_memory)
|
|
|
|
# Temporary cached values
|
|
self.support_pp = (
|
|
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
|
)
|
|
|
|
# For weight updates
|
|
self._model_update_group = {}
|
|
self._weights_send_group = {}
|
|
|
|
def initialize(self, min_per_gpu_memory: float):
|
|
server_args = self.server_args
|
|
|
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
|
enable=self.server_args.enable_memory_saver
|
|
)
|
|
|
|
if not self.is_draft_worker:
|
|
set_global_expert_location_metadata(
|
|
compute_initial_expert_location_metadata(server_args, self.model_config)
|
|
)
|
|
if self.tp_rank == 0 and get_bool_env_var(
|
|
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
|
|
):
|
|
logger.info(
|
|
f"Initial expert_location_metadata: {get_global_expert_location_metadata()}"
|
|
)
|
|
|
|
set_global_expert_distribution_recorder(
|
|
ExpertDistributionRecorder.init_new(
|
|
server_args,
|
|
get_global_expert_location_metadata(),
|
|
rank=self.tp_rank,
|
|
)
|
|
)
|
|
|
|
# Expert parallelism
|
|
self.eplb_manager = (
|
|
EPLBManager(self)
|
|
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
|
else None
|
|
)
|
|
self.expert_location_updater = ExpertLocationUpdater()
|
|
|
|
# Load the model
|
|
self.sampler = Sampler()
|
|
self.load_model()
|
|
|
|
# Check if the model is using hybrid SWA
|
|
if (
|
|
not self.server_args.disable_hybrid_swa_memory
|
|
and self.sliding_window_size is not None
|
|
and self.sliding_window_size > 0
|
|
):
|
|
architectures = self.model_config.hf_config.architectures
|
|
if architectures and not any("Llama4" in arch for arch in architectures):
|
|
self.is_hybrid = self.model_config.is_hybrid = True
|
|
|
|
if self.is_hybrid_gdn:
|
|
logger.warning("Hybrid GDN model detected, disable radix cache")
|
|
self.server_args.disable_radix_cache = True
|
|
self.server_args.attention_backend = "hybrid_linear_attn"
|
|
if self.server_args.max_mamba_cache_size is None:
|
|
if self.server_args.max_running_requests is not None:
|
|
self.server_args.max_mamba_cache_size = (
|
|
self.server_args.max_running_requests
|
|
)
|
|
else:
|
|
self.server_args.max_mamba_cache_size = 512
|
|
self.server_args.max_mamba_cache_size = (
|
|
self.server_args.max_mamba_cache_size
|
|
// (
|
|
self.server_args.dp_size
|
|
if self.server_args.enable_dp_attention
|
|
else 1
|
|
)
|
|
)
|
|
|
|
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
|
|
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
|
|
# determine the number of layers.
|
|
model_has_mtp_layers = self.model_config.num_nextn_predict_layers is not None
|
|
model_num_layers = (
|
|
self.model_config.num_nextn_predict_layers
|
|
if self.is_draft_worker and model_has_mtp_layers
|
|
else max(
|
|
self.model_config.num_hidden_layers,
|
|
self.model_config.num_attention_layers,
|
|
)
|
|
)
|
|
self.start_layer = getattr(self.model, "start_layer", 0)
|
|
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
|
|
self.num_effective_layers = self.end_layer - self.start_layer
|
|
assert (
|
|
(not model_has_mtp_layers)
|
|
or (self.spec_algorithm.is_none())
|
|
or (
|
|
(not self.spec_algorithm.is_none())
|
|
and (self.num_effective_layers == model_num_layers)
|
|
)
|
|
), "PP is not compatible with MTP models."
|
|
|
|
# Apply torchao quantization
|
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
|
# In layered loading, torchao may have been applied
|
|
if not torchao_applied:
|
|
apply_torchao_config_to_model(
|
|
self.model, global_server_args_dict["torchao_config"]
|
|
)
|
|
|
|
# Apply torch TP if the model supports it
|
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
|
if self.tp_size > 1 and supports_torch_tp:
|
|
self.apply_torch_tp()
|
|
|
|
# Init lora
|
|
if server_args.enable_lora:
|
|
self.init_lora_manager()
|
|
|
|
# Init Double Sparsity
|
|
if server_args.enable_double_sparsity:
|
|
if server_args.ds_heavy_channel_type is None:
|
|
raise ValueError(
|
|
"Please specify the heavy channel type for double sparsity optimization."
|
|
)
|
|
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
|
|
|
# Enable batch invariant mode
|
|
if server_args.enable_deterministic_inference:
|
|
from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
|
|
|
|
enable_batch_invariant_mode()
|
|
|
|
# Init memory pool and attention backends
|
|
self.init_memory_pool(
|
|
min_per_gpu_memory,
|
|
server_args.max_running_requests,
|
|
server_args.max_total_tokens,
|
|
)
|
|
if self.device == "cuda":
|
|
self.init_cublas()
|
|
self.init_attention_backend()
|
|
self.init_device_graphs()
|
|
elif self.device in ["npu", "cpu"]:
|
|
self.init_attention_backend()
|
|
self.init_device_graphs()
|
|
else:
|
|
self.graph_runner = None
|
|
self.graph_mem_usage = 0
|
|
self.init_attention_backend()
|
|
|
|
# auxiliary hidden capture mode. TODO: expose this to server args?
|
|
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
|
|
# load draft config
|
|
draft_model_config = ModelConfig.from_server_args(
|
|
server_args,
|
|
model_path=(server_args.speculative_draft_model_path),
|
|
is_draft_model=True,
|
|
)
|
|
|
|
try:
|
|
# get the aux layer from draft model config
|
|
eagle_config = getattr(
|
|
draft_model_config.hf_config, "eagle_config", None
|
|
)
|
|
eagle_aux_hidden_state_layer_ids = eagle_config[
|
|
"eagle_aux_hidden_state_layer_ids"
|
|
]
|
|
except:
|
|
# if there is no aux layer, set to None
|
|
eagle_aux_hidden_state_layer_ids = None
|
|
|
|
self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
|
|
|
|
def model_specific_adjustment(self):
|
|
server_args = self.server_args
|
|
|
|
if (
|
|
server_args.attention_backend == "intel_amx"
|
|
and server_args.device == "cpu"
|
|
and not _is_cpu_amx_available
|
|
):
|
|
logger.info(
|
|
"The current platform does not support Intel AMX, will fallback to torch_native backend."
|
|
)
|
|
server_args.attention_backend = "torch_native"
|
|
|
|
if server_args.prefill_attention_backend is not None and (
|
|
server_args.prefill_attention_backend
|
|
== server_args.decode_attention_backend
|
|
): # override the default attention backend
|
|
server_args.attention_backend = server_args.prefill_attention_backend
|
|
|
|
if (
|
|
getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
|
|
is not None
|
|
):
|
|
if server_args.attention_backend is None:
|
|
server_args.attention_backend = "dual_chunk_flash_attn"
|
|
logger.info("Dual chunk attention is turned on by default.")
|
|
elif server_args.attention_backend != "dual_chunk_flash_attn":
|
|
raise ValueError(
|
|
"Dual chunk attention is enabled, but attention backend is set to "
|
|
f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
|
|
)
|
|
|
|
if server_args.attention_backend is None:
|
|
"""
|
|
Auto select the fastest attention backend.
|
|
|
|
1. Models with MHA Architecture (e.g: Llama, QWen)
|
|
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
|
|
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
|
|
2. Models with MLA Architecture and using FA3
|
|
2.1 We will use FA3 backend on hopper.
|
|
2.2 We will use Flashinfer backend on blackwell.
|
|
2.3 Otherwise, we will use triton backend.
|
|
"""
|
|
|
|
if not self.use_mla_backend:
|
|
# MHA architecture
|
|
if (
|
|
is_hopper_with_cuda_12_3()
|
|
and is_no_spec_infer_or_topk_one(server_args)
|
|
and is_fa3_default_architecture(self.model_config.hf_config)
|
|
):
|
|
server_args.attention_backend = "fa3"
|
|
elif _is_hip:
|
|
server_args.attention_backend = "aiter"
|
|
elif _is_npu:
|
|
server_args.attention_backend = "ascend"
|
|
else:
|
|
server_args.attention_backend = (
|
|
"flashinfer" if is_flashinfer_available() else "triton"
|
|
)
|
|
else:
|
|
# MLA architecture
|
|
if is_hopper_with_cuda_12_3():
|
|
server_args.attention_backend = "fa3"
|
|
elif is_sm100_supported():
|
|
server_args.attention_backend = "flashinfer"
|
|
elif _is_hip:
|
|
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
|
# TODO current aiter only support head number 16 or 128 head number
|
|
if (
|
|
head_num == 128 or head_num == 16
|
|
) and self.spec_algorithm.is_none():
|
|
server_args.attention_backend = "aiter"
|
|
else:
|
|
server_args.attention_backend = "triton"
|
|
elif _is_npu:
|
|
server_args.attention_backend = "ascend"
|
|
else:
|
|
server_args.attention_backend = "triton"
|
|
logger.info(
|
|
f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default."
|
|
)
|
|
elif self.use_mla_backend:
|
|
if server_args.device != "cpu":
|
|
if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
|
|
logger.info(
|
|
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
|
)
|
|
else:
|
|
if server_args.attention_backend != "intel_amx":
|
|
raise ValueError(
|
|
"MLA optimization not supported on CPU except for intel_amx backend."
|
|
)
|
|
|
|
if (
|
|
server_args.attention_backend == "fa3"
|
|
and server_args.kv_cache_dtype == "fp8_e5m2"
|
|
):
|
|
logger.warning(
|
|
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
|
|
"Setting attention backend to triton."
|
|
)
|
|
server_args.attention_backend = "triton"
|
|
|
|
if server_args.enable_double_sparsity:
|
|
logger.info(
|
|
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
|
)
|
|
server_args.attention_backend = "triton"
|
|
server_args.disable_cuda_graph = True
|
|
|
|
if self.is_multimodal:
|
|
if not self.is_multimodal_chunked_prefill_supported:
|
|
server_args.chunked_prefill_size = -1
|
|
logger.info(
|
|
f"Automatically turn off --chunked-prefill-size as it is not supported for "
|
|
f"{self.model_config.hf_config.model_type}"
|
|
)
|
|
|
|
if not self.use_mla_backend:
|
|
server_args.disable_chunked_prefix_cache = True
|
|
|
|
if not server_args.disable_chunked_prefix_cache:
|
|
logger.info("Chunked prefix cache is turned on.")
|
|
|
|
if server_args.attention_backend == "aiter":
|
|
if self.model_config.context_len > 8192:
|
|
self.mem_fraction_static *= 0.85
|
|
|
|
if (
|
|
server_args.enable_hierarchical_cache
|
|
and server_args.hicache_io_backend == "kernel"
|
|
):
|
|
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
|
|
if server_args.decode_attention_backend is None:
|
|
if not self.use_mla_backend:
|
|
server_args.decode_attention_backend = (
|
|
"flashinfer" if is_flashinfer_available() else "triton"
|
|
)
|
|
else:
|
|
server_args.decode_attention_backend = (
|
|
"flashinfer" if is_sm100_supported() else "triton"
|
|
)
|
|
elif server_args.decode_attention_backend == "fa3":
|
|
server_args.hicache_io_backend = "direct"
|
|
logger.warning(
|
|
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
|
f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
|
)
|
|
|
|
def init_torch_distributed(self):
|
|
logger.info("Init torch distributed begin.")
|
|
|
|
try:
|
|
torch.get_device_module(self.device).set_device(self.gpu_id)
|
|
except Exception:
|
|
logger.warning(
|
|
f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
|
|
)
|
|
raise
|
|
|
|
if self.device == "cuda":
|
|
backend = "nccl"
|
|
elif self.device == "xpu":
|
|
backend = "xccl"
|
|
elif self.device == "hpu":
|
|
backend = "hccl"
|
|
elif self.device == "cpu":
|
|
backend = "gloo"
|
|
elif self.device == "npu":
|
|
backend = "hccl"
|
|
|
|
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
|
if not self.server_args.enable_p2p_check:
|
|
monkey_patch_p2p_access_check()
|
|
|
|
if self.server_args.dist_init_addr:
|
|
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
|
else:
|
|
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
|
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
|
|
|
|
if not self.is_draft_worker:
|
|
if self.device == "cpu":
|
|
if _is_cpu_amx_available:
|
|
# Bind OpenMP threads to CPU cores
|
|
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
|
|
|
|
# Set local size to hint SGLang to use shared memory based AllReduce
|
|
os.environ["LOCAL_SIZE"] = str(self.tp_size)
|
|
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
|
|
|
|
@torch.library.register_fake("sgl_kernel::shm_allgather")
|
|
def _(data, dim):
|
|
return torch.cat([data] * self.tp_size, dim=dim)
|
|
|
|
else:
|
|
logger.warning(
|
|
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
|
|
)
|
|
|
|
# Only initialize the distributed environment on the target model worker.
|
|
init_distributed_environment(
|
|
backend=backend,
|
|
world_size=self.tp_size * self.pp_size,
|
|
rank=self.tp_size * self.pp_rank + self.tp_rank,
|
|
local_rank=self.gpu_id,
|
|
distributed_init_method=dist_init_method,
|
|
timeout=self.server_args.dist_timeout,
|
|
)
|
|
initialize_model_parallel(
|
|
tensor_model_parallel_size=self.tp_size,
|
|
pipeline_model_parallel_size=self.pp_size,
|
|
expert_model_parallel_size=self.moe_ep_size,
|
|
duplicate_tp_group=self.server_args.enable_pdmux,
|
|
)
|
|
initialize_dp_attention(
|
|
server_args=self.server_args,
|
|
model_config=self.model_config,
|
|
)
|
|
|
|
min_per_gpu_memory = get_available_gpu_memory(
|
|
self.device,
|
|
self.gpu_id,
|
|
distributed=get_world_group().world_size > 1,
|
|
cpu_group=get_world_group().cpu_group,
|
|
)
|
|
self.tp_group = get_tp_group()
|
|
self.pp_group = get_pp_group()
|
|
self.attention_tp_group = get_attention_tp_group()
|
|
|
|
# Check memory for tensor parallelism
|
|
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
|
if self.tp_size > 1 and not self.is_draft_worker:
|
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
|
if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
|
|
logger.warning(
|
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
|
|
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
|
|
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
|
|
)
|
|
|
|
logger.info(
|
|
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
|
|
)
|
|
return min_per_gpu_memory
|
|
|
|
def load_model(self):
|
|
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
|
logger.info(
|
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
|
)
|
|
|
|
# This can reduce thread conflicts and speed up weight loading.
|
|
if self.device != "cpu":
|
|
torch.set_num_threads(1)
|
|
if self.device == "cuda":
|
|
if torch.cuda.get_device_capability()[0] < 8:
|
|
logger.info(
|
|
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
|
)
|
|
self.server_args.dtype = "float16"
|
|
self.model_config.dtype = torch.float16
|
|
if torch.cuda.get_device_capability()[1] < 5:
|
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
|
|
|
set_cuda_arch()
|
|
|
|
# Prepare the model config
|
|
self.load_config = LoadConfig(
|
|
load_format=self.server_args.load_format,
|
|
download_dir=self.server_args.download_dir,
|
|
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
|
tp_rank=self.tp_rank,
|
|
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
|
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
|
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
|
)
|
|
if self.device == "cpu":
|
|
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
|
self.model_config, self.load_config, self.tp_size
|
|
)
|
|
if self.server_args.load_format == "gguf":
|
|
monkey_patch_vllm_gguf_config()
|
|
|
|
if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
|
|
if self.tp_rank == 0:
|
|
instance_ip = socket.gethostbyname(socket.gethostname())
|
|
t = threading.Thread(
|
|
target=trigger_init_weights_send_group_for_remote_instance_request,
|
|
args=(
|
|
self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
|
self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
|
self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
|
instance_ip,
|
|
),
|
|
)
|
|
t.start()
|
|
|
|
# Load the model
|
|
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
|
monkey_patch_vllm_parallel_state()
|
|
monkey_patch_isinstance_for_vllm_base_layer()
|
|
|
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
|
|
self.model = get_model(
|
|
model_config=self.model_config,
|
|
load_config=self.load_config,
|
|
device_config=DeviceConfig(self.device, self.gpu_id),
|
|
)
|
|
monkey_patch_vllm_parallel_state(reverse=True)
|
|
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
|
|
|
get_offloader().post_init()
|
|
|
|
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
|
if self.server_args.quantization_param_path is not None:
|
|
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
|
self.model.load_kv_cache_scales(
|
|
self.server_args.quantization_param_path
|
|
)
|
|
logger.info(
|
|
"Loaded KV cache scaling factors from %s",
|
|
self.server_args.quantization_param_path,
|
|
)
|
|
else:
|
|
raise RuntimeError(
|
|
"Using FP8 KV cache and scaling factors provided but "
|
|
"model %s does not support loading scaling factors.",
|
|
self.model.__class__,
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Using FP8 KV cache but no scaling factors "
|
|
"provided. Defaulting to scaling factors of 1.0. "
|
|
"This may lead to less accurate results!"
|
|
)
|
|
|
|
# Parse other args
|
|
self.sliding_window_size = None
|
|
if hasattr(self.model, "get_attention_sliding_window_size"):
|
|
self.sliding_window_size = self.model.get_attention_sliding_window_size()
|
|
elif self.model_config.attention_chunk_size is not None:
|
|
self.sliding_window_size = self.model_config.attention_chunk_size
|
|
logger.info(
|
|
f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
|
|
)
|
|
|
|
self.dtype = self.model_config.dtype
|
|
|
|
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
|
self.weight_load_mem_usage = before_avail_memory - after_avail_memory
|
|
logger.info(
|
|
f"Load weight end. "
|
|
f"type={type(self.model).__name__}, "
|
|
f"dtype={self.dtype}, "
|
|
f"avail mem={after_avail_memory:.2f} GB, "
|
|
f"mem usage={self.weight_load_mem_usage:.2f} GB."
|
|
)
|
|
|
|
# Handle the case where some ranks do not finish loading.
|
|
try:
|
|
dist.monitored_barrier(
|
|
group=get_tp_group().cpu_group,
|
|
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
|
wait_all_ranks=True,
|
|
)
|
|
except RuntimeError:
|
|
raise ValueError(
|
|
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
|
) from None
|
|
|
|
def update_expert_location(
|
|
self,
|
|
new_expert_location_metadata: ExpertLocationMetadata,
|
|
update_layer_ids: List[int],
|
|
):
|
|
self.expert_location_updater.update(
|
|
self.model.routed_experts_weights_of_layer,
|
|
new_expert_location_metadata,
|
|
update_layer_ids=update_layer_ids,
|
|
nnodes=self.server_args.nnodes,
|
|
rank=self.tp_rank,
|
|
)
|
|
|
|
def update_weights_from_disk(
|
|
self, model_path: str, load_format: str
|
|
) -> tuple[bool, str]:
|
|
"""Update engine weights in-place from the disk."""
|
|
logger.info(
|
|
f"Update engine weights online from disk begin. "
|
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
|
)
|
|
|
|
target_device = torch.device(self.device)
|
|
self.model_config.model_path = model_path
|
|
load_config = LoadConfig(load_format=load_format)
|
|
|
|
# Only support DefaultModelLoader for now
|
|
loader = get_model_loader(load_config)
|
|
if not isinstance(loader, DefaultModelLoader):
|
|
message = f"Failed to get model loader: {loader}."
|
|
return False, message
|
|
|
|
def get_weight_iter(config):
|
|
iter = loader._get_weights_iterator(
|
|
DefaultModelLoader.Source.init_new(config, self.model)
|
|
)
|
|
return iter
|
|
|
|
def model_load_weights(model, iter):
|
|
DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
|
|
return model
|
|
|
|
with set_default_torch_dtype(self.model_config.dtype):
|
|
try:
|
|
iter = get_weight_iter(self.model_config)
|
|
except Exception as e:
|
|
message = f"Failed to get weights iterator: {e}."
|
|
return False, message
|
|
try:
|
|
model = model_load_weights(self.model, iter)
|
|
except Exception as e:
|
|
message = (
|
|
f"Failed to update weights: {e}.\nRolling back to original weights."
|
|
)
|
|
del iter
|
|
gc.collect()
|
|
iter = get_weight_iter(self.model_config)
|
|
self.model = model_load_weights(self.model, iter)
|
|
return False, message
|
|
|
|
self.model = model
|
|
self.server_args.model_path = model_path
|
|
self.server_args.load_format = load_format
|
|
self.load_config = load_config
|
|
|
|
logger.info("Update weights end.")
|
|
return True, "Succeeded to update model weights."
|
|
|
|
def init_weights_send_group_for_remote_instance(
|
|
self,
|
|
master_address,
|
|
ports,
|
|
group_rank,
|
|
world_size,
|
|
group_name,
|
|
backend="nccl",
|
|
):
|
|
assert (
|
|
torch.distributed.is_initialized()
|
|
), "Default torch process group must be initialized"
|
|
assert group_name != "", "Group name cannot be empty"
|
|
|
|
ports_list = ports.split(",")
|
|
assert (
|
|
len(ports_list) == self.tp_size
|
|
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
|
group_port = ports_list[self.tp_rank]
|
|
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
|
|
|
logger.info(
|
|
f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
|
|
f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
|
)
|
|
|
|
torch.cuda.empty_cache()
|
|
success = False
|
|
message = ""
|
|
try:
|
|
self._weights_send_group[group_name] = init_custom_process_group(
|
|
backend=backend,
|
|
init_method=f"tcp://{master_address}:{group_port}",
|
|
world_size=world_size,
|
|
rank=group_rank,
|
|
group_name=group_name,
|
|
device_id=torch.device("cuda", self.gpu_id),
|
|
)
|
|
dist.barrier(group=self._weights_send_group[group_name])
|
|
success = True
|
|
message = (
|
|
f"Succeeded to init group through {master_address}:{group_port} group."
|
|
)
|
|
except Exception as e:
|
|
message = f"Failed to init group: {e}."
|
|
logger.error(message)
|
|
|
|
torch.cuda.empty_cache()
|
|
return success, message
|
|
|
|
def send_weights_to_remote_instance(
|
|
self,
|
|
master_address,
|
|
ports,
|
|
group_name,
|
|
):
|
|
assert (
|
|
torch.distributed.is_initialized()
|
|
), "Default torch process group must be initialized"
|
|
assert group_name != "", "Group name cannot be empty"
|
|
|
|
ports_list = ports.split(",")
|
|
assert (
|
|
len(ports_list) == self.tp_size
|
|
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
|
group_port = ports_list[self.tp_rank]
|
|
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
|
|
|
if self._weights_send_group[group_name] is not None:
|
|
send_group = self._weights_send_group[group_name]
|
|
else:
|
|
message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
|
|
logger.error(message)
|
|
return False, message
|
|
|
|
torch.cuda.empty_cache()
|
|
success = False
|
|
message = ""
|
|
try:
|
|
for _, weights in self.model.named_parameters():
|
|
torch.distributed.broadcast(
|
|
weights,
|
|
src=0,
|
|
group=send_group,
|
|
)
|
|
success = True
|
|
message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
|
|
except Exception as e:
|
|
message = f"Failed to send weights: {e}."
|
|
logger.error(message)
|
|
|
|
# destroy the process group after sending weights
|
|
del self._weights_send_group[group_name]
|
|
torch.distributed.distributed_c10d.destroy_process_group(send_group)
|
|
torch.cuda.empty_cache()
|
|
return success, message
|
|
|
|
def init_weights_update_group(
|
|
self,
|
|
master_address,
|
|
master_port,
|
|
rank_offset,
|
|
world_size,
|
|
group_name,
|
|
backend="nccl",
|
|
):
|
|
"""Initialize the Torch process group for model parameter updates.
|
|
|
|
`_model_update_group` is used in the RLHF workflow, where rank
|
|
0 is the actor model in the training engine, and the other ranks are
|
|
the inference engine, which is used for rollout.
|
|
|
|
In the RLHF workflow, the training engine updates the model
|
|
weights/parameters online, and broadcasts them to the inference
|
|
engine through the `_model_update_group` process group.
|
|
"""
|
|
assert (
|
|
torch.distributed.is_initialized()
|
|
), "Default torch process group must be initialized"
|
|
assert group_name != "", "Group name cannot be empty"
|
|
|
|
rank = rank_offset + self.tp_rank
|
|
|
|
logger.info(
|
|
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
|
f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
|
)
|
|
|
|
try:
|
|
self._model_update_group[group_name] = init_custom_process_group(
|
|
backend=backend,
|
|
init_method=f"tcp://{master_address}:{master_port}",
|
|
world_size=world_size,
|
|
rank=rank,
|
|
group_name=group_name,
|
|
)
|
|
return True, "Succeeded to initialize custom process group."
|
|
except Exception as e:
|
|
message = f"Failed to initialize custom process group: {e}."
|
|
logger.error(message)
|
|
return False, message
|
|
|
|
def destroy_weights_update_group(self, group_name):
|
|
try:
|
|
if group_name in self._model_update_group:
|
|
pg = self._model_update_group.pop(group_name)
|
|
torch.distributed.destroy_process_group(pg)
|
|
return True, "Succeeded to destroy custom process group."
|
|
else:
|
|
return False, "The group to be destroyed does not exist."
|
|
except Exception as e:
|
|
message = f"Failed to destroy custom process group: {e}."
|
|
logger.error(message)
|
|
return False, message
|
|
|
|
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
|
|
"""
|
|
Update specific parameter in the model weights online
|
|
through `_model_update_group` process group.
|
|
|
|
Args:
|
|
name: the name of the parameter to be updated.
|
|
dtype: the data type of the parameter to be updated.
|
|
shape: the shape of the parameter to be updated.
|
|
"""
|
|
|
|
assert group_name in self._model_update_group, (
|
|
f"Group {group_name} not in {list(self._model_update_group.keys())}. "
|
|
"Please call `init_weights_update_group` first."
|
|
)
|
|
|
|
try:
|
|
weights = []
|
|
handles = []
|
|
for name, dtype, shape in zip(names, dtypes, shapes):
|
|
target_dtype = (
|
|
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
|
)
|
|
weight = torch.empty(shape, dtype=target_dtype, device=self.device)
|
|
handles.append(
|
|
torch.distributed.broadcast(
|
|
weight,
|
|
src=0,
|
|
group=self._model_update_group[group_name],
|
|
async_op=True,
|
|
)
|
|
)
|
|
weights.append((name, weight))
|
|
for handle in handles:
|
|
handle.wait()
|
|
|
|
self.model.load_weights(weights)
|
|
return True, f"Succeeded to update parameter online."
|
|
|
|
except Exception as e:
|
|
error_msg = (
|
|
f"Failed to update parameter online: {e}. "
|
|
f"The full weights of the ModelRunner are partially updated. "
|
|
f"Please discard the whole weights."
|
|
)
|
|
logger.error(error_msg)
|
|
return False, error_msg
|
|
|
|
def update_weights_from_tensor(
|
|
self,
|
|
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
|
load_format: Optional[str] = None,
|
|
):
|
|
monkey_patch_torch_reductions()
|
|
if load_format == "flattened_bucket":
|
|
# Handle flattened bucket format
|
|
return self._update_weights_from_flattened_bucket(
|
|
flattened_tensor_bucket_dict=named_tensors
|
|
)
|
|
|
|
# We need to get device after patch otherwise the device would be wrong
|
|
self.device_module = torch.get_device_module(self.device)
|
|
infered_device = self.device_module.current_device()
|
|
|
|
named_tensors = [
|
|
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
|
|
for name, tensor in named_tensors
|
|
]
|
|
if load_format == "direct":
|
|
_model_load_weights_direct(self.model, named_tensors)
|
|
elif load_format in self.server_args.custom_weight_loader:
|
|
custom_loader = dynamic_import(load_format)
|
|
custom_loader(self.model, named_tensors)
|
|
elif load_format is None:
|
|
self.model.load_weights(named_tensors)
|
|
else:
|
|
raise NotImplementedError(f"Unknown load_format={load_format}")
|
|
return True, "Success"
|
|
|
|
def _update_weights_from_flattened_bucket(
|
|
self,
|
|
flattened_tensor_bucket_dict,
|
|
):
|
|
"""Handle flattened bucket format for weight updates"""
|
|
flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
|
|
metadata = flattened_tensor_bucket_dict["metadata"]
|
|
|
|
# Convert metadata dict to our format
|
|
converted_metadata = []
|
|
for meta in metadata:
|
|
converted_meta = FlattenedTensorMetadata(
|
|
name=meta.name,
|
|
shape=meta.shape,
|
|
dtype=meta.dtype,
|
|
start_idx=meta.start_idx,
|
|
end_idx=meta.end_idx,
|
|
numel=meta.numel,
|
|
)
|
|
converted_metadata.append(converted_meta)
|
|
|
|
# Create bucket and reconstruct tensors
|
|
bucket = FlattenedTensorBucket(
|
|
flattened_tensor=flattened_tensor, metadata=converted_metadata
|
|
)
|
|
reconstructed_tensors = bucket.reconstruct_tensors()
|
|
|
|
# Load the reconstructed tensors using the standard method
|
|
self.model.load_weights(reconstructed_tensors)
|
|
|
|
return True, "Success"
|
|
|
|
def get_weights_by_name(
|
|
self, name: str, truncate_size: int = 100
|
|
) -> Optional[torch.Tensor]:
|
|
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
|
|
|
Only used for unit test with an unoptimized performance.
|
|
For optimized performance, please use torch.save and torch.load.
|
|
"""
|
|
# TODO: (chenyang) Add support for Qwen models.
|
|
try:
|
|
return self.model.get_weights_by_name(
|
|
name, truncate_size, tp_size=self.tp_size
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error when getting parameter {name}: {e}")
|
|
return None
|
|
|
|
def init_lora_manager(self):
|
|
self.lora_manager = LoRAManager(
|
|
base_model=self.model,
|
|
base_hf_config=self.model_config.hf_config,
|
|
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
|
load_config=self.load_config,
|
|
dtype=self.dtype,
|
|
lora_backend=self.server_args.lora_backend,
|
|
tp_size=self.tp_size,
|
|
tp_rank=self.tp_rank,
|
|
max_lora_rank=self.server_args.max_lora_rank,
|
|
target_modules=self.server_args.lora_target_modules,
|
|
lora_paths=self.server_args.lora_paths,
|
|
server_args=self.server_args,
|
|
)
|
|
|
|
def load_lora_adapter(self, lora_ref: LoRARef):
|
|
"""Load a new lora adapter from disk or huggingface."""
|
|
|
|
logger.info(
|
|
f"LoRA adapter loading starts: {lora_ref}. "
|
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
|
)
|
|
|
|
result = self.lora_manager.load_lora_adapter(lora_ref)
|
|
|
|
logger.info(
|
|
f"LoRA adapter loading completes: {lora_ref}. "
|
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
|
)
|
|
|
|
return result
|
|
|
|
def unload_lora_adapter(self, lora_ref: LoRARef):
|
|
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
|
|
|
|
logger.info(
|
|
f"LoRA adapter unloading starts: {lora_ref}. "
|
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
|
)
|
|
|
|
result = self.lora_manager.unload_lora_adapter(lora_ref)
|
|
|
|
logger.info(
|
|
f"LoRA adapter unloading completes: {lora_ref}. "
|
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
|
)
|
|
|
|
return result
|
|
|
|
def profile_max_num_token(self, total_gpu_memory: int):
|
|
available_gpu_memory = get_available_gpu_memory(
|
|
self.device,
|
|
self.gpu_id,
|
|
distributed=get_world_group().world_size > 1,
|
|
cpu_group=get_world_group().cpu_group,
|
|
)
|
|
if self.is_draft_worker:
|
|
num_layers = getattr(
|
|
self.model_config.hf_config,
|
|
"num_nextn_predict_layers",
|
|
self.num_effective_layers,
|
|
)
|
|
elif self.is_hybrid_gdn:
|
|
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
|
|
else:
|
|
num_layers = self.num_effective_layers
|
|
if self.use_mla_backend:
|
|
cell_size = (
|
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
|
* num_layers
|
|
* torch._utils._element_size(self.kv_cache_dtype)
|
|
)
|
|
else:
|
|
cell_size = (
|
|
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
|
* self.model_config.head_dim
|
|
* num_layers
|
|
* 2
|
|
* torch._utils._element_size(self.kv_cache_dtype)
|
|
)
|
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
|
1 - self.mem_fraction_static
|
|
)
|
|
if self.is_hybrid_gdn:
|
|
rest_memory -= (
|
|
self.server_args.max_mamba_cache_size
|
|
* self.model_config.hf_config.mamba_cache_per_req
|
|
/ (1 << 30)
|
|
)
|
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
|
return max_num_token
|
|
|
|
@property
|
|
def is_hybrid_gdn(self):
|
|
return self.model_config.hf_config.architectures[0] in [
|
|
"Qwen3NextForCausalLM",
|
|
"Qwen3NextForCausalLMMTP",
|
|
]
|
|
|
|
def set_num_token_hybrid(self):
|
|
if (
|
|
"Llama4ForConditionalGeneration"
|
|
in self.model_config.hf_config.architectures
|
|
):
|
|
temp_ratio = (
|
|
(1 - self.is_hybrid)
|
|
+ self.is_hybrid
|
|
* self.attention_chunk_size
|
|
/ self.model_config.context_len
|
|
)
|
|
self.swa_max_total_num_tokens = (
|
|
4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
|
|
)
|
|
self.full_max_total_num_tokens = (
|
|
4 * self.max_total_num_tokens
|
|
- 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
|
|
)
|
|
self.swa_max_total_num_tokens = int(
|
|
self.swa_max_total_num_tokens
|
|
// self.server_args.page_size
|
|
* self.server_args.page_size
|
|
)
|
|
self.full_max_total_num_tokens = int(
|
|
self.full_max_total_num_tokens
|
|
// self.server_args.page_size
|
|
* self.server_args.page_size
|
|
)
|
|
self.max_total_num_tokens = self.full_max_total_num_tokens
|
|
else:
|
|
assert self.sliding_window_size is not None and self.sliding_window_size > 0
|
|
full_attention_layer_ids = []
|
|
swa_attention_layer_ids = []
|
|
|
|
try:
|
|
layers = self.model.model.layers
|
|
except:
|
|
try:
|
|
layers = self.model.language_model.model.layers
|
|
except:
|
|
try:
|
|
layers = self.model.language_model.layers
|
|
except:
|
|
self.is_hybrid = False
|
|
return
|
|
|
|
for layer in layers:
|
|
if (
|
|
layer.self_attn.attn.sliding_window_size is None
|
|
or layer.self_attn.attn.sliding_window_size == -1
|
|
):
|
|
full_attention_layer_ids.append(layer.layer_id)
|
|
else:
|
|
swa_attention_layer_ids.append(layer.layer_id)
|
|
self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
|
|
self.model_config.full_attention_layer_ids = full_attention_layer_ids
|
|
|
|
# Algorithm:
|
|
# Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
|
|
# - Find total # of tokens available across layers.
|
|
# - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
|
|
total_tokens = (
|
|
self.max_total_num_tokens * self.model_config.num_hidden_layers
|
|
)
|
|
full_layers_num = len(full_attention_layer_ids)
|
|
swa_layers_num = len(swa_attention_layer_ids)
|
|
swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio
|
|
|
|
# Solve the equations:
|
|
# 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
|
|
# 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
|
|
denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
|
|
self.full_max_total_num_tokens = int(total_tokens / denominator)
|
|
self.swa_max_total_num_tokens = int(
|
|
self.full_max_total_num_tokens * swa_full_tokens_ratio
|
|
)
|
|
self.max_total_num_tokens = self.full_max_total_num_tokens
|
|
|
|
logger.info(
|
|
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
|
|
)
|
|
|
|
def init_memory_pool(
|
|
self,
|
|
total_gpu_memory: int,
|
|
max_num_reqs: Optional[int] = None,
|
|
max_total_tokens: Optional[int] = None,
|
|
):
|
|
# Determine the kv cache dtype
|
|
if self.server_args.kv_cache_dtype == "auto":
|
|
quant_config = getattr(self.model, "quant_config", None)
|
|
kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
|
|
if (
|
|
isinstance(kv_cache_quant_algo, str)
|
|
and kv_cache_quant_algo.upper() == "FP8"
|
|
):
|
|
if _is_hip:
|
|
self.kv_cache_dtype = torch.float8_e4m3fnuz
|
|
else:
|
|
self.kv_cache_dtype = torch.float8_e4m3fn
|
|
else:
|
|
self.kv_cache_dtype = self.dtype
|
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
|
if _is_hip: # Using natively supported format
|
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
|
else:
|
|
self.kv_cache_dtype = torch.float8_e5m2
|
|
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
|
|
if _is_hip: # Using natively supported format
|
|
self.kv_cache_dtype = torch.float8_e4m3fnuz
|
|
else:
|
|
self.kv_cache_dtype = torch.float8_e4m3fn
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
|
)
|
|
|
|
log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")
|
|
|
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
|
if SGLANG_CI_SMALL_KV_SIZE:
|
|
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
|
|
|
if max_num_reqs is None:
|
|
max_num_reqs = min(
|
|
max(
|
|
int(
|
|
self.max_total_num_tokens / self.model_config.context_len * 512
|
|
),
|
|
2048,
|
|
),
|
|
4096,
|
|
)
|
|
if self.is_hybrid_gdn:
|
|
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
|
|
|
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
|
if self.is_draft_worker:
|
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
|
max_num_reqs = self.server_args.max_num_reqs
|
|
else:
|
|
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
|
|
# can be concurrently allocated, so we should give a headroom for it.
|
|
self.server_args.draft_runner_cache_size = (
|
|
self.max_total_num_tokens
|
|
# draft
|
|
+ max_num_reqs
|
|
* self.server_args.speculative_num_steps
|
|
* self.server_args.speculative_eagle_topk
|
|
# verify
|
|
+ max_num_reqs * self.server_args.speculative_num_draft_tokens
|
|
# buffer
|
|
+ 100
|
|
)
|
|
# Target worker and draft worker shares the same indices for the
|
|
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
|
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
|
self.server_args.max_num_reqs = max_num_reqs
|
|
|
|
if max_total_tokens is not None:
|
|
if max_total_tokens > self.max_total_num_tokens:
|
|
logging.warning(
|
|
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
|
|
f"{self.max_total_num_tokens}. "
|
|
f"Use the profiled value instead."
|
|
)
|
|
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
|
|
|
|
self.max_total_num_tokens = (
|
|
self.max_total_num_tokens
|
|
// self.server_args.page_size
|
|
* self.server_args.page_size
|
|
)
|
|
# different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens
|
|
if self.pp_size > 1:
|
|
tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64)
|
|
torch.distributed.all_reduce(
|
|
tensor,
|
|
op=torch.distributed.ReduceOp.MIN,
|
|
group=get_world_group().cpu_group,
|
|
)
|
|
self.max_total_num_tokens = tensor.item()
|
|
|
|
# create token size for hybrid cache
|
|
if self.is_hybrid:
|
|
self.set_num_token_hybrid()
|
|
|
|
if self.max_total_num_tokens <= 0:
|
|
raise RuntimeError(
|
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
|
)
|
|
|
|
# Initialize req_to_token_pool
|
|
if self.req_to_token_pool is None:
|
|
# FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
|
|
extra_max_context_len = 4
|
|
if self.server_args.speculative_num_draft_tokens is not None:
|
|
extra_max_context_len += self.server_args.speculative_num_draft_tokens
|
|
|
|
if self.server_args.disaggregation_mode == "decode":
|
|
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
|
|
|
# subscribe memory for pre-allocated requests
|
|
# if max_num_reqs <= 32, we pre-allocate 2x requests
|
|
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
|
self.req_to_token_pool = DecodeReqToTokenPool(
|
|
size=max_num_reqs,
|
|
max_context_len=self.model_config.context_len
|
|
+ extra_max_context_len,
|
|
device=self.device,
|
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
pre_alloc_size=pre_alloc_size,
|
|
)
|
|
elif self.is_hybrid_gdn:
|
|
config = self.model_config.hf_config
|
|
(
|
|
conv_state_shape,
|
|
temporal_state_shape,
|
|
conv_dtype,
|
|
ssm_dtype,
|
|
mamba_layers,
|
|
) = config.hybrid_gdn_params
|
|
self.req_to_token_pool = HybridReqToTokenPool(
|
|
size=max_num_reqs,
|
|
max_context_len=self.model_config.context_len
|
|
+ extra_max_context_len,
|
|
device=self.device,
|
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
conv_state_shape=conv_state_shape,
|
|
temporal_state_shape=temporal_state_shape,
|
|
conv_dtype=conv_dtype,
|
|
ssm_dtype=ssm_dtype,
|
|
mamba_layers=mamba_layers,
|
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
|
)
|
|
else:
|
|
self.req_to_token_pool = ReqToTokenPool(
|
|
size=max_num_reqs,
|
|
max_context_len=self.model_config.context_len
|
|
+ extra_max_context_len,
|
|
device=self.device,
|
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
)
|
|
else:
|
|
# Draft worker shares req_to_token_pool with the target worker.
|
|
assert self.is_draft_worker
|
|
|
|
# Initialize token_to_kv_pool
|
|
if self.server_args.attention_backend == "ascend":
|
|
if self.use_mla_backend:
|
|
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
|
self.max_total_num_tokens,
|
|
page_size=self.page_size,
|
|
dtype=self.kv_cache_dtype,
|
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
|
layer_num=self.num_effective_layers,
|
|
device=self.device,
|
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
start_layer=self.start_layer,
|
|
end_layer=self.end_layer,
|
|
)
|
|
else:
|
|
self.token_to_kv_pool = AscendTokenToKVPool(
|
|
self.max_total_num_tokens,
|
|
page_size=self.page_size,
|
|
dtype=self.kv_cache_dtype,
|
|
head_num=self.model_config.get_num_kv_heads(
|
|
get_attention_tp_size()
|
|
),
|
|
head_dim=self.model_config.head_dim,
|
|
layer_num=self.model_config.num_hidden_layers,
|
|
device=self.device,
|
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
)
|
|
elif self.use_mla_backend:
|
|
self.token_to_kv_pool = MLATokenToKVPool(
|
|
self.max_total_num_tokens,
|
|
page_size=self.page_size,
|
|
dtype=self.kv_cache_dtype,
|
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
|
layer_num=self.num_effective_layers,
|
|
device=self.device,
|
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
start_layer=self.start_layer,
|
|
end_layer=self.end_layer,
|
|
)
|
|
elif self.server_args.enable_double_sparsity:
|
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
|
self.max_total_num_tokens,
|
|
page_size=self.page_size,
|
|
dtype=self.kv_cache_dtype,
|
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
|
head_dim=self.model_config.head_dim,
|
|
layer_num=self.num_effective_layers,
|
|
device=self.device,
|
|
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
start_layer=self.start_layer,
|
|
end_layer=self.end_layer,
|
|
)
|
|
else:
|
|
if self.is_hybrid:
|
|
self.token_to_kv_pool = SWAKVPool(
|
|
size=self.full_max_total_num_tokens,
|
|
size_swa=self.swa_max_total_num_tokens,
|
|
dtype=self.kv_cache_dtype,
|
|
head_num=self.model_config.get_num_kv_heads(
|
|
get_attention_tp_size()
|
|
),
|
|
head_dim=self.model_config.head_dim,
|
|
swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
|
|
full_attention_layer_ids=self.model_config.full_attention_layer_ids,
|
|
enable_kvcache_transpose=False,
|
|
device=self.device,
|
|
)
|
|
elif self.is_hybrid_gdn:
|
|
self.token_to_kv_pool = HybridLinearKVPool(
|
|
page_size=self.page_size if _is_npu else 1,
|
|
size=self.max_total_num_tokens,
|
|
dtype=self.kv_cache_dtype,
|
|
head_num=self.model_config.get_num_kv_heads(
|
|
get_attention_tp_size()
|
|
),
|
|
head_dim=self.model_config.head_dim,
|
|
# if draft worker, we only need 1 attention layer's kv pool
|
|
full_attention_layer_ids=(
|
|
[0]
|
|
if self.is_draft_worker
|
|
else self.model_config.hf_config.full_attention_layer_ids
|
|
),
|
|
enable_kvcache_transpose=False,
|
|
device=self.device,
|
|
)
|
|
else:
|
|
self.token_to_kv_pool = MHATokenToKVPool(
|
|
self.max_total_num_tokens,
|
|
page_size=self.page_size,
|
|
dtype=self.kv_cache_dtype,
|
|
head_num=self.model_config.get_num_kv_heads(
|
|
get_attention_tp_size()
|
|
),
|
|
head_dim=self.model_config.head_dim,
|
|
layer_num=self.num_effective_layers,
|
|
device=self.device,
|
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
start_layer=self.start_layer,
|
|
end_layer=self.end_layer,
|
|
)
|
|
|
|
# Initialize token_to_kv_pool_allocator
|
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
|
if self.token_to_kv_pool_allocator is None:
|
|
if _is_npu and self.server_args.attention_backend in [
|
|
"ascend",
|
|
"hybrid_linear_attn",
|
|
]:
|
|
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
|
self.max_total_num_tokens,
|
|
page_size=self.page_size,
|
|
dtype=self.kv_cache_dtype,
|
|
device=self.device,
|
|
kvcache=self.token_to_kv_pool,
|
|
need_sort=need_sort,
|
|
)
|
|
else:
|
|
if self.page_size == 1:
|
|
if self.is_hybrid:
|
|
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
|
|
self.full_max_total_num_tokens,
|
|
self.swa_max_total_num_tokens,
|
|
dtype=self.kv_cache_dtype,
|
|
device=self.device,
|
|
kvcache=self.token_to_kv_pool,
|
|
need_sort=need_sort,
|
|
)
|
|
else:
|
|
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
|
self.max_total_num_tokens,
|
|
dtype=self.kv_cache_dtype,
|
|
device=self.device,
|
|
kvcache=self.token_to_kv_pool,
|
|
need_sort=need_sort,
|
|
)
|
|
else:
|
|
assert not self.is_hybrid
|
|
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
|
self.max_total_num_tokens,
|
|
page_size=self.page_size,
|
|
dtype=self.kv_cache_dtype,
|
|
device=self.device,
|
|
kvcache=self.token_to_kv_pool,
|
|
need_sort=need_sort,
|
|
)
|
|
else:
|
|
assert self.is_draft_worker
|
|
|
|
logger.info(
|
|
f"Memory pool end. "
|
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
|
)
|
|
|
|
def init_cublas(self):
|
|
"""We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
|
|
dtype = torch.float16
|
|
device = "cuda"
|
|
a = torch.ones((16, 16), dtype=dtype, device=device)
|
|
b = torch.ones((16, 16), dtype=dtype, device=device)
|
|
c = a @ b
|
|
return c
|
|
|
|
def init_attention_backend(self):
|
|
"""Init attention kernel backend."""
|
|
if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
|
|
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
|
|
else:
|
|
self.attn_backend = self._get_attention_backend()
|
|
|
|
def _get_attention_backend(self):
|
|
"""Init attention kernel backend."""
|
|
self.decode_attention_backend_str = (
|
|
self.server_args.decode_attention_backend
|
|
if self.server_args.decode_attention_backend
|
|
else self.server_args.attention_backend
|
|
)
|
|
self.prefill_attention_backend_str = (
|
|
self.server_args.prefill_attention_backend
|
|
if self.server_args.prefill_attention_backend
|
|
else self.server_args.attention_backend
|
|
)
|
|
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
|
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
|
HybridAttnBackend,
|
|
)
|
|
|
|
attn_backend = HybridAttnBackend(
|
|
self,
|
|
decode_backend=self._get_attention_backend_from_str(
|
|
self.decode_attention_backend_str
|
|
),
|
|
prefill_backend=self._get_attention_backend_from_str(
|
|
self.prefill_attention_backend_str
|
|
),
|
|
)
|
|
logger.info(
|
|
f"Using hybrid attention backend for decode and prefill: "
|
|
f"decode_backend={self.decode_attention_backend_str}, "
|
|
f"prefill_backend={self.prefill_attention_backend_str}."
|
|
)
|
|
logger.warning(
|
|
f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
|
f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
|
)
|
|
else:
|
|
attn_backend = self._get_attention_backend_from_str(
|
|
self.server_args.attention_backend
|
|
)
|
|
|
|
global_server_args_dict.update(
|
|
{
|
|
"decode_attention_backend": self.decode_attention_backend_str,
|
|
"prefill_attention_backend": self.prefill_attention_backend_str,
|
|
}
|
|
)
|
|
return attn_backend
|
|
|
|
def _get_attention_backend_from_str(self, backend_str: str):
|
|
if backend_str not in ATTENTION_BACKENDS:
|
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
|
return ATTENTION_BACKENDS[backend_str](self)
|
|
|
|
def init_double_sparsity_channel_config(self, selected_channel):
|
|
selected_channel = "." + selected_channel + "_proj"
|
|
self.sorted_channels = []
|
|
# load channel config
|
|
with open(self.server_args.ds_channel_config_path, "r") as f:
|
|
channel_config = json.load(f)
|
|
|
|
for i in range(self.start_layer, self.end_layer):
|
|
key = "model.layers." + str(i) + ".self_attn" + selected_channel
|
|
self.sorted_channels.append(
|
|
torch.tensor(channel_config[key])[
|
|
:, : self.server_args.ds_heavy_channel_num
|
|
]
|
|
.contiguous()
|
|
.cuda()
|
|
)
|
|
|
|
def init_device_graphs(self):
|
|
"""Capture device graphs."""
|
|
self.graph_runner = None
|
|
self.graph_mem_usage = 0
|
|
|
|
if not self.is_generation:
|
|
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
|
return
|
|
|
|
if self.device != "cpu" and self.server_args.disable_cuda_graph:
|
|
return
|
|
|
|
if self.device == "cpu" and not self.server_args.enable_torch_compile:
|
|
return
|
|
|
|
tic = time.perf_counter()
|
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
|
logger.info(
|
|
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
|
)
|
|
graph_runners = defaultdict(
|
|
lambda: CudaGraphRunner,
|
|
{
|
|
"cpu": CPUGraphRunner,
|
|
"npu": NPUGraphRunner,
|
|
},
|
|
)
|
|
self.graph_runner = graph_runners[self.device](self)
|
|
|
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
|
self.graph_mem_usage = before_mem - after_mem
|
|
logger.info(
|
|
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
|
f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
|
|
)
|
|
|
|
def init_threads_binding(self):
|
|
omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
|
|
cpu_ids_by_node = get_cpu_ids_by_node()
|
|
n_numa_node = len(cpu_ids_by_node)
|
|
if omp_cpuids == "all":
|
|
assert self.tp_size <= n_numa_node, (
|
|
f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
|
|
f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
|
|
f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
|
|
f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
|
|
f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
|
|
f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
|
|
f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
|
|
f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
|
|
)
|
|
if self.tp_size < n_numa_node:
|
|
logger.warning(
|
|
f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
|
|
)
|
|
self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
|
|
else:
|
|
threads_bind_list = omp_cpuids.split("|")
|
|
assert self.tp_size == len(threads_bind_list), (
|
|
f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). "
|
|
f"Please double check your settings."
|
|
)
|
|
self.local_omp_cpuid = threads_bind_list[self.tp_rank]
|
|
if self.tp_size > n_numa_node:
|
|
logger.warning(
|
|
f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), "
|
|
f"in this case the available memory amount of each rank cannot be determined in prior. "
|
|
f"Please set proper `--max-total-tokens` to avoid the out-of-memory error."
|
|
)
|
|
|
|
def apply_torch_tp(self):
|
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
|
from sglang.srt.layers.model_parallel import tensor_parallel
|
|
|
|
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
|
tensor_parallel(self.model, device_mesh)
|
|
|
|
def forward_decode(
|
|
self,
|
|
forward_batch: ForwardBatch,
|
|
skip_attn_backend_init: bool = False,
|
|
pp_proxy_tensors=None,
|
|
) -> LogitsProcessorOutput:
|
|
if not skip_attn_backend_init:
|
|
self.attn_backend.init_forward_metadata(forward_batch)
|
|
# FIXME: add pp_proxy_tensors arg to all models
|
|
kwargs = {}
|
|
if self.support_pp:
|
|
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
|
return self.model.forward(
|
|
forward_batch.input_ids,
|
|
forward_batch.positions,
|
|
forward_batch,
|
|
**kwargs,
|
|
)
|
|
|
|
def forward_extend(
|
|
self,
|
|
forward_batch: ForwardBatch,
|
|
skip_attn_backend_init: bool = False,
|
|
pp_proxy_tensors=None,
|
|
) -> LogitsProcessorOutput:
|
|
if not skip_attn_backend_init:
|
|
self.attn_backend.init_forward_metadata(forward_batch)
|
|
|
|
kwargs = {}
|
|
if self.support_pp:
|
|
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
|
if forward_batch.input_embeds is not None:
|
|
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
|
|
if not self.is_generation:
|
|
kwargs["get_embedding"] = True
|
|
return self.model.forward(
|
|
forward_batch.input_ids,
|
|
forward_batch.positions,
|
|
forward_batch,
|
|
**kwargs,
|
|
)
|
|
|
|
def forward_idle(
|
|
self, forward_batch: ForwardBatch, pp_proxy_tensors=None
|
|
) -> LogitsProcessorOutput:
|
|
kwargs = {}
|
|
if self.support_pp:
|
|
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
|
return self.model.forward(
|
|
forward_batch.input_ids,
|
|
forward_batch.positions,
|
|
forward_batch,
|
|
**kwargs,
|
|
)
|
|
|
|
def forward_split_prefill(
|
|
self,
|
|
forward_batch: ForwardBatch,
|
|
reinit_attn_backend: bool = False,
|
|
forward_count: int = 1,
|
|
) -> LogitsProcessorOutput:
|
|
if forward_batch.split_index == 0 or reinit_attn_backend:
|
|
self.attn_backend.init_forward_metadata(forward_batch)
|
|
next_split_index = min(
|
|
forward_batch.split_index + forward_count,
|
|
self.model_config.num_hidden_layers,
|
|
)
|
|
ret = self.model.forward_split_prefill(
|
|
forward_batch.input_ids,
|
|
forward_batch.positions,
|
|
forward_batch,
|
|
(forward_batch.split_index, next_split_index),
|
|
)
|
|
forward_batch.split_index = next_split_index
|
|
return ret
|
|
|
|
def forward(
|
|
self,
|
|
forward_batch: ForwardBatch,
|
|
skip_attn_backend_init: bool = False,
|
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
reinit_attn_backend: bool = False,
|
|
split_forward_count: int = 1,
|
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
|
self.forward_pass_id += 1
|
|
|
|
with get_global_expert_distribution_recorder().with_forward_pass(
|
|
self.forward_pass_id,
|
|
forward_batch,
|
|
):
|
|
output = self._forward_raw(
|
|
forward_batch,
|
|
skip_attn_backend_init,
|
|
pp_proxy_tensors,
|
|
reinit_attn_backend,
|
|
split_forward_count,
|
|
)
|
|
|
|
if self.eplb_manager is not None:
|
|
self.eplb_manager.on_forward_pass_end()
|
|
|
|
return output
|
|
|
|
def _forward_raw(
|
|
self,
|
|
forward_batch: ForwardBatch,
|
|
skip_attn_backend_init: bool,
|
|
pp_proxy_tensors: Optional[PPProxyTensors],
|
|
reinit_attn_backend: bool = False,
|
|
split_forward_count: int = 1,
|
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
|
mode_check = (
|
|
forward_batch.forward_mode.is_cpu_graph
|
|
if self.device == "cpu"
|
|
else forward_batch.forward_mode.is_cuda_graph
|
|
)
|
|
can_run_graph = bool(
|
|
mode_check()
|
|
and self.graph_runner
|
|
and self.graph_runner.can_run(forward_batch)
|
|
)
|
|
|
|
if can_run_graph:
|
|
ret = self.graph_runner.replay(
|
|
forward_batch,
|
|
skip_attn_backend_init=skip_attn_backend_init,
|
|
pp_proxy_tensors=pp_proxy_tensors,
|
|
)
|
|
return ret, can_run_graph
|
|
|
|
# For MLP sync
|
|
if forward_batch.global_num_tokens_cpu is not None:
|
|
forward_batch.prepare_mlp_sync_batch(self)
|
|
|
|
if forward_batch.forward_mode.is_decode():
|
|
ret = self.forward_decode(
|
|
forward_batch,
|
|
skip_attn_backend_init=skip_attn_backend_init,
|
|
pp_proxy_tensors=pp_proxy_tensors,
|
|
)
|
|
elif forward_batch.forward_mode.is_extend():
|
|
ret = self.forward_extend(
|
|
forward_batch,
|
|
skip_attn_backend_init=skip_attn_backend_init,
|
|
pp_proxy_tensors=pp_proxy_tensors,
|
|
)
|
|
elif forward_batch.forward_mode.is_split_prefill():
|
|
ret = self.forward_split_prefill(
|
|
forward_batch,
|
|
reinit_attn_backend=reinit_attn_backend,
|
|
forward_count=split_forward_count,
|
|
)
|
|
elif forward_batch.forward_mode.is_idle():
|
|
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
|
else:
|
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
|
|
|
if (
|
|
forward_batch.global_num_tokens_cpu is not None
|
|
and self.pp_group.is_last_rank
|
|
):
|
|
forward_batch.post_forward_mlp_sync_batch(ret)
|
|
|
|
return ret, can_run_graph
|
|
|
|
def _preprocess_logits(
|
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
|
):
|
|
# Apply logit bias
|
|
if sampling_info.sampling_info_done:
|
|
# Overlap mode: the function update_regex_vocab_mask was executed
|
|
# in process_batch_result of the last batch.
|
|
if sampling_info.grammars:
|
|
sampling_info.sampling_info_done.wait()
|
|
else:
|
|
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
|
sampling_info.update_regex_vocab_mask()
|
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
|
|
|
def sample(
|
|
self,
|
|
logits_output: LogitsProcessorOutput,
|
|
forward_batch: ForwardBatch,
|
|
) -> torch.Tensor:
|
|
"""Sample and compute logprobs and update logits_output.
|
|
|
|
Args:
|
|
logits_output: The logits output from the model forward
|
|
forward_batch: The forward batch that generates logits_output
|
|
|
|
Returns:
|
|
A list of next_token_ids
|
|
"""
|
|
# For duplex models with multiple output streams.
|
|
if isinstance(logits_output, tuple):
|
|
return torch.stack(
|
|
[self.sample(values, forward_batch) for values in logits_output],
|
|
axis=-1,
|
|
)
|
|
|
|
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
|
# Sample the next tokens
|
|
next_token_ids = self.sampler(
|
|
logits_output,
|
|
forward_batch.sampling_info,
|
|
forward_batch.return_logprob,
|
|
forward_batch.top_logprobs_nums,
|
|
forward_batch.token_ids_logprobs,
|
|
# For prefill, we only use the position of the last token.
|
|
(
|
|
forward_batch.positions
|
|
if forward_batch.forward_mode.is_decode()
|
|
else forward_batch.seq_lens - 1
|
|
),
|
|
)
|
|
return next_token_ids
|
|
|
|
def compute_logprobs_only(
|
|
self,
|
|
logits_output: LogitsProcessorOutput,
|
|
forward_batch: ForwardBatch,
|
|
) -> None:
|
|
"""
|
|
Compute token_ids_logprobs without performing sampling.
|
|
|
|
Optimized path for prefill-only requests that need token_ids_logprobs but don't
|
|
require next token generation. Skips expensive sampling operations
|
|
while still providing requested probability information.
|
|
|
|
Args:
|
|
logits_output: The logits output from the model forward
|
|
forward_batch: The forward batch that generates logits_output
|
|
"""
|
|
if not forward_batch.token_ids_logprobs:
|
|
return
|
|
|
|
# Preprocess logits (same as in sample method)
|
|
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
|
|
|
# Delegate to sampler for logprob-only computation
|
|
# This populates logits_output with requested token probabilities
|
|
self.sampler.compute_logprobs_only(
|
|
logits_output,
|
|
forward_batch.sampling_info,
|
|
forward_batch.return_logprob,
|
|
forward_batch.top_logprobs_nums,
|
|
forward_batch.token_ids_logprobs,
|
|
)
|
|
|
|
@property
|
|
def model_is_mrope(self) -> bool:
|
|
"""Detect if the model has "mrope" rope_scaling type.
|
|
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
|
rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
|
|
if rope_scaling is None:
|
|
return False
|
|
is_mrope_enabled = "mrope_section" in rope_scaling
|
|
return is_mrope_enabled
|
|
|
|
def save_remote_model(self, url: str):
|
|
from sglang.srt.model_loader.loader import RemoteModelLoader
|
|
|
|
logger.info(f"Saving model to {url}")
|
|
RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)
|
|
|
|
def save_sharded_model(
|
|
self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
|
|
):
|
|
from sglang.srt.model_loader.loader import ShardedStateLoader
|
|
|
|
logger.info(
|
|
f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
|
|
)
|
|
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
|
|
|
|
|
|
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
|
params_dict = dict(model.named_parameters())
|
|
for name, tensor in named_tensors:
|
|
default_weight_loader(params_dict[name], tensor)
|
|
|
|
|
|
def _unwrap_tensor(tensor, tp_rank, device):
|
|
if isinstance(tensor, LocalSerializedTensor):
|
|
tensor = tensor.get(tp_rank)
|
|
return tensor.to(device)
|
|
|
|
|
|
@dataclass
|
|
class LocalSerializedTensor:
|
|
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
|
|
The i-th element in the list corresponds to i-th rank's GPU."""
|
|
|
|
values: List[bytes]
|
|
|
|
def get(self, rank: int):
|
|
return MultiprocessingSerializer.deserialize(self.values[rank])
|