Files
sglang/python/sglang/srt/model_executor/model_runner.py

595 lines
22 KiB
Python
Raw Normal View History

2024-07-28 23:07:12 +10:00
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
2024-06-08 02:06:52 -07:00
"""ModelRunner runs the forward passes of the models."""
import gc
2024-01-25 15:29:07 -08:00
import importlib
2024-03-24 15:41:24 +08:00
import importlib.resources
import logging
import pkgutil
2024-01-25 15:29:07 -08:00
from functools import lru_cache
2024-08-28 18:58:52 -07:00
from typing import Optional, Tuple, Type
import torch
2024-05-21 09:13:37 -07:00
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
2024-07-18 04:55:39 +10:00
from vllm.distributed import (
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
set_custom_all_reduce,
2024-07-18 04:55:39 +10:00
)
from vllm.distributed.parallel_state import in_the_same_node_as
2024-08-15 23:29:35 +08:00
from vllm.model_executor.model_loader import get_model
2024-05-21 09:13:37 -07:00
from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
2024-09-11 11:44:26 -07:00
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
2024-08-28 18:58:52 -07:00
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
2024-05-21 11:46:35 -07:00
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
is_generation_model,
is_multimodal_model,
monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check,
)
2024-07-28 23:01:45 -07:00
logger = logging.getLogger(__name__)
2024-01-20 23:20:35 -08:00
class ModelRunner:
2024-09-11 11:44:26 -07:00
"""ModelRunner runs the forward passes of the models."""
def __init__(
self,
model_config: ModelConfig,
mem_fraction_static: float,
gpu_id: int,
tp_rank: int,
tp_size: int,
nccl_port: int,
2024-05-21 11:46:35 -07:00
server_args: ServerArgs,
):
# Parse args
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = tp_size
self.nccl_port = nccl_port
2024-05-21 11:46:35 -07:00
self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(
self.model_config.hf_config.architectures
)
2024-09-17 22:07:53 +08:00
# Model-specific adjustment
2024-09-17 22:07:53 +08:00
if (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
):
logger.info("MLA optimization is tunred on. Use triton backend.")
self.server_args.attention_backend = "triton"
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
global_server_args_dict.update(
{
"attention_backend": server_args.attention_backend,
"sampling_backend": server_args.sampling_backend,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
2024-09-17 19:42:48 +08:00
"disable_mla": server_args.disable_mla,
"torchao_config": server_args.torchao_config,
}
)
2024-09-11 11:44:26 -07:00
# Init componnets
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model()
if server_args.lora_paths is not None:
self.init_lora_manager()
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_running_requests,
server_args.max_total_tokens,
)
self.init_cublas()
2024-09-11 11:44:26 -07:00
self.init_attention_backend()
self.init_cuda_graphs()
def init_torch_distributed(self):
# Init torch distributed
torch.cuda.set_device(self.gpu_id)
logger.info("Init nccl begin.")
2024-07-06 23:34:10 -07:00
if not self.server_args.enable_p2p_check:
2024-07-06 23:34:10 -07:00
monkey_patch_vllm_p2p_access_check(self.gpu_id)
2024-09-29 02:36:12 -07:00
if self.server_args.dist_init_addr:
nccl_init_method = f"tcp://{self.server_args.dist_init_addr}"
else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
init_distributed_environment(
backend="nccl",
world_size=self.tp_size,
rank=self.tp_rank,
local_rank=self.gpu_id,
2024-07-05 10:06:17 -07:00
distributed_init_method=nccl_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
min_per_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
self.tp_group = get_tp_group()
2024-05-24 03:48:53 -07:00
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
self.server_args.disable_cuda_graph_padding = True
logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)
# Check memory for tensor parallelism
2024-05-24 03:48:53 -07:00
if self.tp_size > 1:
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
return min_per_gpu_memory
2024-07-13 05:29:46 -07:00
def load_model(self):
logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
2024-09-16 18:16:27 -07:00
# This can reduce thread conflicts and speed up weight loading.
torch.set_num_threads(1)
2024-08-17 22:45:42 +08:00
if torch.cuda.get_device_capability()[0] < 8:
logger.info(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
2024-08-17 22:45:42 +08:00
)
self.server_args.dtype = "float16"
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
2024-01-20 23:20:35 -08:00
2024-09-16 18:16:27 -07:00
# Prepare the vllm model config
monkey_patch_vllm_dummy_weight_loader()
self.device_config = DeviceConfig()
self.load_config = LoadConfig(load_format=self.server_args.load_format)
self.vllm_model_config = VllmModelConfig(
2024-05-21 11:46:35 -07:00
model=self.server_args.model_path,
quantization=self.server_args.quantization,
2024-05-21 09:13:37 -07:00
tokenizer=None,
tokenizer_mode=None,
2024-05-21 11:46:35 -07:00
trust_remote_code=self.server_args.trust_remote_code,
2024-06-27 23:30:39 -07:00
dtype=self.server_args.dtype,
2024-09-16 18:16:27 -07:00
seed=self.server_args.random_seed,
2024-05-21 09:13:37 -07:00
skip_tokenizer_init=True,
)
if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update(
self.model_config.model_override_args
)
2024-09-16 18:16:27 -07:00
self.dtype = self.vllm_model_config.dtype
2024-05-21 09:13:37 -07:00
2024-09-16 18:16:27 -07:00
# Load the model
2024-05-21 09:13:37 -07:00
self.model = get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
2024-08-27 00:28:24 +10:00
device_config=self.device_config,
2024-05-21 09:13:37 -07:00
parallel_config=None,
scheduler_config=None,
2024-08-27 00:28:24 +10:00
lora_config=None,
cache_config=None,
2024-05-21 09:13:37 -07:00
)
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
logger.info(
f"Load weight end. "
f"type={type(self.model).__name__}, "
2024-06-27 23:30:39 -07:00
f"dtype={self.dtype}, "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
2024-01-20 23:20:35 -08:00
def update_weights(self, model_path: str, load_format: str):
"""Update weights in-place."""
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
get_model_loader,
)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
logger.info(
f"Update weights begin. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
target_device = torch.device(self.device_config.device)
try:
# TODO: Use a better method to check this
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
2024-09-16 18:16:27 -07:00
seed=self.server_args.random_seed,
skip_tokenizer_init=True,
)
except Exception as e:
2024-09-16 18:16:27 -07:00
message = f"Failed to load model config: {e}."
return False, message
load_config = LoadConfig(load_format=load_format)
# Only support vllm DefaultModelLoader for now
loader = get_model_loader(load_config)
if not isinstance(loader, DefaultModelLoader):
2024-09-16 18:16:27 -07:00
message = f"Failed to get model loader: {loader}."
return False, message
def get_weight_iter(config):
iter = loader._get_weights_iterator(
config.model,
config.revision,
fall_back_to_pt=getattr(
self.model, "fall_back_to_pt_during_load", True
),
)
return iter
def model_load_weights(model, iter):
model.load_weights(iter)
for _, module in self.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model
with set_default_torch_dtype(vllm_model_config.dtype):
try:
iter = get_weight_iter(vllm_model_config)
except Exception as e:
2024-09-16 18:16:27 -07:00
message = f"Failed to get weights iterator: {e}."
return False, message
try:
model = model_load_weights(self.model, iter)
except Exception as e:
2024-09-16 18:16:27 -07:00
message = (
f"Failed to update weights: {e}.\nRolling back to original weights."
)
del iter
gc.collect()
iter = get_weight_iter(self.vllm_model_config)
self.model = model_load_weights(self.model, iter)
return False, message
self.model = model
self.server_args.model_path = model_path
self.server_args.load_format = load_format
self.vllm_model_config = vllm_model_config
self.load_config = load_config
self.model_config.path = model_path
logger.info("Update weights end.")
2024-09-16 18:16:27 -07:00
return True, "Succeeded to update model weights."
def init_lora_manager(self):
self.lora_manager = LoRAManager(
base_model=self.model,
lora_paths=self.server_args.lora_paths,
base_hf_config=self.model_config.hf_config,
max_loras_per_batch=self.server_args.max_loras_per_batch,
load_config=self.load_config,
dtype=self.dtype,
)
logger.info("LoRA manager ready.")
def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
if (
self.model_config.attention_arch == AttentionArch.MLA
2024-09-17 19:42:48 +08:00
and not self.server_args.disable_mla
):
cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* self.model_config.num_hidden_layers
* torch._utils._element_size(self.kv_cache_dtype)
)
else:
cell_size = (
self.model_config.get_num_kv_heads(self.tp_size)
* self.model_config.head_dim
* self.model_config.num_hidden_layers
* 2
* torch._utils._element_size(self.kv_cache_dtype)
)
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
2024-05-24 03:48:53 -07:00
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
2024-07-30 13:33:55 -07:00
def init_memory_pool(
self,
total_gpu_memory: int,
max_num_reqs: Optional[int] = None,
max_total_tokens: Optional[int] = None,
2024-07-30 13:33:55 -07:00
):
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
self.kv_cache_dtype = torch.float8_e5m2
else:
raise ValueError(
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
)
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
2024-07-30 13:33:55 -07:00
if max_total_tokens is not None:
if max_total_tokens > self.max_total_num_tokens:
2024-08-20 08:31:29 -07:00
logging.warning(
2024-07-30 13:33:55 -07:00
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
f"{self.max_total_num_tokens}. "
f"Use the profiled value instead."
)
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
if self.max_total_num_tokens <= 0:
raise RuntimeError(
"Not enough memory. Please try to increase --mem-fraction-static."
)
2024-07-26 17:10:07 -07:00
if max_num_reqs is None:
max_num_reqs = min(
max(
int(
self.max_total_num_tokens / self.model_config.context_len * 512
),
2048,
),
4096,
2024-07-26 17:10:07 -07:00
)
self.req_to_token_pool = ReqToTokenPool(
max_num_reqs + 1, self.model_config.context_len + 4, device="cuda"
)
if (
self.model_config.attention_arch == AttentionArch.MLA
2024-09-17 19:42:48 +08:00
and not self.server_args.disable_mla
):
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
)
logger.info(
f"Memory pool end. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
2024-06-25 12:46:00 -07:00
def init_cublas(self):
"""We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
dtype = torch.float16
device = "cuda"
a = torch.ones((16, 16), dtype=dtype, device=device)
b = torch.ones((16, 16), dtype=dtype, device=device)
c = a @ b
return c
2024-09-11 11:44:26 -07:00
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer":
self.attn_backend = FlashInferAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
2024-09-11 11:44:26 -07:00
self.attn_backend = TritonAttnBackend(self)
else:
2024-09-11 11:44:26 -07:00
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
2024-06-20 20:29:06 -07:00
2024-07-13 05:29:46 -07:00
def init_cuda_graphs(self):
"""Capture cuda graphs."""
2024-09-11 11:44:26 -07:00
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
self.cuda_graph_runner = None
if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
2024-09-11 11:44:26 -07:00
if self.server_args.disable_cuda_graph:
return
2024-07-13 05:29:46 -07:00
logger.info("Capture cuda graph begin. This can take up to several minutes.")
2024-09-11 11:44:26 -07:00
self.cuda_graph_runner = CudaGraphRunner(self)
2024-07-13 05:29:46 -07:00
def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
forward_batch.batch_size
):
return self.cuda_graph_runner.replay(forward_batch)
2024-03-24 19:48:37 +08:00
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward_extend(self, forward_batch: ForwardBatch):
if self.is_generation:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
get_embedding=True,
)
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
# Attach attention information
forward_batch.req_to_token_pool = self.req_to_token_pool
forward_batch.token_to_kv_pool = self.token_to_kv_pool
forward_batch.attn_backend = self.attn_backend
forward_batch.attn_backend.init_forward_metadata(forward_batch)
# Attach lora information
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(forward_batch)
2024-09-09 13:49:29 -07:00
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
return self.forward_extend(forward_batch)
else:
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
2024-05-21 09:13:37 -07:00
def _apply_logits_bias(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
):
# Apply logit_bias
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
# repetition
if sampling_info.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
)
# Apply regex vocab_mask
if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
return logits
def sample(
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
) -> torch.Tensor:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
batch.sampling_info.update_regex_vocab_mask(batch)
batch.sampling_info.update_penalties()
logits = self._apply_logits_bias(
logits_output.next_token_logits, batch.sampling_info
)
# Sample the next tokens.
next_token_ids = self.sampler(logits, batch.sampling_info)
return next_token_ids
2024-05-21 09:13:37 -07:00
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
2024-09-25 11:32:21 -07:00
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}. " f"{e}")
continue
2024-05-21 09:13:37 -07:00
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(
entry, list
): # To support multiple model classes in one module
for tmp in entry:
2024-09-25 11:32:21 -07:00
assert (
tmp.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {tmp.__name__}"
model_arch_name_to_cls[tmp.__name__] = tmp
else:
2024-09-25 11:32:21 -07:00
assert (
entry.__name__ not in model_arch_name_to_cls
), f"Duplicated model implementation for {entry.__name__}"
model_arch_name_to_cls[entry.__name__] = entry
2024-05-21 09:13:37 -07:00
return model_arch_name_to_cls
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
model_arch_name_to_cls = import_model_classes()
2024-05-21 09:13:37 -07:00
if model_arch not in model_arch_name_to_cls:
raise ValueError(
f"Unsupported architectures: {model_arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_arch_name_to_cls[model_arch]
# Monkey patch model loader
2024-08-27 00:28:24 +10:00
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)