Files
sglang/python/sglang/srt/managers/controller/model_runner.py

464 lines
17 KiB
Python
Raw Normal View History

2024-06-08 02:06:52 -07:00
"""ModelRunner runs the forward passes of the models."""
2024-01-25 15:29:07 -08:00
import importlib
2024-03-24 15:41:24 +08:00
import importlib.resources
import logging
import pkgutil
from dataclasses import dataclass
2024-01-25 15:29:07 -08:00
from functools import lru_cache
2024-05-21 09:13:37 -07:00
from typing import List, Optional, Type
import numpy as np
import torch
2024-05-21 09:13:37 -07:00
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import initialize_model_parallel, init_distributed_environment
2024-05-21 09:13:37 -07:00
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
2024-04-22 22:38:09 +08:00
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
2024-05-21 11:46:35 -07:00
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
2024-04-22 22:38:09 +08:00
logger = logging.getLogger("srt.model_runner")
2024-01-20 23:20:35 -08:00
2024-03-11 12:14:27 +08:00
# for server args in model endpoints
2024-05-12 04:54:07 -07:00
global_server_args_dict = {}
@dataclass
class InputMetadata:
model_runner: "ModelRunner"
forward_mode: ForwardMode
batch_size: int
total_num_tokens: int
max_seq_len: int
req_pool_indices: torch.Tensor
start_loc: torch.Tensor
seq_lens: torch.Tensor
prefix_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
# for extend
extend_seq_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
max_extend_len: int = 0
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None
2024-01-23 05:07:30 -08:00
return_logprob: bool = False
2024-05-12 04:54:07 -07:00
top_logprobs_nums: List[int] = None
# for flashinfer
qo_indptr: torch.Tensor = None
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
kv_last_page_len: torch.Tensor = None
prefill_wrapper = None
decode_wrapper = None
def init_flashinfer_args(self, tp_size):
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
2024-05-12 04:54:07 -07:00
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
2024-05-12 15:05:40 -07:00
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
seq_lens_cpu = self.seq_lens.cpu().numpy()
self.kv_indices = torch.cat(
[
self.req_to_token_pool.req_to_token[
2024-05-12 08:18:53 -07:00
req_pool_indices_cpu[i], : seq_lens_cpu[i]
]
for i in range(self.batch_size)
],
dim=0,
).contiguous()
workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
)
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
):
self.qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
2024-03-12 21:45:58 +08:00
args = [
self.qo_indptr,
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
2024-05-14 07:57:00 +08:00
self.model_runner.model_config.head_dim,
2024-03-12 21:45:58 +08:00
]
self.prefill_wrapper.begin_forward(*args)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
self.model_runner.model_config.head_dim,
1,
"NONE",
"float16",
)
def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.max_extend_len = int(torch.max(self.extend_seq_lens))
@classmethod
def create(
cls,
model_runner,
tp_size,
forward_mode,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
out_cache_cont_start=None,
out_cache_cont_end=None,
2024-03-28 14:34:49 +08:00
top_logprobs_nums=None,
2024-01-23 05:07:30 -08:00
return_logprob=False,
):
batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
total_num_tokens = int(torch.sum(seq_lens))
max_seq_len = int(torch.max(seq_lens))
if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
other_kv_index = model_runner.req_to_token_pool.req_to_token[
req_pool_indices[0], seq_lens[0] - 1
].item()
else:
2024-05-12 04:54:07 -07:00
seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_cpu = prefix_lens.cpu().numpy()
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
positions = torch.tensor(
np.concatenate(
[
np.arange(
2024-05-12 04:54:07 -07:00
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
)
for i in range(batch_size)
],
axis=0,
),
device="cuda",
)
other_kv_index = None
ret = cls(
model_runner=model_runner,
forward_mode=forward_mode,
batch_size=batch_size,
total_num_tokens=total_num_tokens,
max_seq_len=max_seq_len,
req_pool_indices=req_pool_indices,
start_loc=start_loc,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
positions=positions,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
other_kv_index=other_kv_index,
2024-05-12 04:54:07 -07:00
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
)
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()
2024-03-11 05:49:27 -07:00
if global_server_args_dict.get("enable_flashinfer", False):
ret.init_flashinfer_args(tp_size)
return ret
class ModelRunner:
def __init__(
self,
model_config,
mem_fraction_static: float,
gpu_id: int,
tp_rank: int,
tp_size: int,
nccl_port: int,
2024-05-21 11:46:35 -07:00
server_args: ServerArgs,
):
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = tp_size
self.nccl_port = nccl_port
2024-05-21 11:46:35 -07:00
self.server_args = server_args
2024-03-11 20:06:52 +08:00
global global_server_args_dict
2024-05-21 11:46:35 -07:00
global_server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
# Init torch distributed
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
monkey_patch_vllm_p2p_access_check()
init_distributed_environment(
backend="nccl",
world_size=self.tp_size,
rank=self.tp_rank,
local_rank=self.gpu_id,
distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
2024-05-24 03:48:53 -07:00
if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
2024-05-24 03:48:53 -07:00
if total_local_gpu_memory < total_gpu_memory * 0.9:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
self.load_model()
self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config)
def load_model(self):
logger.info(
f"[gpu_id={self.gpu_id}] Load weight begin. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
2024-01-20 23:20:35 -08:00
2024-05-21 09:13:37 -07:00
device_config = DeviceConfig()
2024-05-21 11:46:35 -07:00
load_config = LoadConfig(load_format=self.server_args.load_format)
2024-05-21 09:13:37 -07:00
vllm_model_config = VllmModelConfig(
2024-05-21 11:46:35 -07:00
model=self.server_args.model_path,
quantization=self.server_args.quantization,
2024-05-21 09:13:37 -07:00
tokenizer=None,
tokenizer_mode=None,
2024-05-21 11:46:35 -07:00
trust_remote_code=self.server_args.trust_remote_code,
2024-05-21 09:13:37 -07:00
dtype=torch.float16,
seed=42,
skip_tokenizer_init=True,
)
if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
self.model = get_model(
model_config=vllm_model_config,
device_config=device_config,
load_config=load_config,
lora_config=None,
vision_language_config=None,
parallel_config=None,
scheduler_config=None,
cache_config=None,
2024-05-21 09:13:37 -07:00
)
logger.info(
f"[gpu_id={self.gpu_id}] Load weight end. "
f"type={type(self.model).__name__}, "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
2024-01-20 23:20:35 -08:00
def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
2024-03-11 12:14:27 +08:00
head_dim = self.model_config.head_dim
head_num = self.model_config.num_key_value_heads // self.tp_size
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
2024-05-24 03:48:53 -07:00
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
def init_memory_pool(self, total_gpu_memory):
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if self.max_total_num_tokens <= 0:
raise RuntimeError(
"Not enought memory. Please try to increase --mem-fraction-static."
)
self.req_to_token_pool = ReqToTokenPool(
int(self.max_total_num_tokens / self.model_config.context_len * 256),
self.model_config.context_len + 8,
)
self.token_to_kv_pool = TokenToKVPool(
self.max_total_num_tokens,
dtype=torch.float16,
head_num=self.model_config.num_key_value_heads // self.tp_size,
2024-03-11 12:14:27 +08:00
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
)
logger.info(
f"[gpu_id={self.gpu_id}] Memory pool end. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
@torch.inference_mode()
2024-03-24 19:48:37 +08:00
def forward_prefill(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.PREFILL,
tp_size=self.tp_size,
2024-03-24 19:48:37 +08:00
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
2024-03-28 14:34:49 +08:00
top_logprobs_nums=batch.top_logprobs_nums,
2024-03-24 19:48:37 +08:00
return_logprob=batch.return_logprob,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
@torch.inference_mode()
2024-03-24 19:48:37 +08:00
def forward_extend(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
2024-03-24 19:48:37 +08:00
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
2024-03-28 14:34:49 +08:00
top_logprobs_nums=batch.top_logprobs_nums,
2024-03-24 19:48:37 +08:00
return_logprob=batch.return_logprob,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
@torch.inference_mode()
2024-03-24 19:48:37 +08:00
def forward_decode(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.DECODE,
tp_size=self.tp_size,
2024-03-24 19:48:37 +08:00
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
out_cache_cont_start=batch.out_cache_cont_start,
out_cache_cont_end=batch.out_cache_cont_end,
2024-03-28 14:34:49 +08:00
top_logprobs_nums=batch.top_logprobs_nums,
2024-03-24 19:48:37 +08:00
return_logprob=batch.return_logprob,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
@torch.inference_mode()
2024-03-24 19:48:37 +08:00
def forward_extend_multi_modal(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
2024-03-24 19:48:37 +08:00
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
2024-03-28 14:34:49 +08:00
top_logprobs_nums=batch.top_logprobs_nums,
2024-03-24 19:48:37 +08:00
return_logprob=batch.return_logprob,
)
return self.model.forward(
2024-03-24 19:48:37 +08:00
batch.input_ids,
input_metadata.positions,
input_metadata,
2024-03-24 19:48:37 +08:00
batch.pixel_values,
batch.image_sizes,
batch.image_offsets,
)
2024-03-24 19:48:37 +08:00
def forward(self, batch: Batch, forward_mode: ForwardMode):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
2024-03-24 19:48:37 +08:00
return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE:
return self.forward_decode(batch)
elif forward_mode == ForwardMode.EXTEND:
2024-03-24 19:48:37 +08:00
return self.forward_extend(batch)
elif forward_mode == ForwardMode.PREFILL:
2024-03-24 19:48:37 +08:00
return self.forward_prefill(batch)
else:
raise ValueError(f"Invaid forward mode: {forward_mode}")
2024-05-21 09:13:37 -07:00
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
module = importlib.import_module(name)
if hasattr(module, "EntryClass"):
entry = module.EntryClass
if isinstance(entry, list): # To support multiple model classes in one module
for tmp in entry:
model_arch_name_to_cls[tmp.__name__] = tmp
else:
model_arch_name_to_cls[entry.__name__] = entry
2024-05-21 09:13:37 -07:00
return model_arch_name_to_cls
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
model_arch_name_to_cls = import_model_classes()
if model_arch not in model_arch_name_to_cls:
raise ValueError(
f"Unsupported architectures: {model_arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_arch_name_to_cls[model_arch]
# Monkey patch model loader
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)