2024-07-28 23:07:12 +10:00
|
|
|
"""
|
|
|
|
|
Copyright 2023-2024 SGLang Team
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License.
|
|
|
|
|
"""
|
|
|
|
|
|
2024-06-08 02:06:52 -07:00
|
|
|
"""ModelRunner runs the forward passes of the models."""
|
2024-06-12 21:48:40 -07:00
|
|
|
|
2024-08-20 13:48:24 -07:00
|
|
|
import gc
|
2024-01-25 15:29:07 -08:00
|
|
|
import importlib
|
2024-03-24 15:41:24 +08:00
|
|
|
import importlib.resources
|
2024-10-14 02:00:41 -07:00
|
|
|
import json
|
2024-03-24 15:41:24 +08:00
|
|
|
import logging
|
|
|
|
|
import pkgutil
|
2024-01-25 15:29:07 -08:00
|
|
|
from functools import lru_cache
|
2024-09-30 06:41:49 -07:00
|
|
|
from typing import Optional, Type
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
import torch
|
2024-05-21 09:13:37 -07:00
|
|
|
import torch.nn as nn
|
|
|
|
|
from vllm.config import DeviceConfig, LoadConfig
|
|
|
|
|
from vllm.config import ModelConfig as VllmModelConfig
|
2024-07-18 04:55:39 +10:00
|
|
|
from vllm.distributed import (
|
|
|
|
|
get_tp_group,
|
|
|
|
|
init_distributed_environment,
|
|
|
|
|
initialize_model_parallel,
|
2024-08-20 23:44:12 +08:00
|
|
|
set_custom_all_reduce,
|
2024-07-18 04:55:39 +10:00
|
|
|
)
|
2024-08-16 01:39:24 -07:00
|
|
|
from vllm.distributed.parallel_state import in_the_same_node_as
|
2024-08-15 23:29:35 +08:00
|
|
|
from vllm.model_executor.model_loader import get_model
|
2024-05-21 09:13:37 -07:00
|
|
|
from vllm.model_executor.models import ModelRegistry
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-09-10 15:15:08 -07:00
|
|
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
2024-09-30 06:41:49 -07:00
|
|
|
from sglang.srt.constrained import disable_cache
|
2024-10-14 02:00:41 -07:00
|
|
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
2024-09-30 15:54:18 -07:00
|
|
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
|
|
|
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
2024-08-28 18:58:52 -07:00
|
|
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
2024-09-16 21:23:31 -07:00
|
|
|
from sglang.srt.layers.sampler import Sampler
|
2024-09-12 16:46:14 -07:00
|
|
|
from sglang.srt.lora.lora_manager import LoRAManager
|
2024-09-30 06:41:49 -07:00
|
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
2024-08-05 01:40:33 +08:00
|
|
|
from sglang.srt.mem_cache.memory_pool import (
|
2024-10-14 02:00:41 -07:00
|
|
|
DoubleSparseTokenToKVPool,
|
2024-08-05 01:40:33 +08:00
|
|
|
MHATokenToKVPool,
|
|
|
|
|
MLATokenToKVPool,
|
|
|
|
|
ReqToTokenPool,
|
|
|
|
|
)
|
2024-09-30 02:41:11 -07:00
|
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
2024-09-13 20:27:53 -07:00
|
|
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
2024-05-21 11:46:35 -07:00
|
|
|
from sglang.srt.server_args import ServerArgs
|
2024-06-12 21:48:40 -07:00
|
|
|
from sglang.srt.utils import (
|
2024-09-30 06:41:49 -07:00
|
|
|
enable_show_time_cost,
|
2024-06-12 21:48:40 -07:00
|
|
|
get_available_gpu_memory,
|
2024-08-08 16:31:19 -07:00
|
|
|
is_generation_model,
|
2024-06-12 21:48:40 -07:00
|
|
|
is_multimodal_model,
|
2024-06-25 03:38:04 -07:00
|
|
|
monkey_patch_vllm_dummy_weight_loader,
|
2024-06-12 21:48:40 -07:00
|
|
|
monkey_patch_vllm_p2p_access_check,
|
|
|
|
|
)
|
2024-01-29 17:05:42 -08:00
|
|
|
|
2024-07-28 23:01:45 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
2024-01-20 23:20:35 -08:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
class ModelRunner:
|
2024-09-11 11:44:26 -07:00
|
|
|
"""ModelRunner runs the forward passes of the models."""
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
2024-08-28 06:33:05 -07:00
|
|
|
model_config: ModelConfig,
|
2024-05-27 21:24:10 -07:00
|
|
|
mem_fraction_static: float,
|
|
|
|
|
gpu_id: int,
|
|
|
|
|
tp_rank: int,
|
|
|
|
|
tp_size: int,
|
|
|
|
|
nccl_port: int,
|
2024-05-21 11:46:35 -07:00
|
|
|
server_args: ServerArgs,
|
2024-01-08 04:37:50 +00:00
|
|
|
):
|
2024-07-12 12:28:09 -07:00
|
|
|
# Parse args
|
2024-01-08 04:37:50 +00:00
|
|
|
self.model_config = model_config
|
|
|
|
|
self.mem_fraction_static = mem_fraction_static
|
2024-10-11 17:05:58 +08:00
|
|
|
self.device = server_args.device
|
2024-05-27 21:24:10 -07:00
|
|
|
self.gpu_id = gpu_id
|
2024-01-08 04:37:50 +00:00
|
|
|
self.tp_rank = tp_rank
|
|
|
|
|
self.tp_size = tp_size
|
2024-10-11 17:05:58 +08:00
|
|
|
self.dist_port = nccl_port
|
2024-05-21 11:46:35 -07:00
|
|
|
self.server_args = server_args
|
2024-08-28 06:33:05 -07:00
|
|
|
self.is_multimodal_model = is_multimodal_model(
|
|
|
|
|
self.model_config.hf_config.architectures
|
|
|
|
|
)
|
2024-09-17 22:07:53 +08:00
|
|
|
|
2024-09-29 17:42:45 -07:00
|
|
|
# Model-specific adjustment
|
2024-09-17 22:07:53 +08:00
|
|
|
if (
|
|
|
|
|
self.model_config.attention_arch == AttentionArch.MLA
|
|
|
|
|
and not self.server_args.disable_mla
|
|
|
|
|
):
|
2024-10-08 21:11:19 -07:00
|
|
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
2024-09-17 22:07:53 +08:00
|
|
|
self.server_args.attention_backend = "triton"
|
|
|
|
|
|
2024-10-14 02:00:41 -07:00
|
|
|
if self.server_args.enable_double_sparsity:
|
|
|
|
|
logger.info(
|
|
|
|
|
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
|
|
|
|
)
|
|
|
|
|
self.server_args.attention_backend = "triton"
|
|
|
|
|
self.server_args.disable_cuda_graph = True
|
|
|
|
|
if self.server_args.ds_heavy_channel_type is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Please specify the heavy channel type for double sparsity optimization."
|
|
|
|
|
)
|
|
|
|
|
self.init_double_sparsity_channel_config(
|
|
|
|
|
self.server_args.ds_heavy_channel_type
|
|
|
|
|
)
|
|
|
|
|
|
2024-09-29 17:42:45 -07:00
|
|
|
if self.is_multimodal_model:
|
|
|
|
|
logger.info(
|
|
|
|
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
|
|
|
|
)
|
|
|
|
|
server_args.chunked_prefill_size = None
|
|
|
|
|
server_args.mem_fraction_static *= 0.95
|
|
|
|
|
|
2024-09-30 06:41:49 -07:00
|
|
|
# Global vars
|
|
|
|
|
if server_args.show_time_cost:
|
|
|
|
|
enable_show_time_cost()
|
|
|
|
|
if server_args.disable_disk_cache:
|
|
|
|
|
disable_cache()
|
|
|
|
|
|
2024-07-27 20:18:56 -07:00
|
|
|
global_server_args_dict.update(
|
|
|
|
|
{
|
2024-09-10 17:11:16 -07:00
|
|
|
"attention_backend": server_args.attention_backend,
|
|
|
|
|
"sampling_backend": server_args.sampling_backend,
|
2024-08-24 08:02:23 -07:00
|
|
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
2024-09-17 19:42:48 +08:00
|
|
|
"disable_mla": server_args.disable_mla,
|
2024-09-09 05:32:41 -07:00
|
|
|
"torchao_config": server_args.torchao_config,
|
2024-10-12 17:53:23 -07:00
|
|
|
"disable_penalizer": server_args.disable_penalizer,
|
2024-10-18 20:21:24 -07:00
|
|
|
"disable_nan_detection": server_args.disable_nan_detection,
|
2024-07-27 20:18:56 -07:00
|
|
|
}
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-09-11 11:44:26 -07:00
|
|
|
# Init componnets
|
2024-08-24 08:02:23 -07:00
|
|
|
min_per_gpu_memory = self.init_torch_distributed()
|
2024-09-13 20:27:53 -07:00
|
|
|
self.sampler = Sampler()
|
2024-08-24 08:02:23 -07:00
|
|
|
self.load_model()
|
2024-09-12 16:46:14 -07:00
|
|
|
if server_args.lora_paths is not None:
|
|
|
|
|
self.init_lora_manager()
|
2024-08-24 08:02:23 -07:00
|
|
|
self.init_memory_pool(
|
|
|
|
|
min_per_gpu_memory,
|
2024-09-10 17:11:16 -07:00
|
|
|
server_args.max_running_requests,
|
2024-08-24 08:02:23 -07:00
|
|
|
server_args.max_total_tokens,
|
|
|
|
|
)
|
2024-10-11 17:05:58 +08:00
|
|
|
if self.device == "cuda":
|
|
|
|
|
self.init_cublas()
|
|
|
|
|
self.init_attention_backend()
|
|
|
|
|
self.init_cuda_graphs()
|
|
|
|
|
else:
|
2024-10-13 02:10:32 +08:00
|
|
|
self.cuda_graph_runner = None
|
2024-10-11 17:05:58 +08:00
|
|
|
self.init_attention_backend()
|
2024-08-24 08:02:23 -07:00
|
|
|
|
|
|
|
|
def init_torch_distributed(self):
|
2024-10-11 07:22:48 -07:00
|
|
|
logger.info("Init torch distributed begin.")
|
2024-01-08 04:37:50 +00:00
|
|
|
# Init torch distributed
|
2024-10-11 17:05:58 +08:00
|
|
|
if self.device == "cuda":
|
|
|
|
|
torch.cuda.set_device(self.gpu_id)
|
|
|
|
|
backend = "nccl"
|
2024-10-13 02:10:32 +08:00
|
|
|
# ToDO(liangan1):Just use gloo to bypass the initilization fail
|
|
|
|
|
# Need to use xccl for xpu backend in the future
|
|
|
|
|
elif self.device == "xpu":
|
|
|
|
|
torch.xpu.set_device(self.gpu_id)
|
|
|
|
|
backend = "gloo"
|
2024-07-06 23:34:10 -07:00
|
|
|
|
2024-08-24 08:02:23 -07:00
|
|
|
if not self.server_args.enable_p2p_check:
|
2024-07-06 23:34:10 -07:00
|
|
|
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
2024-09-29 02:36:12 -07:00
|
|
|
if self.server_args.dist_init_addr:
|
2024-10-11 17:05:58 +08:00
|
|
|
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
2024-06-17 20:41:24 -07:00
|
|
|
else:
|
2024-10-11 17:05:58 +08:00
|
|
|
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
2024-08-24 08:02:23 -07:00
|
|
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
2024-06-07 12:11:31 -07:00
|
|
|
init_distributed_environment(
|
2024-10-11 17:05:58 +08:00
|
|
|
backend=backend,
|
2024-01-08 04:37:50 +00:00
|
|
|
world_size=self.tp_size,
|
|
|
|
|
rank=self.tp_rank,
|
2024-06-07 19:22:34 -07:00
|
|
|
local_rank=self.gpu_id,
|
2024-10-11 17:05:58 +08:00
|
|
|
distributed_init_method=dist_init_method,
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
|
|
|
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
2024-08-24 08:02:23 -07:00
|
|
|
min_per_gpu_memory = get_available_gpu_memory(
|
2024-10-11 17:05:58 +08:00
|
|
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
2024-05-27 21:24:10 -07:00
|
|
|
)
|
2024-08-16 01:39:24 -07:00
|
|
|
self.tp_group = get_tp_group()
|
2024-05-24 03:48:53 -07:00
|
|
|
|
2024-08-24 08:02:23 -07:00
|
|
|
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
|
|
|
|
# so we disable padding in cuda graph.
|
2024-10-11 17:05:58 +08:00
|
|
|
if self.device == "cuda" and not all(
|
|
|
|
|
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
|
|
|
|
|
):
|
2024-08-24 08:02:23 -07:00
|
|
|
self.server_args.disable_cuda_graph_padding = True
|
|
|
|
|
logger.info(
|
|
|
|
|
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Check memory for tensor parallelism
|
2024-05-24 03:48:53 -07:00
|
|
|
if self.tp_size > 1:
|
2024-10-11 17:05:58 +08:00
|
|
|
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
2024-08-24 08:02:23 -07:00
|
|
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
2024-05-27 21:24:10 -07:00
|
|
|
raise ValueError(
|
|
|
|
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-24 08:02:23 -07:00
|
|
|
return min_per_gpu_memory
|
2024-07-13 05:29:46 -07:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
def load_model(self):
|
2024-05-27 21:24:10 -07:00
|
|
|
logger.info(
|
2024-10-11 17:05:58 +08:00
|
|
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
2024-05-27 21:24:10 -07:00
|
|
|
)
|
2024-09-16 18:16:27 -07:00
|
|
|
|
|
|
|
|
# This can reduce thread conflicts and speed up weight loading.
|
|
|
|
|
torch.set_num_threads(1)
|
2024-10-11 17:05:58 +08:00
|
|
|
if self.device == "cuda":
|
|
|
|
|
if torch.cuda.get_device_capability()[0] < 8:
|
|
|
|
|
logger.info(
|
|
|
|
|
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
|
|
|
|
)
|
|
|
|
|
self.server_args.dtype = "float16"
|
|
|
|
|
if torch.cuda.get_device_capability()[1] < 5:
|
|
|
|
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
2024-01-20 23:20:35 -08:00
|
|
|
|
2024-09-16 18:16:27 -07:00
|
|
|
# Prepare the vllm model config
|
2024-07-27 20:18:56 -07:00
|
|
|
monkey_patch_vllm_dummy_weight_loader()
|
2024-08-20 13:48:24 -07:00
|
|
|
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
|
|
|
|
self.vllm_model_config = VllmModelConfig(
|
2024-05-21 11:46:35 -07:00
|
|
|
model=self.server_args.model_path,
|
|
|
|
|
quantization=self.server_args.quantization,
|
2024-05-21 09:13:37 -07:00
|
|
|
tokenizer=None,
|
|
|
|
|
tokenizer_mode=None,
|
2024-05-21 11:46:35 -07:00
|
|
|
trust_remote_code=self.server_args.trust_remote_code,
|
2024-06-27 23:30:39 -07:00
|
|
|
dtype=self.server_args.dtype,
|
2024-09-16 18:16:27 -07:00
|
|
|
seed=self.server_args.random_seed,
|
2024-05-21 09:13:37 -07:00
|
|
|
skip_tokenizer_init=True,
|
|
|
|
|
)
|
2024-09-01 03:14:56 -07:00
|
|
|
if self.model_config.model_override_args is not None:
|
2024-08-20 13:48:24 -07:00
|
|
|
self.vllm_model_config.hf_config.update(
|
2024-09-01 03:14:56 -07:00
|
|
|
self.model_config.model_override_args
|
2024-08-20 13:48:24 -07:00
|
|
|
)
|
2024-09-16 18:16:27 -07:00
|
|
|
self.dtype = self.vllm_model_config.dtype
|
2024-05-21 09:13:37 -07:00
|
|
|
|
2024-09-16 18:16:27 -07:00
|
|
|
# Load the model
|
2024-05-21 09:13:37 -07:00
|
|
|
self.model = get_model(
|
2024-08-20 13:48:24 -07:00
|
|
|
model_config=self.vllm_model_config,
|
|
|
|
|
load_config=self.load_config,
|
2024-10-11 17:05:58 +08:00
|
|
|
device_config=DeviceConfig(self.device),
|
2024-05-21 09:13:37 -07:00
|
|
|
parallel_config=None,
|
|
|
|
|
scheduler_config=None,
|
2024-08-27 00:28:24 +10:00
|
|
|
lora_config=None,
|
2024-06-07 12:11:31 -07:00
|
|
|
cache_config=None,
|
2024-05-21 09:13:37 -07:00
|
|
|
)
|
2024-08-14 10:37:01 -07:00
|
|
|
self.sliding_window_size = (
|
2024-08-24 08:02:23 -07:00
|
|
|
self.model.get_attention_sliding_window_size()
|
|
|
|
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
2024-08-14 10:37:01 -07:00
|
|
|
else None
|
|
|
|
|
)
|
2024-10-01 00:28:42 -07:00
|
|
|
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
|
2024-08-08 16:31:19 -07:00
|
|
|
self.is_generation = is_generation_model(
|
2024-08-26 01:29:12 +08:00
|
|
|
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
2024-08-08 16:31:19 -07:00
|
|
|
)
|
|
|
|
|
|
2024-05-27 21:24:10 -07:00
|
|
|
logger.info(
|
2024-08-25 14:46:34 -07:00
|
|
|
f"Load weight end. "
|
2024-06-07 19:22:34 -07:00
|
|
|
f"type={type(self.model).__name__}, "
|
2024-06-27 23:30:39 -07:00
|
|
|
f"dtype={self.dtype}, "
|
2024-10-11 17:05:58 +08:00
|
|
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
2024-05-27 21:24:10 -07:00
|
|
|
)
|
2024-01-20 23:20:35 -08:00
|
|
|
|
2024-08-24 08:02:23 -07:00
|
|
|
def update_weights(self, model_path: str, load_format: str):
|
|
|
|
|
"""Update weights in-place."""
|
2024-08-20 13:48:24 -07:00
|
|
|
from vllm.model_executor.model_loader.loader import (
|
|
|
|
|
DefaultModelLoader,
|
|
|
|
|
device_loading_context,
|
|
|
|
|
get_model_loader,
|
|
|
|
|
)
|
|
|
|
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
|
|
|
|
|
|
|
|
|
logger.info(
|
2024-08-25 14:46:34 -07:00
|
|
|
f"Update weights begin. "
|
2024-10-11 17:05:58 +08:00
|
|
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
2024-08-20 13:48:24 -07:00
|
|
|
)
|
|
|
|
|
|
2024-10-11 17:05:58 +08:00
|
|
|
target_device = torch.device(self.device)
|
2024-08-20 13:48:24 -07:00
|
|
|
|
|
|
|
|
try:
|
2024-08-24 08:02:23 -07:00
|
|
|
# TODO: Use a better method to check this
|
2024-08-20 13:48:24 -07:00
|
|
|
vllm_model_config = VllmModelConfig(
|
|
|
|
|
model=model_path,
|
|
|
|
|
quantization=self.server_args.quantization,
|
|
|
|
|
tokenizer=None,
|
|
|
|
|
tokenizer_mode=None,
|
|
|
|
|
trust_remote_code=self.server_args.trust_remote_code,
|
|
|
|
|
dtype=self.server_args.dtype,
|
2024-09-16 18:16:27 -07:00
|
|
|
seed=self.server_args.random_seed,
|
2024-08-20 13:48:24 -07:00
|
|
|
skip_tokenizer_init=True,
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
2024-09-16 18:16:27 -07:00
|
|
|
message = f"Failed to load model config: {e}."
|
|
|
|
|
return False, message
|
2024-08-20 13:48:24 -07:00
|
|
|
|
|
|
|
|
load_config = LoadConfig(load_format=load_format)
|
|
|
|
|
|
|
|
|
|
# Only support vllm DefaultModelLoader for now
|
|
|
|
|
loader = get_model_loader(load_config)
|
|
|
|
|
if not isinstance(loader, DefaultModelLoader):
|
2024-09-16 18:16:27 -07:00
|
|
|
message = f"Failed to get model loader: {loader}."
|
|
|
|
|
return False, message
|
2024-08-20 13:48:24 -07:00
|
|
|
|
|
|
|
|
def get_weight_iter(config):
|
|
|
|
|
iter = loader._get_weights_iterator(
|
|
|
|
|
config.model,
|
|
|
|
|
config.revision,
|
|
|
|
|
fall_back_to_pt=getattr(
|
|
|
|
|
self.model, "fall_back_to_pt_during_load", True
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
return iter
|
|
|
|
|
|
|
|
|
|
def model_load_weights(model, iter):
|
|
|
|
|
model.load_weights(iter)
|
|
|
|
|
for _, module in self.model.named_modules():
|
|
|
|
|
quant_method = getattr(module, "quant_method", None)
|
|
|
|
|
if quant_method is not None:
|
|
|
|
|
with device_loading_context(module, target_device):
|
|
|
|
|
quant_method.process_weights_after_loading(module)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
with set_default_torch_dtype(vllm_model_config.dtype):
|
|
|
|
|
try:
|
|
|
|
|
iter = get_weight_iter(vllm_model_config)
|
|
|
|
|
except Exception as e:
|
2024-09-16 18:16:27 -07:00
|
|
|
message = f"Failed to get weights iterator: {e}."
|
2024-08-20 13:48:24 -07:00
|
|
|
return False, message
|
|
|
|
|
try:
|
|
|
|
|
model = model_load_weights(self.model, iter)
|
|
|
|
|
except Exception as e:
|
2024-09-16 18:16:27 -07:00
|
|
|
message = (
|
|
|
|
|
f"Failed to update weights: {e}.\nRolling back to original weights."
|
|
|
|
|
)
|
2024-08-20 13:48:24 -07:00
|
|
|
del iter
|
|
|
|
|
gc.collect()
|
|
|
|
|
iter = get_weight_iter(self.vllm_model_config)
|
|
|
|
|
self.model = model_load_weights(self.model, iter)
|
|
|
|
|
return False, message
|
|
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
self.server_args.model_path = model_path
|
|
|
|
|
self.server_args.load_format = load_format
|
|
|
|
|
self.vllm_model_config = vllm_model_config
|
|
|
|
|
self.load_config = load_config
|
|
|
|
|
self.model_config.path = model_path
|
|
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
logger.info("Update weights end.")
|
2024-09-16 18:16:27 -07:00
|
|
|
return True, "Succeeded to update model weights."
|
2024-08-20 13:48:24 -07:00
|
|
|
|
2024-09-12 16:46:14 -07:00
|
|
|
def init_lora_manager(self):
|
|
|
|
|
self.lora_manager = LoRAManager(
|
|
|
|
|
base_model=self.model,
|
|
|
|
|
lora_paths=self.server_args.lora_paths,
|
|
|
|
|
base_hf_config=self.model_config.hf_config,
|
|
|
|
|
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
|
|
|
|
load_config=self.load_config,
|
|
|
|
|
dtype=self.dtype,
|
|
|
|
|
)
|
|
|
|
|
logger.info("LoRA manager ready.")
|
|
|
|
|
|
2024-08-24 08:02:23 -07:00
|
|
|
def profile_max_num_token(self, total_gpu_memory: int):
|
2024-05-27 21:24:10 -07:00
|
|
|
available_gpu_memory = get_available_gpu_memory(
|
2024-10-11 17:05:58 +08:00
|
|
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
2024-05-27 21:24:10 -07:00
|
|
|
)
|
2024-08-05 01:40:33 +08:00
|
|
|
if (
|
|
|
|
|
self.model_config.attention_arch == AttentionArch.MLA
|
2024-09-17 19:42:48 +08:00
|
|
|
and not self.server_args.disable_mla
|
2024-08-05 01:40:33 +08:00
|
|
|
):
|
|
|
|
|
cell_size = (
|
|
|
|
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
|
|
|
|
* self.model_config.num_hidden_layers
|
2024-08-26 08:38:11 +08:00
|
|
|
* torch._utils._element_size(self.kv_cache_dtype)
|
2024-08-05 01:40:33 +08:00
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
cell_size = (
|
|
|
|
|
self.model_config.get_num_kv_heads(self.tp_size)
|
|
|
|
|
* self.model_config.head_dim
|
|
|
|
|
* self.model_config.num_hidden_layers
|
|
|
|
|
* 2
|
2024-08-26 08:38:11 +08:00
|
|
|
* torch._utils._element_size(self.kv_cache_dtype)
|
2024-08-05 01:40:33 +08:00
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
|
|
|
|
1 - self.mem_fraction_static
|
|
|
|
|
)
|
2024-05-24 03:48:53 -07:00
|
|
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
2024-01-08 04:37:50 +00:00
|
|
|
return max_num_token
|
|
|
|
|
|
2024-07-30 13:33:55 -07:00
|
|
|
def init_memory_pool(
|
2024-08-24 08:02:23 -07:00
|
|
|
self,
|
|
|
|
|
total_gpu_memory: int,
|
2024-09-10 17:11:16 -07:00
|
|
|
max_num_reqs: Optional[int] = None,
|
|
|
|
|
max_total_tokens: Optional[int] = None,
|
2024-07-30 13:33:55 -07:00
|
|
|
):
|
2024-08-26 08:38:11 +08:00
|
|
|
if self.server_args.kv_cache_dtype == "auto":
|
|
|
|
|
self.kv_cache_dtype = self.dtype
|
|
|
|
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
2024-09-01 17:46:40 +08:00
|
|
|
self.kv_cache_dtype = torch.float8_e5m2
|
2024-08-26 08:38:11 +08:00
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
|
|
|
|
)
|
|
|
|
|
|
2024-05-26 12:51:45 -07:00
|
|
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
2024-07-30 13:33:55 -07:00
|
|
|
if max_total_tokens is not None:
|
|
|
|
|
if max_total_tokens > self.max_total_num_tokens:
|
2024-08-20 08:31:29 -07:00
|
|
|
logging.warning(
|
2024-07-30 13:33:55 -07:00
|
|
|
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
|
|
|
|
|
f"{self.max_total_num_tokens}. "
|
|
|
|
|
f"Use the profiled value instead."
|
|
|
|
|
)
|
|
|
|
|
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
|
2024-01-19 17:03:33 -08:00
|
|
|
|
2024-05-26 12:51:45 -07:00
|
|
|
if self.max_total_num_tokens <= 0:
|
2024-01-21 01:39:23 -08:00
|
|
|
raise RuntimeError(
|
2024-06-17 20:41:24 -07:00
|
|
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
2024-01-21 01:39:23 -08:00
|
|
|
)
|
2024-01-19 17:03:33 -08:00
|
|
|
|
2024-07-26 17:10:07 -07:00
|
|
|
if max_num_reqs is None:
|
2024-07-30 01:58:31 -07:00
|
|
|
max_num_reqs = min(
|
|
|
|
|
max(
|
|
|
|
|
int(
|
|
|
|
|
self.max_total_num_tokens / self.model_config.context_len * 512
|
|
|
|
|
),
|
|
|
|
|
2048,
|
|
|
|
|
),
|
2024-09-10 17:11:16 -07:00
|
|
|
4096,
|
2024-07-26 17:10:07 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.req_to_token_pool = ReqToTokenPool(
|
2024-10-03 18:29:49 -07:00
|
|
|
size=max_num_reqs + 1,
|
|
|
|
|
max_context_len=self.model_config.context_len + 4,
|
2024-10-11 17:05:58 +08:00
|
|
|
device=self.device,
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
2024-08-05 01:40:33 +08:00
|
|
|
if (
|
|
|
|
|
self.model_config.attention_arch == AttentionArch.MLA
|
2024-09-17 19:42:48 +08:00
|
|
|
and not self.server_args.disable_mla
|
2024-08-05 01:40:33 +08:00
|
|
|
):
|
|
|
|
|
self.token_to_kv_pool = MLATokenToKVPool(
|
|
|
|
|
self.max_total_num_tokens,
|
2024-08-26 08:38:11 +08:00
|
|
|
dtype=self.kv_cache_dtype,
|
2024-08-05 01:40:33 +08:00
|
|
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
|
|
|
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
|
|
|
|
layer_num=self.model_config.num_hidden_layers,
|
2024-10-11 17:05:58 +08:00
|
|
|
device=self.device,
|
2024-08-05 01:40:33 +08:00
|
|
|
)
|
2024-10-14 02:00:41 -07:00
|
|
|
elif self.server_args.enable_double_sparsity:
|
|
|
|
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
|
|
|
|
self.max_total_num_tokens,
|
|
|
|
|
dtype=self.kv_cache_dtype,
|
|
|
|
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
|
|
|
|
head_dim=self.model_config.head_dim,
|
|
|
|
|
layer_num=self.model_config.num_hidden_layers,
|
|
|
|
|
device=self.device,
|
|
|
|
|
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
|
|
|
|
)
|
2024-08-05 01:40:33 +08:00
|
|
|
else:
|
|
|
|
|
self.token_to_kv_pool = MHATokenToKVPool(
|
|
|
|
|
self.max_total_num_tokens,
|
2024-08-26 08:38:11 +08:00
|
|
|
dtype=self.kv_cache_dtype,
|
2024-08-05 01:40:33 +08:00
|
|
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
|
|
|
|
head_dim=self.model_config.head_dim,
|
|
|
|
|
layer_num=self.model_config.num_hidden_layers,
|
2024-10-11 17:05:58 +08:00
|
|
|
device=self.device,
|
2024-08-05 01:40:33 +08:00
|
|
|
)
|
2024-05-27 21:24:10 -07:00
|
|
|
logger.info(
|
2024-08-25 14:46:34 -07:00
|
|
|
f"Memory pool end. "
|
2024-10-11 17:05:58 +08:00
|
|
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
2024-05-27 21:24:10 -07:00
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-06-25 12:46:00 -07:00
|
|
|
def init_cublas(self):
|
|
|
|
|
"""We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
|
|
|
|
|
dtype = torch.float16
|
|
|
|
|
device = "cuda"
|
|
|
|
|
a = torch.ones((16, 16), dtype=dtype, device=device)
|
|
|
|
|
b = torch.ones((16, 16), dtype=dtype, device=device)
|
|
|
|
|
c = a @ b
|
|
|
|
|
return c
|
|
|
|
|
|
2024-09-11 11:44:26 -07:00
|
|
|
def init_attention_backend(self):
|
|
|
|
|
"""Init attention kernel backend."""
|
|
|
|
|
if self.server_args.attention_backend == "flashinfer":
|
|
|
|
|
self.attn_backend = FlashInferAttnBackend(self)
|
|
|
|
|
elif self.server_args.attention_backend == "triton":
|
|
|
|
|
assert self.sliding_window_size is None, (
|
|
|
|
|
"Window attention is not supported in the triton attention backend. "
|
|
|
|
|
"Please use `--attention-backend flashinfer`."
|
2024-08-13 17:01:26 -07:00
|
|
|
)
|
2024-10-01 00:28:42 -07:00
|
|
|
assert not self.has_cross_attention, (
|
|
|
|
|
"Cross attention is not supported in the triton attention backend. "
|
|
|
|
|
"Please use `--attention-backend flashinfer`."
|
|
|
|
|
)
|
2024-10-14 02:00:41 -07:00
|
|
|
if self.server_args.enable_double_sparsity:
|
|
|
|
|
self.attn_backend = DoubleSparseAttnBackend(self)
|
|
|
|
|
else:
|
|
|
|
|
self.attn_backend = TritonAttnBackend(self)
|
2024-08-13 17:01:26 -07:00
|
|
|
else:
|
2024-09-11 11:44:26 -07:00
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
2024-08-13 17:01:26 -07:00
|
|
|
)
|
2024-06-20 20:29:06 -07:00
|
|
|
|
2024-10-14 02:00:41 -07:00
|
|
|
def init_double_sparsity_channel_config(self, selected_channel):
|
|
|
|
|
|
|
|
|
|
selected_channel = "." + selected_channel + "_proj"
|
|
|
|
|
self.sorted_channels = []
|
|
|
|
|
# load channel config
|
|
|
|
|
with open(self.server_args.ds_channel_config_path, "r") as f:
|
|
|
|
|
channel_config = json.load(f)
|
|
|
|
|
|
|
|
|
|
for i in range(self.model_config.num_hidden_layers):
|
|
|
|
|
key = "model.layers." + str(i) + ".self_attn" + selected_channel
|
|
|
|
|
self.sorted_channels.append(
|
|
|
|
|
torch.tensor(channel_config[key])[
|
|
|
|
|
:, : self.server_args.ds_heavy_channel_num
|
|
|
|
|
]
|
|
|
|
|
.contiguous()
|
|
|
|
|
.cuda()
|
|
|
|
|
)
|
|
|
|
|
|
2024-07-13 05:29:46 -07:00
|
|
|
def init_cuda_graphs(self):
|
2024-08-24 08:02:23 -07:00
|
|
|
"""Capture cuda graphs."""
|
2024-09-11 11:44:26 -07:00
|
|
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
|
|
|
|
|
|
|
|
|
self.cuda_graph_runner = None
|
|
|
|
|
|
2024-08-24 08:02:23 -07:00
|
|
|
if not self.is_generation:
|
|
|
|
|
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
|
|
|
|
return
|
|
|
|
|
|
2024-09-11 11:44:26 -07:00
|
|
|
if self.server_args.disable_cuda_graph:
|
|
|
|
|
return
|
2024-07-13 05:29:46 -07:00
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
2024-09-11 11:44:26 -07:00
|
|
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
2024-07-13 05:29:46 -07:00
|
|
|
|
2024-09-30 02:41:11 -07:00
|
|
|
def forward_decode(self, forward_batch: ForwardBatch):
|
2024-09-29 20:28:45 -07:00
|
|
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
2024-09-30 02:41:11 -07:00
|
|
|
forward_batch.batch_size
|
2024-09-29 20:28:45 -07:00
|
|
|
):
|
2024-09-30 02:41:11 -07:00
|
|
|
return self.cuda_graph_runner.replay(forward_batch)
|
2024-08-08 01:11:22 -07:00
|
|
|
|
2024-10-17 22:54:14 -07:00
|
|
|
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
|
|
|
|
|
self.attn_backend.init_forward_metadata(forward_batch)
|
2024-03-24 19:48:37 +08:00
|
|
|
return self.model.forward(
|
2024-09-30 02:41:11 -07:00
|
|
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
|
|
|
|
|
2024-09-30 02:41:11 -07:00
|
|
|
def forward_extend(self, forward_batch: ForwardBatch):
|
2024-10-17 22:54:14 -07:00
|
|
|
self.attn_backend.init_forward_metadata(forward_batch)
|
2024-08-26 01:29:12 +08:00
|
|
|
if self.is_generation:
|
|
|
|
|
return self.model.forward(
|
2024-09-30 02:41:11 -07:00
|
|
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
2024-08-26 01:29:12 +08:00
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# Only embedding models have get_embedding parameter
|
|
|
|
|
return self.model.forward(
|
2024-09-30 02:41:11 -07:00
|
|
|
forward_batch.input_ids,
|
|
|
|
|
forward_batch.positions,
|
|
|
|
|
forward_batch,
|
2024-08-26 01:29:12 +08:00
|
|
|
get_embedding=True,
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-09-30 02:41:11 -07:00
|
|
|
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
|
|
|
|
if forward_batch.forward_mode.is_decode():
|
|
|
|
|
return self.forward_decode(forward_batch)
|
|
|
|
|
elif forward_batch.forward_mode.is_extend():
|
|
|
|
|
return self.forward_extend(forward_batch)
|
2024-01-08 04:37:50 +00:00
|
|
|
else:
|
2024-09-30 02:41:11 -07:00
|
|
|
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
2024-05-21 09:13:37 -07:00
|
|
|
|
2024-09-30 06:41:49 -07:00
|
|
|
def sample(
|
|
|
|
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
|
|
|
|
sampling_info = forward_batch.sampling_info
|
|
|
|
|
sampling_info.update_regex_vocab_mask()
|
|
|
|
|
sampling_info.update_penalties()
|
|
|
|
|
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
|
|
|
|
|
|
|
|
|
# Sample the next tokens.
|
|
|
|
|
next_token_ids = self.sampler(logits, sampling_info)
|
|
|
|
|
return next_token_ids
|
|
|
|
|
|
|
|
|
|
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
2024-09-13 20:27:53 -07:00
|
|
|
# Apply logit_bias
|
|
|
|
|
if sampling_info.logit_bias is not None:
|
|
|
|
|
logits.add_(sampling_info.logit_bias)
|
|
|
|
|
|
|
|
|
|
# min-token, presence, frequency
|
|
|
|
|
if sampling_info.linear_penalties is not None:
|
2024-09-30 06:41:49 -07:00
|
|
|
logits.add_(sampling_info.linear_penalties)
|
2024-09-13 20:27:53 -07:00
|
|
|
|
|
|
|
|
# repetition
|
|
|
|
|
if sampling_info.scaling_penalties is not None:
|
|
|
|
|
logits = torch.where(
|
|
|
|
|
logits > 0,
|
|
|
|
|
logits / sampling_info.scaling_penalties,
|
|
|
|
|
logits * sampling_info.scaling_penalties,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Apply regex vocab_mask
|
|
|
|
|
if sampling_info.vocab_mask is not None:
|
|
|
|
|
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
|
|
|
|
|
|
|
|
|
return logits
|
|
|
|
|
|
2024-05-21 09:13:37 -07:00
|
|
|
|
|
|
|
|
@lru_cache()
|
|
|
|
|
def import_model_classes():
|
|
|
|
|
model_arch_name_to_cls = {}
|
|
|
|
|
package_name = "sglang.srt.models"
|
|
|
|
|
package = importlib.import_module(package_name)
|
|
|
|
|
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
|
|
|
|
if not ispkg:
|
2024-09-25 11:32:21 -07:00
|
|
|
try:
|
|
|
|
|
module = importlib.import_module(name)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Ignore import error when loading {name}. " f"{e}")
|
|
|
|
|
continue
|
2024-05-21 09:13:37 -07:00
|
|
|
if hasattr(module, "EntryClass"):
|
2024-05-27 03:29:51 +08:00
|
|
|
entry = module.EntryClass
|
2024-06-12 21:48:40 -07:00
|
|
|
if isinstance(
|
|
|
|
|
entry, list
|
|
|
|
|
): # To support multiple model classes in one module
|
2024-05-27 21:24:10 -07:00
|
|
|
for tmp in entry:
|
2024-09-25 11:32:21 -07:00
|
|
|
assert (
|
|
|
|
|
tmp.__name__ not in model_arch_name_to_cls
|
|
|
|
|
), f"Duplicated model implementation for {tmp.__name__}"
|
2024-05-27 21:24:10 -07:00
|
|
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
2024-05-27 03:29:51 +08:00
|
|
|
else:
|
2024-09-25 11:32:21 -07:00
|
|
|
assert (
|
|
|
|
|
entry.__name__ not in model_arch_name_to_cls
|
|
|
|
|
), f"Duplicated model implementation for {entry.__name__}"
|
2024-05-27 03:29:51 +08:00
|
|
|
model_arch_name_to_cls[entry.__name__] = entry
|
2024-06-12 07:39:52 +08:00
|
|
|
|
2024-05-21 09:13:37 -07:00
|
|
|
return model_arch_name_to_cls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
|
|
|
|
model_arch_name_to_cls = import_model_classes()
|
2024-06-12 07:39:52 +08:00
|
|
|
|
2024-05-21 09:13:37 -07:00
|
|
|
if model_arch not in model_arch_name_to_cls:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Unsupported architectures: {model_arch}. "
|
|
|
|
|
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
|
|
|
|
)
|
|
|
|
|
return model_arch_name_to_cls[model_arch]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Monkey patch model loader
|
2024-08-27 00:28:24 +10:00
|
|
|
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|