2024-11-22 22:16:53 +08:00
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-06-08 02:06:52 -07:00
""" ModelRunner runs the forward passes of the models. """
2024-06-12 21:48:40 -07:00
2025-03-03 00:12:04 -08:00
import datetime
2024-08-20 13:48:24 -07:00
import gc
2024-10-14 02:00:41 -07:00
import json
2024-03-24 15:41:24 +08:00
import logging
2025-03-03 00:12:04 -08:00
import os
2024-12-02 20:45:53 -08:00
import time
2025-03-01 01:53:10 +08:00
from dataclasses import dataclass
from typing import List , Optional , Tuple , Union
2024-01-08 04:37:50 +00:00
import torch
2024-12-01 23:23:18 -08:00
import torch . distributed as dist
2025-01-17 22:31:51 +08:00
from sglang . srt . configs . device_config import DeviceConfig
from sglang . srt . configs . load_config import LoadConfig
from sglang . srt . configs . model_config import AttentionArch , ModelConfig
from sglang . srt . distributed import (
2024-07-18 04:55:39 +10:00
get_tp_group ,
init_distributed_environment ,
initialize_model_parallel ,
2024-08-20 23:44:12 +08:00
set_custom_all_reduce ,
2024-07-18 04:55:39 +10:00
)
2025-01-17 22:31:51 +08:00
from sglang . srt . distributed . parallel_state import monkey_patch_vllm_parallel_state
2025-01-16 11:15:00 -08:00
from sglang . srt . layers . dp_attention import (
get_attention_tp_group ,
2025-01-16 12:51:11 -08:00
get_attention_tp_size ,
2025-01-16 11:15:00 -08:00
initialize_dp_attention ,
)
2024-08-28 18:58:52 -07:00
from sglang . srt . layers . logits_processor import LogitsProcessorOutput
2025-03-10 03:06:21 -07:00
from sglang . srt . layers . quantization import monkey_patch_isinstance_for_vllm_base_layer
2025-04-22 07:52:53 +08:00
from sglang . srt . layers . quantization . deep_gemm import (
_ENABLE_JIT_DEEPGEMM ,
update_deep_gemm_config ,
)
2024-12-30 05:42:08 -08:00
from sglang . srt . layers . sampler import Sampler
2024-12-04 19:02:08 -08:00
from sglang . srt . layers . torchao_utils import apply_torchao_config_to_model
2024-09-12 16:46:14 -07:00
from sglang . srt . lora . lora_manager import LoRAManager
2024-09-30 06:41:49 -07:00
from sglang . srt . managers . schedule_batch import global_server_args_dict
2024-08-05 01:40:33 +08:00
from sglang . srt . mem_cache . memory_pool import (
2024-10-14 02:00:41 -07:00
DoubleSparseTokenToKVPool ,
2024-08-05 01:40:33 +08:00
MHATokenToKVPool ,
MLATokenToKVPool ,
ReqToTokenPool ,
2025-03-05 08:06:07 -08:00
TokenToKVPoolAllocator ,
2024-08-05 01:40:33 +08:00
)
2025-03-12 22:22:39 -07:00
from sglang . srt . mem_cache . paged_allocator import PagedTokenToKVPoolAllocator
2025-02-03 20:52:30 +08:00
from sglang . srt . model_executor . cuda_graph_runner import CudaGraphRunner
2024-09-30 02:41:11 -07:00
from sglang . srt . model_executor . forward_batch_info import ForwardBatch
2024-12-02 23:22:13 +08:00
from sglang . srt . model_loader import get_model
2025-03-12 15:36:13 -07:00
from sglang . srt . model_loader . loader import (
DefaultModelLoader ,
device_loading_context ,
get_model_loader ,
)
from sglang . srt . model_loader . utils import set_default_torch_dtype
2025-03-01 01:53:10 +08:00
from sglang . srt . model_loader . weight_utils import default_weight_loader
2025-04-15 18:37:07 -07:00
from sglang . srt . patch_torch import monkey_patch_torch_reductions
2025-03-03 00:12:04 -08:00
from sglang . srt . sampling . sampling_batch_info import SamplingBatchInfo
2024-05-21 11:46:35 -07:00
from sglang . srt . server_args import ServerArgs
2025-01-02 02:09:08 -08:00
from sglang . srt . speculative . spec_info import SpeculativeAlgorithm
2025-01-13 14:24:00 -08:00
from sglang . srt . torch_memory_saver_adapter import TorchMemorySaverAdapter
2024-06-12 21:48:40 -07:00
from sglang . srt . utils import (
2025-03-01 01:53:10 +08:00
MultiprocessingSerializer ,
2024-09-30 06:41:49 -07:00
enable_show_time_cost ,
2024-06-12 21:48:40 -07:00
get_available_gpu_memory ,
2025-04-18 10:51:39 +08:00
get_bool_env_var ,
2024-12-01 23:23:18 -08:00
init_custom_process_group ,
2025-01-13 13:17:11 +08:00
is_cuda ,
2025-04-15 14:45:15 -07:00
is_fa3_default_architecture ,
2025-04-05 01:23:02 -07:00
is_flashinfer_available ,
2024-11-19 14:06:29 -08:00
is_hip ,
2025-04-12 01:09:25 -07:00
is_hopper_with_cuda_12_3 ,
2025-04-15 14:45:15 -07:00
is_no_spec_infer_or_topk_one ,
2025-01-20 04:03:15 -08:00
monkey_patch_p2p_access_check ,
2024-11-30 22:14:48 -08:00
monkey_patch_vllm_gguf_config ,
2024-11-28 23:58:54 -08:00
set_cpu_offload_max_bytes ,
2025-02-14 08:50:14 +08:00
set_cuda_arch ,
2024-06-12 21:48:40 -07:00
)
2024-01-29 17:05:42 -08:00
2025-04-27 07:18:10 -07:00
# Use a small KV cache pool size for tests in CI
2025-03-03 00:12:04 -08:00
SGLANG_CI_SMALL_KV_SIZE = os . getenv ( " SGLANG_CI_SMALL_KV_SIZE " , None )
2025-04-27 07:18:10 -07:00
# Detect stragger ranks in model loading
2025-03-03 00:12:04 -08:00
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
2025-04-27 07:18:10 -07:00
logger = logging . getLogger ( __name__ )
2025-03-03 00:12:04 -08:00
2024-01-08 04:37:50 +00:00
class ModelRunner :
2024-09-11 11:44:26 -07:00
""" ModelRunner runs the forward passes of the models. """
2024-01-08 04:37:50 +00:00
def __init__ (
self ,
2024-08-28 06:33:05 -07:00
model_config : ModelConfig ,
2024-05-27 21:24:10 -07:00
mem_fraction_static : float ,
gpu_id : int ,
tp_rank : int ,
tp_size : int ,
nccl_port : int ,
2024-05-21 11:46:35 -07:00
server_args : ServerArgs ,
2025-01-02 02:09:08 -08:00
is_draft_worker : bool = False ,
2025-03-05 08:06:07 -08:00
req_to_token_pool : Optional [ ReqToTokenPool ] = None ,
token_to_kv_pool_allocator : Optional [ TokenToKVPoolAllocator ] = None ,
2024-01-08 04:37:50 +00:00
) :
2024-07-12 12:28:09 -07:00
# Parse args
2024-01-08 04:37:50 +00:00
self . model_config = model_config
self . mem_fraction_static = mem_fraction_static
2024-10-11 17:05:58 +08:00
self . device = server_args . device
2024-05-27 21:24:10 -07:00
self . gpu_id = gpu_id
2024-01-08 04:37:50 +00:00
self . tp_rank = tp_rank
self . tp_size = tp_size
2024-10-11 17:05:58 +08:00
self . dist_port = nccl_port
2024-05-21 11:46:35 -07:00
self . server_args = server_args
2025-01-02 02:09:08 -08:00
self . is_draft_worker = is_draft_worker
2024-11-03 12:25:39 -08:00
self . is_generation = model_config . is_generation
self . is_multimodal = model_config . is_multimodal
2025-01-07 02:52:53 -08:00
self . should_log = tp_rank == 0
2025-01-02 02:09:08 -08:00
self . spec_algorithm = SpeculativeAlgorithm . from_string (
server_args . speculative_algorithm
)
2025-03-12 16:21:49 -07:00
self . page_size = server_args . page_size
2025-03-05 08:06:07 -08:00
self . req_to_token_pool = req_to_token_pool
self . token_to_kv_pool_allocator = token_to_kv_pool_allocator
2025-04-17 01:43:14 -07:00
self . use_mla_backend = self . model_config . attention_arch == AttentionArch . MLA
2025-04-07 00:29:36 -07:00
self . attention_chunk_size = model_config . attention_chunk_size
2024-09-17 22:07:53 +08:00
2024-09-29 17:42:45 -07:00
# Model-specific adjustment
2025-03-06 01:51:12 -08:00
self . model_specific_adjustment ( )
2024-10-14 02:00:41 -07:00
2024-09-30 06:41:49 -07:00
if server_args . show_time_cost :
enable_show_time_cost ( )
2025-03-06 01:51:12 -08:00
# Global vars
2024-07-27 20:18:56 -07:00
global_server_args_dict . update (
{
2024-09-10 17:11:16 -07:00
" attention_backend " : server_args . attention_backend ,
" sampling_backend " : server_args . sampling_backend ,
2024-08-24 08:02:23 -07:00
" triton_attention_reduce_in_fp32 " : server_args . triton_attention_reduce_in_fp32 ,
2024-09-09 05:32:41 -07:00
" torchao_config " : server_args . torchao_config ,
2024-11-17 16:53:44 -08:00
" enable_nan_detection " : server_args . enable_nan_detection ,
2024-11-16 17:01:43 +08:00
" enable_dp_attention " : server_args . enable_dp_attention ,
2024-12-06 15:05:21 +08:00
" enable_ep_moe " : server_args . enable_ep_moe ,
2025-03-19 23:16:31 +08:00
" enable_deepep_moe " : server_args . enable_deepep_moe ,
2025-04-02 00:23:25 +08:00
" deepep_mode " : server_args . deepep_mode ,
2025-01-17 13:22:53 +08:00
" device " : server_args . device ,
2025-03-03 00:12:04 -08:00
" speculative_accept_threshold_single " : server_args . speculative_accept_threshold_single ,
" speculative_accept_threshold_acc " : server_args . speculative_accept_threshold_acc ,
2025-02-18 02:06:43 +08:00
" disable_radix_cache " : server_args . disable_radix_cache ,
2025-02-28 18:13:56 -08:00
" flashinfer_mla_disable_ragged " : server_args . flashinfer_mla_disable_ragged ,
2025-04-18 12:38:26 +08:00
" moe_dense_tp_size " : server_args . moe_dense_tp_size ,
2025-03-03 00:12:04 -08:00
" debug_tensor_dump_output_folder " : server_args . debug_tensor_dump_output_folder ,
" debug_tensor_dump_inject " : server_args . debug_tensor_dump_inject ,
2025-04-04 16:59:29 +08:00
" n_share_experts_fusion " : server_args . n_share_experts_fusion ,
2025-04-15 22:01:22 -07:00
" disable_chunked_prefix_cache " : server_args . disable_chunked_prefix_cache ,
2025-04-05 01:23:02 -07:00
" use_mla_backend " : self . use_mla_backend ,
2024-07-27 20:18:56 -07:00
}
)
2024-01-08 04:37:50 +00:00
2025-03-06 01:51:12 -08:00
# CPU offload
2024-11-28 23:58:54 -08:00
set_cpu_offload_max_bytes ( int ( server_args . cpu_offload_gb * 1024 * * 3 ) )
2024-11-29 16:03:32 -08:00
# Get memory before model loading
2024-08-24 08:02:23 -07:00
min_per_gpu_memory = self . init_torch_distributed ( )
2024-11-29 16:03:32 -08:00
2025-04-22 07:52:53 +08:00
# Update deep gemm configure
if _ENABLE_JIT_DEEPGEMM :
update_deep_gemm_config ( gpu_id , server_args )
2025-04-27 07:18:10 -07:00
# If it is a draft model, tp_group can be different
2025-03-12 16:21:49 -07:00
self . initialize ( min_per_gpu_memory )
def initialize ( self , min_per_gpu_memory : float ) :
server_args = self . server_args
2025-01-14 03:38:51 +08:00
self . memory_saver_adapter = TorchMemorySaverAdapter . create (
enable = self . server_args . enable_memory_saver
)
2024-11-29 16:03:32 -08:00
# Load the model
2024-09-13 20:27:53 -07:00
self . sampler = Sampler ( )
2024-08-24 08:02:23 -07:00
self . load_model ( )
2024-11-15 21:26:00 -08:00
2024-12-16 19:21:11 -08:00
# Apply torchao quantization
2025-01-22 21:33:17 -08:00
torchao_applied = getattr ( self . model , " torchao_applied " , False )
# In layered loading, torchao may have been applied
if not torchao_applied :
apply_torchao_config_to_model (
self . model , global_server_args_dict [ " torchao_config " ]
)
2024-12-16 19:21:11 -08:00
2024-11-29 16:03:32 -08:00
# Apply torch TP if the model supports it
2024-11-15 21:26:00 -08:00
supports_torch_tp = getattr ( self . model , " supports_torch_tp " , False )
if self . tp_size > 1 and supports_torch_tp :
self . apply_torch_tp ( )
2025-03-06 01:51:12 -08:00
# Init lora
2024-09-12 16:46:14 -07:00
if server_args . lora_paths is not None :
self . init_lora_manager ( )
2025-03-06 01:51:12 -08:00
# Init memory pool and attention backends
2024-08-24 08:02:23 -07:00
self . init_memory_pool (
min_per_gpu_memory ,
2024-09-10 17:11:16 -07:00
server_args . max_running_requests ,
2024-08-24 08:02:23 -07:00
server_args . max_total_tokens ,
)
2024-10-11 17:05:58 +08:00
if self . device == " cuda " :
self . init_cublas ( )
self . init_attention_backend ( )
self . init_cuda_graphs ( )
else :
2024-10-13 02:10:32 +08:00
self . cuda_graph_runner = None
2024-10-11 17:05:58 +08:00
self . init_attention_backend ( )
2024-08-24 08:02:23 -07:00
2025-03-18 10:35:23 -04:00
# auxiliary hidden capture mode. TODO: expose this to server args?
if self . spec_algorithm . is_eagle3 ( ) and not self . is_draft_worker :
self . model . set_eagle3_layers_to_capture ( )
2025-03-06 01:51:12 -08:00
def model_specific_adjustment ( self ) :
server_args = self . server_args
2025-04-17 01:43:33 -07:00
if server_args . attention_backend is None :
2025-04-20 22:58:28 -07:00
"""
2025-04-27 07:18:10 -07:00
Auto select the fastest attention backend .
2025-04-20 22:58:28 -07:00
1. Models with MHA Architecture ( e . g : Llama , QWen )
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases , we will use flashinfer if available , otherwise use triton .
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper .
2.2 Otherwise , we will use triton backend .
"""
2025-04-05 01:23:02 -07:00
if not self . use_mla_backend :
2025-04-27 07:18:10 -07:00
# MHA architecture
2025-04-15 14:45:15 -07:00
if (
is_hopper_with_cuda_12_3 ( )
and is_no_spec_infer_or_topk_one ( server_args )
and is_fa3_default_architecture ( self . model_config . hf_config )
) :
server_args . attention_backend = " fa3 "
else :
server_args . attention_backend = (
" flashinfer " if is_flashinfer_available ( ) else " triton "
)
2025-04-05 01:23:02 -07:00
else :
2025-04-27 07:18:10 -07:00
# MLA architecture
2025-04-20 22:58:28 -07:00
if is_hopper_with_cuda_12_3 ( ) :
2025-04-15 14:45:15 -07:00
server_args . attention_backend = " fa3 "
2025-04-12 01:09:25 -07:00
else :
server_args . attention_backend = " triton "
2025-04-05 01:23:02 -07:00
logger . info (
f " Attention backend not set. Use { server_args . attention_backend } backend by default. "
)
elif self . use_mla_backend :
2025-03-06 01:51:12 -08:00
if server_args . device != " cpu " :
2025-04-17 01:43:33 -07:00
if server_args . attention_backend in [
" flashinfer " ,
" fa3 " ,
" triton " ,
" flashmla " ,
2025-04-27 20:58:53 -07:00
" cutlass_mla " ,
2025-04-17 01:43:33 -07:00
] :
2025-03-28 18:30:14 -07:00
logger . info (
2025-04-05 01:23:02 -07:00
f " MLA optimization is turned on. Use { server_args . attention_backend } backend. "
2025-03-28 18:30:14 -07:00
)
2025-03-06 01:51:12 -08:00
else :
2025-04-05 01:23:02 -07:00
raise ValueError (
f " Invalid attention backend for MLA: { server_args . attention_backend } "
)
else :
2025-04-27 07:18:10 -07:00
raise ValueError ( " MLA optimization not supported on CPU. " )
2025-03-06 01:51:12 -08:00
2025-04-12 01:09:25 -07:00
if (
server_args . attention_backend == " fa3 "
and server_args . kv_cache_dtype == " fp8_e5m2 "
) :
logger . warning (
" FlashAttention3 only supports fp8_e4m3 if using FP8; "
" Setting attention backend to triton. "
)
server_args . attention_backend = " triton "
2025-03-06 01:51:12 -08:00
if server_args . enable_double_sparsity :
logger . info (
" Double sparsity optimization is turned on. Use triton backend without CUDA graph. "
)
server_args . attention_backend = " triton "
server_args . disable_cuda_graph = True
if server_args . ds_heavy_channel_type is None :
raise ValueError (
" Please specify the heavy channel type for double sparsity optimization. "
)
self . init_double_sparsity_channel_config ( server_args . ds_heavy_channel_type )
if self . is_multimodal :
2025-04-01 00:57:51 +08:00
self . mem_fraction_static * = 0.90
2025-03-06 01:51:12 -08:00
logger . info (
f " Automatically reduce --mem-fraction-static to { self . mem_fraction_static : .3f } "
f " because this is a multimodal model. "
)
2025-04-01 00:57:51 +08:00
logger . info (
" Automatically turn off --chunked-prefill-size for multimodal model. "
)
server_args . chunked_prefill_size = - 1
2025-03-06 01:51:12 -08:00
2025-04-15 22:01:22 -07:00
if not self . use_mla_backend :
server_args . disable_chunked_prefix_cache = True
elif self . page_size > 1 :
logger . info ( " Disable chunked prefix cache when page size > 1. " )
server_args . disable_chunked_prefix_cache = True
if not server_args . disable_chunked_prefix_cache :
logger . info ( " Chunked prefix cache is turned on. " )
2024-08-24 08:02:23 -07:00
def init_torch_distributed ( self ) :
2024-10-11 07:22:48 -07:00
logger . info ( " Init torch distributed begin. " )
2025-03-05 08:06:07 -08:00
2025-03-26 15:18:14 +08:00
try :
torch . get_device_module ( self . device ) . set_device ( self . gpu_id )
except Exception :
logger . warning (
f " Context: { self . device =} { self . gpu_id =} { os . environ . get ( ' CUDA_VISIBLE_DEVICES ' ) =} { self . tp_rank =} { self . tp_size =} "
)
raise
2024-10-11 17:05:58 +08:00
if self . device == " cuda " :
backend = " nccl "
2024-10-13 02:10:32 +08:00
elif self . device == " xpu " :
2025-03-04 20:05:56 +08:00
backend = " xccl "
2024-11-23 09:52:23 +05:30
elif self . device == " hpu " :
backend = " hccl "
2025-01-17 13:22:53 +08:00
elif self . device == " cpu " :
backend = " gloo "
2024-07-06 23:34:10 -07:00
2025-03-04 21:23:47 -08:00
before_avail_memory = get_available_gpu_memory ( self . device , self . gpu_id )
2024-08-24 08:02:23 -07:00
if not self . server_args . enable_p2p_check :
2025-01-20 04:03:15 -08:00
monkey_patch_p2p_access_check ( )
2024-09-29 02:36:12 -07:00
if self . server_args . dist_init_addr :
2024-10-11 17:05:58 +08:00
dist_init_method = f " tcp:// { self . server_args . dist_init_addr } "
2024-06-17 20:41:24 -07:00
else :
2024-10-11 17:05:58 +08:00
dist_init_method = f " tcp://127.0.0.1: { self . dist_port } "
2024-08-24 08:02:23 -07:00
set_custom_all_reduce ( not self . server_args . disable_custom_all_reduce )
2025-01-02 02:09:08 -08:00
if not self . is_draft_worker :
2025-01-19 06:14:19 +08:00
# Only initialize the distributed environment on the target model worker.
2025-01-02 02:09:08 -08:00
init_distributed_environment (
backend = backend ,
world_size = self . tp_size ,
rank = self . tp_rank ,
local_rank = self . gpu_id ,
distributed_init_method = dist_init_method ,
2025-02-25 16:26:08 +08:00
timeout = self . server_args . dist_timeout ,
2025-01-02 02:09:08 -08:00
)
initialize_model_parallel ( tensor_model_parallel_size = self . tp_size )
2025-01-16 11:15:00 -08:00
initialize_dp_attention (
enable_dp_attention = self . server_args . enable_dp_attention ,
tp_rank = self . tp_rank ,
tp_size = self . tp_size ,
dp_size = self . server_args . dp_size ,
)
2025-01-02 02:09:08 -08:00
2024-08-24 08:02:23 -07:00
min_per_gpu_memory = get_available_gpu_memory (
2024-10-11 17:05:58 +08:00
self . device , self . gpu_id , distributed = self . tp_size > 1
2024-05-27 21:24:10 -07:00
)
2024-08-16 01:39:24 -07:00
self . tp_group = get_tp_group ( )
2025-01-16 11:15:00 -08:00
self . attention_tp_group = get_attention_tp_group ( )
2024-05-24 03:48:53 -07:00
2024-08-24 08:02:23 -07:00
# Check memory for tensor parallelism
2025-03-12 16:21:49 -07:00
local_gpu_memory = get_available_gpu_memory ( self . device , self . gpu_id )
2024-05-24 03:48:53 -07:00
if self . tp_size > 1 :
2024-08-24 08:02:23 -07:00
if min_per_gpu_memory < local_gpu_memory * 0.9 :
2025-04-18 10:51:39 +08:00
if get_bool_env_var ( " SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK " ) :
logger . warning (
" The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
f " { min_per_gpu_memory =} , { local_gpu_memory =} , { local_gpu_memory * 0.9 =} "
)
else :
raise ValueError (
" The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
f " { min_per_gpu_memory =} , { local_gpu_memory =} , { local_gpu_memory * 0.9 =} "
)
2024-01-08 04:37:50 +00:00
2025-03-04 21:23:47 -08:00
logger . info (
f " Init torch distributed ends. mem usage= { ( before_avail_memory - local_gpu_memory ) : .2f } GB "
)
2024-08-24 08:02:23 -07:00
return min_per_gpu_memory
2024-07-13 05:29:46 -07:00
2024-01-08 04:37:50 +00:00
def load_model ( self ) :
2025-03-04 21:23:47 -08:00
before_avail_memory = get_available_gpu_memory ( self . device , self . gpu_id )
2024-05-27 21:24:10 -07:00
logger . info (
2024-10-11 17:05:58 +08:00
f " Load weight begin. avail mem= { get_available_gpu_memory ( self . device , self . gpu_id ) : .2f } GB "
2024-05-27 21:24:10 -07:00
)
2024-09-16 18:16:27 -07:00
# This can reduce thread conflicts and speed up weight loading.
2025-01-17 13:22:53 +08:00
if self . device != " cpu " :
torch . set_num_threads ( 1 )
2024-10-11 17:05:58 +08:00
if self . device == " cuda " :
if torch . cuda . get_device_capability ( ) [ 0 ] < 8 :
logger . info (
" Compute capability below sm80. Use float16 due to lack of bfloat16 support. "
)
self . server_args . dtype = " float16 "
2024-12-02 23:22:13 +08:00
self . model_config . dtype = torch . float16
2024-10-11 17:05:58 +08:00
if torch . cuda . get_device_capability ( ) [ 1 ] < 5 :
raise RuntimeError ( " SGLang only supports sm75 and above. " )
2024-01-20 23:20:35 -08:00
2025-02-14 08:50:14 +08:00
set_cuda_arch ( )
2024-12-09 06:30:35 -08:00
# Prepare the model config
2024-11-14 02:26:56 -05:00
self . load_config = LoadConfig (
load_format = self . server_args . load_format ,
download_dir = self . server_args . download_dir ,
)
2024-11-30 22:14:48 -08:00
if self . server_args . load_format == " gguf " :
monkey_patch_vllm_gguf_config ( )
2024-12-09 06:30:35 -08:00
# Load the model
2025-01-17 22:31:51 +08:00
# Remove monkey_patch when linear.py quant remove dependencies with vllm
monkey_patch_vllm_parallel_state ( )
2025-03-10 03:06:21 -07:00
monkey_patch_isinstance_for_vllm_base_layer ( )
2025-01-14 03:38:51 +08:00
with self . memory_saver_adapter . region ( ) :
self . model = get_model (
model_config = self . model_config ,
load_config = self . load_config ,
device_config = DeviceConfig ( self . device ) ,
)
2025-01-17 22:31:51 +08:00
monkey_patch_vllm_parallel_state ( reverse = True )
2025-03-10 03:06:21 -07:00
monkey_patch_isinstance_for_vllm_base_layer ( reverse = True )
2024-11-18 21:29:13 +08:00
2025-01-13 13:17:11 +08:00
if self . server_args . kv_cache_dtype == " fp8_e4m3 " :
if self . server_args . quantization_param_path is not None :
if callable ( getattr ( self . model , " load_kv_cache_scales " , None ) ) :
self . model . load_kv_cache_scales (
self . server_args . quantization_param_path
)
logger . info (
" Loaded KV cache scaling factors from %s " ,
self . server_args . quantization_param_path ,
)
else :
raise RuntimeError (
" Using FP8 KV cache and scaling factors provided but "
" model %s does not support loading scaling factors. " ,
self . model . __class__ ,
)
else :
logger . warning (
" Using FP8 KV cache but no scaling factors "
" provided. Defaulting to scaling factors of 1.0. "
" This may lead to less accurate results! "
)
2024-12-09 06:30:35 -08:00
# Parse other args
2024-08-14 10:37:01 -07:00
self . sliding_window_size = (
2024-08-24 08:02:23 -07:00
self . model . get_attention_sliding_window_size ( )
if hasattr ( self . model , " get_attention_sliding_window_size " )
2024-08-14 10:37:01 -07:00
else None
)
2024-12-02 23:22:13 +08:00
self . dtype = self . model_config . dtype
2024-08-08 16:31:19 -07:00
2025-03-04 21:23:47 -08:00
after_avail_memory = get_available_gpu_memory ( self . device , self . gpu_id )
2024-05-27 21:24:10 -07:00
logger . info (
2024-08-25 14:46:34 -07:00
f " Load weight end. "
2024-06-07 19:22:34 -07:00
f " type= { type ( self . model ) . __name__ } , "
2024-06-27 23:30:39 -07:00
f " dtype= { self . dtype } , "
2025-03-04 21:23:47 -08:00
f " avail mem= { after_avail_memory : .2f } GB, "
f " mem usage= { ( before_avail_memory - after_avail_memory ) : .2f } GB. "
2024-05-27 21:24:10 -07:00
)
2024-01-20 23:20:35 -08:00
2025-03-06 01:51:12 -08:00
# Handle the case where some ranks do not finish loading.
try :
dist . monitored_barrier (
group = get_tp_group ( ) . cpu_group ,
timeout = datetime . timedelta ( seconds = UNBALANCED_MODEL_LOADING_TIMEOUT_S ) ,
wait_all_ranks = True ,
)
except RuntimeError :
raise ValueError (
f " TP rank { self . tp_rank } could finish the model loading, but there are other ranks that didn ' t finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node. "
) from None
2024-12-09 06:30:35 -08:00
def update_weights_from_disk (
self , model_path : str , load_format : str
) - > tuple [ bool , str ] :
""" Update engine weights in-place from the disk. """
2024-08-20 13:48:24 -07:00
logger . info (
2024-11-29 17:17:00 -08:00
f " Update engine weights online from disk begin. "
2024-10-11 17:05:58 +08:00
f " avail mem= { get_available_gpu_memory ( self . device , self . gpu_id ) : .2f } GB "
2024-08-20 13:48:24 -07:00
)
2024-10-11 17:05:58 +08:00
target_device = torch . device ( self . device )
2024-12-02 23:22:13 +08:00
self . model_config . model_path = model_path
2024-08-20 13:48:24 -07:00
load_config = LoadConfig ( load_format = load_format )
2025-03-12 22:22:39 -07:00
# Only support DefaultModelLoader for now
2024-08-20 13:48:24 -07:00
loader = get_model_loader ( load_config )
if not isinstance ( loader , DefaultModelLoader ) :
2024-09-16 18:16:27 -07:00
message = f " Failed to get model loader: { loader } . "
return False , message
2024-08-20 13:48:24 -07:00
def get_weight_iter ( config ) :
iter = loader . _get_weights_iterator (
2025-04-28 01:18:57 -07:00
DefaultModelLoader . Source (
config . model_path ,
revision = config . revision ,
fall_back_to_pt = getattr (
self . model , " fall_back_to_pt_during_load " , True
) ,
)
2024-08-20 13:48:24 -07:00
)
return iter
def model_load_weights ( model , iter ) :
model . load_weights ( iter )
for _ , module in self . model . named_modules ( ) :
quant_method = getattr ( module , " quant_method " , None )
if quant_method is not None :
with device_loading_context ( module , target_device ) :
quant_method . process_weights_after_loading ( module )
return model
2024-12-02 23:22:13 +08:00
with set_default_torch_dtype ( self . model_config . dtype ) :
2024-08-20 13:48:24 -07:00
try :
2024-12-02 23:22:13 +08:00
iter = get_weight_iter ( self . model_config )
2024-08-20 13:48:24 -07:00
except Exception as e :
2024-09-16 18:16:27 -07:00
message = f " Failed to get weights iterator: { e } . "
2024-08-20 13:48:24 -07:00
return False , message
try :
model = model_load_weights ( self . model , iter )
except Exception as e :
2024-09-16 18:16:27 -07:00
message = (
f " Failed to update weights: { e } . \n Rolling back to original weights. "
)
2024-08-20 13:48:24 -07:00
del iter
gc . collect ( )
2024-12-02 23:22:13 +08:00
iter = get_weight_iter ( self . model_config )
2024-08-20 13:48:24 -07:00
self . model = model_load_weights ( self . model , iter )
return False , message
self . model = model
self . server_args . model_path = model_path
self . server_args . load_format = load_format
self . load_config = load_config
2024-08-25 14:46:34 -07:00
logger . info ( " Update weights end. " )
2024-09-16 18:16:27 -07:00
return True , " Succeeded to update model weights. "
2024-08-20 13:48:24 -07:00
2024-12-01 23:23:18 -08:00
def init_weights_update_group (
self ,
master_address ,
master_port ,
rank_offset ,
world_size ,
group_name ,
backend = " nccl " ,
) :
""" Initialize the Torch process group for model parameter updates.
` _model_update_group ` is used in the RLHF workflow , where rank
0 is the actor model in the training engine , and the other ranks are
the inference engine , which is used for rollout .
In the RLHF workflow , the training engine updates the model
weights / parameters online , and broadcasts them to the inference
engine through the ` _model_update_group ` process group .
"""
assert (
torch . distributed . is_initialized ( )
) , " Default torch process group must be initialized "
assert group_name != " " , " Group name cannot be empty "
rank = rank_offset + self . tp_rank
logger . info (
f " init custom process group: master_address= { master_address } , master_port= { master_port } , "
2025-01-14 03:38:51 +08:00
f " rank_offset= { rank_offset } , rank= { rank } , world_size= { world_size } , group_name= { group_name } , backend= { backend } "
2024-12-01 23:23:18 -08:00
)
try :
self . _model_update_group = init_custom_process_group (
backend = backend ,
init_method = f " tcp:// { master_address } : { master_port } " ,
world_size = world_size ,
rank = rank ,
group_name = group_name ,
)
dist . barrier ( group = self . _model_update_group , device_ids = [ rank ] )
return True , " Succeeded to initialize custom process group. "
except Exception as e :
message = f " Failed to initialize custom process group: { e } . "
logger . error ( message )
return False , message
def update_weights_from_distributed ( self , name , dtype , shape ) :
"""
Update specific parameter in the model weights online
through ` _model_update_group ` process group .
Args :
name : the name of the parameter to be updated .
dtype : the data type of the parameter to be updated .
shape : the shape of the parameter to be updated .
"""
target_dtype = (
dtype if isinstance ( dtype , torch . dtype ) else getattr ( torch , dtype )
)
assert (
self . _model_update_group is not None
) , " model update group must be initialized "
try :
weights = torch . empty ( shape , dtype = target_dtype , device = self . device )
torch . distributed . broadcast ( weights , src = 0 , group = self . _model_update_group )
self . model . load_weights ( [ ( name , weights ) ] )
return True , f " Succeeded to update parameter { name } online. "
except Exception as e :
error_msg = (
f " Failed to update parameter online: { e } . "
f " The full weights of the ModelRunner are partially updated. "
f " Please discard the whole weights. "
)
logger . error ( error_msg )
return False , error_msg
2025-03-01 01:53:10 +08:00
def update_weights_from_tensor (
self ,
named_tensors : List [ Tuple [ str , Union [ torch . Tensor , " LocalSerializedTensor " ] ] ] ,
load_format : Optional [ str ] = None ,
) :
named_tensors = [
( name , _unwrap_tensor ( tensor , tp_rank = self . tp_rank ) )
for name , tensor in named_tensors
]
if load_format == " direct " :
_model_load_weights_direct ( self . model , named_tensors )
elif load_format is None :
self . model . load_weights ( named_tensors )
else :
raise NotImplementedError ( f " Unknown load_format= { load_format } " )
2025-01-02 18:05:19 +08:00
return True , " Success "
2024-12-29 05:30:27 +08:00
2024-11-29 23:36:38 -08:00
def get_weights_by_name (
self , name : str , truncate_size : int = 100
) - > Optional [ torch . Tensor ] :
""" Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
Only used for unit test with an unoptimized performance .
For optimized performance , please use torch . save and torch . load .
"""
# TODO: (chenyang) Add support for Qwen models.
try :
return self . model . get_weights_by_name (
name , truncate_size , tp_size = self . tp_size
)
except Exception as e :
logger . error ( f " Error when getting parameter { name } : { e } " )
return None
2024-09-12 16:46:14 -07:00
def init_lora_manager ( self ) :
self . lora_manager = LoRAManager (
base_model = self . model ,
lora_paths = self . server_args . lora_paths ,
base_hf_config = self . model_config . hf_config ,
max_loras_per_batch = self . server_args . max_loras_per_batch ,
load_config = self . load_config ,
dtype = self . dtype ,
2025-02-03 22:09:13 -08:00
lora_backend = self . server_args . lora_backend ,
2025-03-18 23:33:07 -04:00
tp_size = self . tp_size ,
tp_rank = self . tp_rank ,
2024-09-12 16:46:14 -07:00
)
logger . info ( " LoRA manager ready. " )
2024-08-24 08:02:23 -07:00
def profile_max_num_token ( self , total_gpu_memory : int ) :
2024-05-27 21:24:10 -07:00
available_gpu_memory = get_available_gpu_memory (
2024-10-11 17:05:58 +08:00
self . device , self . gpu_id , distributed = self . tp_size > 1
2024-05-27 21:24:10 -07:00
)
2025-04-05 01:23:02 -07:00
if self . use_mla_backend :
2024-08-05 01:40:33 +08:00
cell_size = (
( self . model_config . kv_lora_rank + self . model_config . qk_rope_head_dim )
* self . model_config . num_hidden_layers
2024-08-26 08:38:11 +08:00
* torch . _utils . _element_size ( self . kv_cache_dtype )
2024-08-05 01:40:33 +08:00
)
else :
cell_size = (
2025-01-16 12:51:11 -08:00
self . model_config . get_num_kv_heads ( get_attention_tp_size ( ) )
2024-08-05 01:40:33 +08:00
* self . model_config . head_dim
* self . model_config . num_hidden_layers
* 2
2024-08-26 08:38:11 +08:00
* torch . _utils . _element_size ( self . kv_cache_dtype )
2024-08-05 01:40:33 +08:00
)
2024-01-08 04:37:50 +00:00
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self . mem_fraction_static
)
2024-05-24 03:48:53 -07:00
max_num_token = int ( rest_memory * ( 1 << 30 ) / / cell_size )
2024-01-08 04:37:50 +00:00
return max_num_token
2024-07-30 13:33:55 -07:00
def init_memory_pool (
2024-08-24 08:02:23 -07:00
self ,
total_gpu_memory : int ,
2024-09-10 17:11:16 -07:00
max_num_reqs : Optional [ int ] = None ,
max_total_tokens : Optional [ int ] = None ,
2024-07-30 13:33:55 -07:00
) :
2024-08-26 08:38:11 +08:00
if self . server_args . kv_cache_dtype == " auto " :
self . kv_cache_dtype = self . dtype
elif self . server_args . kv_cache_dtype == " fp8_e5m2 " :
2025-03-07 10:27:52 -08:00
if is_hip ( ) : # Using natively supported format
2024-11-19 14:06:29 -08:00
self . kv_cache_dtype = torch . float8_e5m2fnuz
else :
self . kv_cache_dtype = torch . float8_e5m2
2025-01-13 13:17:11 +08:00
elif self . server_args . kv_cache_dtype == " fp8_e4m3 " :
if is_cuda ( ) :
self . kv_cache_dtype = torch . float8_e4m3fn
2024-08-26 08:38:11 +08:00
else :
raise ValueError (
f " Unsupported kv_cache_dtype: { self . server_args . kv_cache_dtype } . "
)
2024-05-26 12:51:45 -07:00
self . max_total_num_tokens = self . profile_max_num_token ( total_gpu_memory )
2025-01-02 02:09:08 -08:00
if max_num_reqs is None :
max_num_reqs = min (
max (
int (
self . max_total_num_tokens / self . model_config . context_len * 512
) ,
2048 ,
) ,
4096 ,
)
2025-03-03 00:12:04 -08:00
if SGLANG_CI_SMALL_KV_SIZE :
self . max_total_num_tokens = int ( SGLANG_CI_SMALL_KV_SIZE )
2025-01-02 02:09:08 -08:00
if not self . spec_algorithm . is_none ( ) :
if self . is_draft_worker :
self . max_total_num_tokens = self . server_args . draft_runner_cache_size
2025-03-05 08:06:07 -08:00
max_num_reqs = self . server_args . max_num_reqs
2025-01-02 02:09:08 -08:00
else :
2025-03-05 08:06:07 -08:00
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
# can be concurrently allocated, so we should give a headroom for it.
2025-01-02 02:09:08 -08:00
self . server_args . draft_runner_cache_size = (
self . max_total_num_tokens
2025-03-05 08:06:07 -08:00
# draft
+ max_num_reqs
* self . server_args . speculative_num_steps
* self . server_args . speculative_eagle_topk
# verify
+ max_num_reqs * self . server_args . speculative_num_draft_tokens
# buffer
2025-01-02 02:09:08 -08:00
+ 100
)
2025-03-05 08:06:07 -08:00
# Target worker and draft worker shares the same indices for the
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
self . max_total_num_tokens = self . server_args . draft_runner_cache_size
self . server_args . max_num_reqs = max_num_reqs
2025-01-02 02:09:08 -08:00
2024-07-30 13:33:55 -07:00
if max_total_tokens is not None :
if max_total_tokens > self . max_total_num_tokens :
2024-08-20 08:31:29 -07:00
logging . warning (
2024-07-30 13:33:55 -07:00
f " max_total_tokens= { max_total_tokens } is larger than the profiled value "
f " { self . max_total_num_tokens } . "
f " Use the profiled value instead. "
)
self . max_total_num_tokens = min ( self . max_total_num_tokens , max_total_tokens )
2024-01-19 17:03:33 -08:00
2025-03-12 16:21:49 -07:00
self . max_total_num_tokens = (
self . max_total_num_tokens
/ / self . server_args . page_size
* self . server_args . page_size
)
2024-05-26 12:51:45 -07:00
if self . max_total_num_tokens < = 0 :
2024-01-21 01:39:23 -08:00
raise RuntimeError (
2024-06-17 20:41:24 -07:00
" Not enough memory. Please try to increase --mem-fraction-static. "
2024-01-21 01:39:23 -08:00
)
2024-01-19 17:03:33 -08:00
2025-03-05 08:06:07 -08:00
if self . req_to_token_pool is None :
self . req_to_token_pool = ReqToTokenPool (
size = max_num_reqs + 1 ,
max_context_len = self . model_config . context_len + 4 ,
device = self . device ,
enable_memory_saver = self . server_args . enable_memory_saver ,
)
else :
# Draft worker shares req_to_token_pool with the target worker.
assert self . is_draft_worker
2025-04-05 01:23:02 -07:00
if self . use_mla_backend :
2024-08-05 01:40:33 +08:00
self . token_to_kv_pool = MLATokenToKVPool (
self . max_total_num_tokens ,
2025-03-12 22:22:39 -07:00
page_size = self . page_size ,
2024-08-26 08:38:11 +08:00
dtype = self . kv_cache_dtype ,
2024-08-05 01:40:33 +08:00
kv_lora_rank = self . model_config . kv_lora_rank ,
qk_rope_head_dim = self . model_config . qk_rope_head_dim ,
layer_num = self . model_config . num_hidden_layers ,
2024-10-11 17:05:58 +08:00
device = self . device ,
2025-01-14 03:38:51 +08:00
enable_memory_saver = self . server_args . enable_memory_saver ,
2024-08-05 01:40:33 +08:00
)
2024-10-14 02:00:41 -07:00
elif self . server_args . enable_double_sparsity :
self . token_to_kv_pool = DoubleSparseTokenToKVPool (
self . max_total_num_tokens ,
2025-03-12 22:22:39 -07:00
page_size = self . page_size ,
2024-10-14 02:00:41 -07:00
dtype = self . kv_cache_dtype ,
2025-01-16 12:51:11 -08:00
head_num = self . model_config . get_num_kv_heads ( get_attention_tp_size ( ) ) ,
2024-10-14 02:00:41 -07:00
head_dim = self . model_config . head_dim ,
layer_num = self . model_config . num_hidden_layers ,
device = self . device ,
heavy_channel_num = self . server_args . ds_heavy_channel_num ,
2025-01-14 03:38:51 +08:00
enable_memory_saver = self . server_args . enable_memory_saver ,
2024-10-14 02:00:41 -07:00
)
2024-08-05 01:40:33 +08:00
else :
self . token_to_kv_pool = MHATokenToKVPool (
self . max_total_num_tokens ,
2025-03-12 22:22:39 -07:00
page_size = self . page_size ,
2024-08-26 08:38:11 +08:00
dtype = self . kv_cache_dtype ,
2025-01-16 12:51:11 -08:00
head_num = self . model_config . get_num_kv_heads ( get_attention_tp_size ( ) ) ,
2024-08-05 01:40:33 +08:00
head_dim = self . model_config . head_dim ,
layer_num = self . model_config . num_hidden_layers ,
2024-10-11 17:05:58 +08:00
device = self . device ,
2025-01-14 03:38:51 +08:00
enable_memory_saver = self . server_args . enable_memory_saver ,
2024-08-05 01:40:33 +08:00
)
2025-03-05 21:39:07 -08:00
if self . token_to_kv_pool_allocator is None :
2025-03-12 22:22:39 -07:00
if self . page_size == 1 :
self . token_to_kv_pool_allocator = TokenToKVPoolAllocator (
self . max_total_num_tokens ,
dtype = self . kv_cache_dtype ,
device = self . device ,
kvcache = self . token_to_kv_pool ,
)
else :
self . token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator (
self . max_total_num_tokens ,
page_size = self . page_size ,
dtype = self . kv_cache_dtype ,
device = self . device ,
kvcache = self . token_to_kv_pool ,
)
2025-03-05 21:39:07 -08:00
else :
assert self . is_draft_worker
2024-05-27 21:24:10 -07:00
logger . info (
2024-08-25 14:46:34 -07:00
f " Memory pool end. "
2024-10-11 17:05:58 +08:00
f " avail mem= { get_available_gpu_memory ( self . device , self . gpu_id ) : .2f } GB "
2024-05-27 21:24:10 -07:00
)
2024-01-08 04:37:50 +00:00
2024-06-25 12:46:00 -07:00
def init_cublas ( self ) :
""" We need to run a small matmul to init cublas. Otherwise, it will raise some errors later. """
dtype = torch . float16
device = " cuda "
a = torch . ones ( ( 16 , 16 ) , dtype = dtype , device = device )
b = torch . ones ( ( 16 , 16 ) , dtype = dtype , device = device )
c = a @ b
return c
2024-09-11 11:44:26 -07:00
def init_attention_backend ( self ) :
""" Init attention kernel backend. """
2025-03-07 10:27:52 -08:00
if self . server_args . attention_backend == " flashinfer " :
2025-04-05 01:23:02 -07:00
if not self . use_mla_backend :
from sglang . srt . layers . attention . flashinfer_backend import (
FlashInferAttnBackend ,
)
2025-03-08 00:41:35 -08:00
2025-04-05 01:23:02 -07:00
# Init streams
if self . server_args . speculative_algorithm == " EAGLE " :
self . plan_stream_for_flashinfer = torch . cuda . Stream ( )
self . attn_backend = FlashInferAttnBackend ( self )
else :
from sglang . srt . layers . attention . flashinfer_mla_backend import (
FlashInferMLAAttnBackend ,
)
self . attn_backend = FlashInferMLAAttnBackend ( self )
2025-03-07 10:27:52 -08:00
elif self . server_args . attention_backend == " triton " :
assert self . sliding_window_size is None , (
" Window attention is not supported in the triton attention backend. "
" Please use `--attention-backend flashinfer`. "
)
assert not self . model_config . is_encoder_decoder , (
" Cross attention is not supported in the triton attention backend. "
" Please use `--attention-backend flashinfer`. "
)
if self . server_args . enable_double_sparsity :
2025-03-08 00:41:35 -08:00
from sglang . srt . layers . attention . double_sparsity_backend import (
DoubleSparseAttnBackend ,
)
2025-03-07 10:27:52 -08:00
self . attn_backend = DoubleSparseAttnBackend ( self )
2025-03-07 04:38:53 -08:00
else :
2025-03-08 00:41:35 -08:00
from sglang . srt . layers . attention . triton_backend import TritonAttnBackend
2025-03-07 10:27:52 -08:00
self . attn_backend = TritonAttnBackend ( self )
elif self . server_args . attention_backend == " torch_native " :
2025-03-08 00:41:35 -08:00
from sglang . srt . layers . attention . torch_native_backend import (
TorchNativeAttnBackend ,
)
2025-03-07 10:27:52 -08:00
self . attn_backend = TorchNativeAttnBackend ( self )
2025-03-17 00:07:06 +08:00
elif self . server_args . attention_backend == " flashmla " :
from sglang . srt . layers . attention . flashmla_backend import FlashMLABackend
self . attn_backend = FlashMLABackend ( self )
2025-03-23 23:28:11 -07:00
elif self . server_args . attention_backend == " fa3 " :
assert torch . cuda . get_device_capability ( ) [ 0 ] > = 9 , (
" FlashAttention v3 Backend requires SM>=90. "
" Please use `--attention-backend flashinfer`. "
)
from sglang . srt . layers . attention . flashattention_backend import (
FlashAttentionBackend ,
)
self . attn_backend = FlashAttentionBackend ( self )
2025-04-27 20:58:53 -07:00
elif self . server_args . attention_backend == " cutlass_mla " :
from sglang . srt . layers . attention . cutlass_mla_backend import (
CutlassMLABackend ,
)
self . attn_backend = CutlassMLABackend ( self )
2025-03-07 10:27:52 -08:00
else :
raise ValueError (
f " Invalid attention backend: { self . server_args . attention_backend } "
)
2024-06-20 20:29:06 -07:00
2024-10-14 02:00:41 -07:00
def init_double_sparsity_channel_config ( self , selected_channel ) :
selected_channel = " . " + selected_channel + " _proj "
self . sorted_channels = [ ]
# load channel config
with open ( self . server_args . ds_channel_config_path , " r " ) as f :
channel_config = json . load ( f )
for i in range ( self . model_config . num_hidden_layers ) :
key = " model.layers. " + str ( i ) + " .self_attn " + selected_channel
self . sorted_channels . append (
torch . tensor ( channel_config [ key ] ) [
: , : self . server_args . ds_heavy_channel_num
]
. contiguous ( )
. cuda ( )
)
2024-07-13 05:29:46 -07:00
def init_cuda_graphs ( self ) :
2024-08-24 08:02:23 -07:00
""" Capture cuda graphs. """
2024-09-11 11:44:26 -07:00
self . cuda_graph_runner = None
2024-08-24 08:02:23 -07:00
if not self . is_generation :
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
2024-09-11 11:44:26 -07:00
if self . server_args . disable_cuda_graph :
return
2024-07-13 05:29:46 -07:00
2024-12-02 20:45:53 -08:00
tic = time . time ( )
2025-03-04 21:23:47 -08:00
before_mem = get_available_gpu_memory ( self . device , self . gpu_id )
2025-03-03 00:12:04 -08:00
logger . info (
2025-03-04 21:23:47 -08:00
f " Capture cuda graph begin. This can take up to several minutes. avail mem= { before_mem : .2f } GB "
2025-03-03 00:12:04 -08:00
)
2024-09-11 11:44:26 -07:00
self . cuda_graph_runner = CudaGraphRunner ( self )
2025-03-04 21:23:47 -08:00
after_mem = get_available_gpu_memory ( self . device , self . gpu_id )
2025-03-03 00:12:04 -08:00
logger . info (
2025-03-04 21:23:47 -08:00
f " Capture cuda graph end. Time elapsed: { time . time ( ) - tic : .2f } s. "
2025-04-28 10:57:17 -07:00
f " mem usage= { ( before_mem - after_mem ) : .2f } GB. avail mem= { after_mem : .2f } GB. "
2025-03-03 00:12:04 -08:00
)
2024-07-13 05:29:46 -07:00
2024-11-15 21:26:00 -08:00
def apply_torch_tp ( self ) :
logger . info ( f " Enabling torch tensor parallelism on { self . tp_size } devices. " )
from sglang . srt . model_parallel import tensor_parallel
device_mesh = torch . distributed . init_device_mesh ( self . device , ( self . tp_size , ) )
tensor_parallel ( self . model , device_mesh )
2024-09-30 02:41:11 -07:00
def forward_decode ( self , forward_batch : ForwardBatch ) :
2024-10-17 22:54:14 -07:00
self . attn_backend . init_forward_metadata ( forward_batch )
2024-03-24 19:48:37 +08:00
return self . model . forward (
2024-09-30 02:41:11 -07:00
forward_batch . input_ids , forward_batch . positions , forward_batch
2024-01-08 04:37:50 +00:00
)
2025-03-04 21:23:47 -08:00
def forward_extend (
self , forward_batch : ForwardBatch , skip_attn_backend_init : bool = False
) :
if not skip_attn_backend_init :
self . attn_backend . init_forward_metadata ( forward_batch )
2024-08-26 01:29:12 +08:00
if self . is_generation :
2024-11-25 19:35:04 -05:00
if forward_batch . input_embeds is None :
return self . model . forward (
forward_batch . input_ids , forward_batch . positions , forward_batch
)
else :
return self . model . forward (
forward_batch . input_ids ,
forward_batch . positions ,
forward_batch ,
input_embeds = forward_batch . input_embeds . bfloat16 ( ) ,
)
2024-08-26 01:29:12 +08:00
else :
# Only embedding models have get_embedding parameter
return self . model . forward (
2024-09-30 02:41:11 -07:00
forward_batch . input_ids ,
forward_batch . positions ,
forward_batch ,
2024-08-26 01:29:12 +08:00
get_embedding = True ,
)
2024-01-08 04:37:50 +00:00
2024-11-16 17:01:43 +08:00
def forward_idle ( self , forward_batch : ForwardBatch ) :
return self . model . forward (
forward_batch . input_ids , forward_batch . positions , forward_batch
)
2025-03-06 01:51:12 -08:00
def forward (
self , forward_batch : ForwardBatch , skip_attn_backend_init : bool = False
) - > LogitsProcessorOutput :
2025-01-02 02:09:08 -08:00
if (
forward_batch . forward_mode . is_cuda_graph ( )
and self . cuda_graph_runner
and self . cuda_graph_runner . can_run ( forward_batch )
) :
2025-03-06 01:51:12 -08:00
return self . cuda_graph_runner . replay (
forward_batch , skip_attn_backend_init = skip_attn_backend_init
)
2025-01-02 02:09:08 -08:00
2024-09-30 02:41:11 -07:00
if forward_batch . forward_mode . is_decode ( ) :
return self . forward_decode ( forward_batch )
elif forward_batch . forward_mode . is_extend ( ) :
2025-03-06 01:51:12 -08:00
return self . forward_extend (
forward_batch , skip_attn_backend_init = skip_attn_backend_init
)
2024-11-16 17:01:43 +08:00
elif forward_batch . forward_mode . is_idle ( ) :
return self . forward_idle ( forward_batch )
2024-01-08 04:37:50 +00:00
else :
2025-01-10 13:14:51 -08:00
raise ValueError ( f " Invalid forward mode: { forward_batch . forward_mode } " )
2024-05-21 09:13:37 -07:00
2025-03-03 00:12:04 -08:00
def _preprocess_logits (
self , logits_output : LogitsProcessorOutput , sampling_info : SamplingBatchInfo
) :
2024-12-30 04:51:38 -08:00
# Apply logit bias
2024-11-19 15:04:43 -08:00
if sampling_info . sampling_info_done :
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
if sampling_info . grammars :
sampling_info . sampling_info_done . wait ( )
else :
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info . update_regex_vocab_mask ( )
2024-12-30 04:51:38 -08:00
sampling_info . apply_logits_bias ( logits_output . next_token_logits )
2025-03-03 00:12:04 -08:00
def sample (
self ,
logits_output : LogitsProcessorOutput ,
forward_batch : ForwardBatch ,
) - > torch . Tensor :
""" Sample and compute logprobs and update logits_output.
Args :
logits_output : The logits output from the model forward
forward_batch : The forward batch that generates logits_output
Returns :
A list of next_token_ids
"""
# For duplex models with multiple output streams.
if isinstance ( logits_output , tuple ) :
return torch . stack (
[ self . sample ( values , forward_batch ) for values in logits_output ] ,
axis = - 1 ,
)
self . _preprocess_logits ( logits_output , forward_batch . sampling_info )
2024-12-30 04:51:38 -08:00
# Sample the next tokens
next_token_ids = self . sampler (
logits_output ,
2025-03-03 00:12:04 -08:00
forward_batch . sampling_info ,
2024-12-30 04:51:38 -08:00
forward_batch . return_logprob ,
forward_batch . top_logprobs_nums ,
2025-03-03 00:12:04 -08:00
forward_batch . token_ids_logprobs ,
2024-12-30 04:51:38 -08:00
)
2024-09-30 06:41:49 -07:00
return next_token_ids
2024-10-19 21:44:38 -07:00
@property
def model_is_mrope ( self ) - > bool :
""" Detect if the model has " mrope " rope_scaling type.
mrope requires keep " rope_deltas " between prompt and decoding phases . """
rope_scaling = getattr ( self . model_config . hf_config , " rope_scaling " , { } )
if rope_scaling is None :
return False
2025-04-11 16:29:45 +08:00
is_mrope_enabled = " mrope_section " in rope_scaling
return is_mrope_enabled
2025-03-01 01:53:10 +08:00
2025-03-14 15:40:44 +08:00
def save_remote_model ( self , url : str ) :
from sglang . srt . model_loader . loader import RemoteModelLoader
logger . info ( f " Saving model to { url } " )
RemoteModelLoader . save_model ( self . model , self . model_config . model_path , url )
def save_sharded_model (
self , path : str , pattern : Optional [ str ] = None , max_size : Optional [ int ] = None
) :
from sglang . srt . model_loader . loader import ShardedStateLoader
logger . info (
f " Save sharded model to { path } with pattern { pattern } and max_size { max_size } "
)
ShardedStateLoader . save_model ( self . model , path , pattern , max_size )
2025-03-01 01:53:10 +08:00
def _model_load_weights_direct ( model , named_tensors : List [ Tuple [ str , torch . Tensor ] ] ) :
params_dict = dict ( model . named_parameters ( ) )
for name , tensor in named_tensors :
default_weight_loader ( params_dict [ name ] , tensor )
def _unwrap_tensor ( tensor , tp_rank ) :
if isinstance ( tensor , LocalSerializedTensor ) :
2025-03-27 15:22:33 +08:00
monkey_patch_torch_reductions ( )
tensor = tensor . get ( tp_rank )
return tensor . to ( torch . cuda . current_device ( ) )
2025-03-01 01:53:10 +08:00
@dataclass
class LocalSerializedTensor :
""" torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
The i - th element in the list corresponds to i - th rank ' s GPU. " " "
values : List [ bytes ]
def get ( self , rank : int ) :
return MultiprocessingSerializer . deserialize ( self . values [ rank ] )