2024-06-08 02:06:52 -07:00
|
|
|
"""ModelRunner runs the forward passes of the models."""
|
2024-01-25 15:29:07 -08:00
|
|
|
import importlib
|
2024-03-24 15:41:24 +08:00
|
|
|
import importlib.resources
|
|
|
|
|
import logging
|
|
|
|
|
import pkgutil
|
2024-01-08 04:37:50 +00:00
|
|
|
from dataclasses import dataclass
|
2024-01-25 15:29:07 -08:00
|
|
|
from functools import lru_cache
|
2024-05-21 09:13:37 -07:00
|
|
|
from typing import List, Optional, Type
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
2024-05-21 09:13:37 -07:00
|
|
|
import torch.nn as nn
|
|
|
|
|
from vllm.config import DeviceConfig, LoadConfig
|
|
|
|
|
from vllm.config import ModelConfig as VllmModelConfig
|
2024-06-07 12:11:31 -07:00
|
|
|
from vllm.distributed import initialize_model_parallel, init_distributed_environment
|
2024-05-21 09:13:37 -07:00
|
|
|
from vllm.model_executor.model_loader import get_model
|
|
|
|
|
from vllm.model_executor.models import ModelRegistry
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-05-27 21:24:10 -07:00
|
|
|
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
2024-04-22 22:38:09 +08:00
|
|
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
2024-05-21 11:46:35 -07:00
|
|
|
from sglang.srt.server_args import ServerArgs
|
2024-06-07 19:22:34 -07:00
|
|
|
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
|
2024-04-22 22:38:09 +08:00
|
|
|
|
2024-01-29 17:05:42 -08:00
|
|
|
|
2024-05-27 21:24:10 -07:00
|
|
|
logger = logging.getLogger("srt.model_runner")
|
2024-01-20 23:20:35 -08:00
|
|
|
|
2024-03-11 12:14:27 +08:00
|
|
|
# for server args in model endpoints
|
2024-05-12 04:54:07 -07:00
|
|
|
global_server_args_dict = {}
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class InputMetadata:
|
|
|
|
|
model_runner: "ModelRunner"
|
|
|
|
|
forward_mode: ForwardMode
|
|
|
|
|
batch_size: int
|
|
|
|
|
total_num_tokens: int
|
|
|
|
|
max_seq_len: int
|
|
|
|
|
req_pool_indices: torch.Tensor
|
|
|
|
|
start_loc: torch.Tensor
|
|
|
|
|
seq_lens: torch.Tensor
|
|
|
|
|
prefix_lens: torch.Tensor
|
|
|
|
|
positions: torch.Tensor
|
|
|
|
|
req_to_token_pool: ReqToTokenPool
|
|
|
|
|
token_to_kv_pool: TokenToKVPool
|
|
|
|
|
|
|
|
|
|
# for extend
|
|
|
|
|
extend_seq_lens: torch.Tensor = None
|
|
|
|
|
extend_start_loc: torch.Tensor = None
|
|
|
|
|
max_extend_len: int = 0
|
|
|
|
|
|
|
|
|
|
out_cache_loc: torch.Tensor = None
|
|
|
|
|
out_cache_cont_start: torch.Tensor = None
|
|
|
|
|
out_cache_cont_end: torch.Tensor = None
|
|
|
|
|
|
|
|
|
|
other_kv_index: torch.Tensor = None
|
2024-01-23 05:07:30 -08:00
|
|
|
return_logprob: bool = False
|
2024-05-12 04:54:07 -07:00
|
|
|
top_logprobs_nums: List[int] = None
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
# for flashinfer
|
|
|
|
|
qo_indptr: torch.Tensor = None
|
|
|
|
|
kv_indptr: torch.Tensor = None
|
|
|
|
|
kv_indices: torch.Tensor = None
|
|
|
|
|
kv_last_page_len: torch.Tensor = None
|
|
|
|
|
prefill_wrapper = None
|
|
|
|
|
decode_wrapper = None
|
|
|
|
|
|
|
|
|
|
def init_flashinfer_args(self, tp_size):
|
2024-02-06 19:28:29 -08:00
|
|
|
from flashinfer import (
|
|
|
|
|
BatchDecodeWithPagedKVCacheWrapper,
|
|
|
|
|
BatchPrefillWithPagedKVCacheWrapper,
|
|
|
|
|
)
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
self.kv_indptr = torch.zeros(
|
|
|
|
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
|
|
|
)
|
|
|
|
|
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
2024-05-12 04:54:07 -07:00
|
|
|
self.kv_last_page_len = torch.ones(
|
|
|
|
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
|
|
|
|
)
|
2024-05-12 15:05:40 -07:00
|
|
|
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
|
|
|
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
2024-01-08 04:37:50 +00:00
|
|
|
self.kv_indices = torch.cat(
|
|
|
|
|
[
|
|
|
|
|
self.req_to_token_pool.req_to_token[
|
2024-05-12 08:18:53 -07:00
|
|
|
req_pool_indices_cpu[i], : seq_lens_cpu[i]
|
2024-01-08 04:37:50 +00:00
|
|
|
]
|
|
|
|
|
for i in range(self.batch_size)
|
|
|
|
|
],
|
|
|
|
|
dim=0,
|
|
|
|
|
).contiguous()
|
|
|
|
|
|
2024-02-11 05:50:13 -08:00
|
|
|
workspace_buffer = torch.empty(
|
|
|
|
|
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
if (
|
|
|
|
|
self.forward_mode == ForwardMode.PREFILL
|
|
|
|
|
or self.forward_mode == ForwardMode.EXTEND
|
|
|
|
|
):
|
|
|
|
|
self.qo_indptr = torch.zeros(
|
|
|
|
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
|
|
|
)
|
|
|
|
|
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
2024-02-11 05:50:13 -08:00
|
|
|
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
|
|
|
|
workspace_buffer, "NHD"
|
|
|
|
|
)
|
2024-03-12 21:45:58 +08:00
|
|
|
args = [
|
2024-01-08 04:37:50 +00:00
|
|
|
self.qo_indptr,
|
2024-02-06 19:28:29 -08:00
|
|
|
self.kv_indptr,
|
|
|
|
|
self.kv_indices,
|
|
|
|
|
self.kv_last_page_len,
|
2024-01-08 04:37:50 +00:00
|
|
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
|
|
|
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
2024-05-14 07:57:00 +08:00
|
|
|
self.model_runner.model_config.head_dim,
|
2024-03-12 21:45:58 +08:00
|
|
|
]
|
|
|
|
|
|
|
|
|
|
self.prefill_wrapper.begin_forward(*args)
|
2024-01-08 04:37:50 +00:00
|
|
|
else:
|
2024-02-11 05:50:13 -08:00
|
|
|
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
|
|
|
|
workspace_buffer, "NHD"
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
self.decode_wrapper.begin_forward(
|
|
|
|
|
self.kv_indptr,
|
2024-02-06 19:28:29 -08:00
|
|
|
self.kv_indices,
|
2024-01-08 04:37:50 +00:00
|
|
|
self.kv_last_page_len,
|
|
|
|
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
|
|
|
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
|
|
|
|
self.model_runner.model_config.head_dim,
|
|
|
|
|
1,
|
|
|
|
|
"NONE",
|
|
|
|
|
"float16",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def init_extend_args(self):
|
|
|
|
|
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
|
|
|
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
2024-01-21 01:39:23 -08:00
|
|
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
2024-01-08 04:37:50 +00:00
|
|
|
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def create(
|
|
|
|
|
cls,
|
|
|
|
|
model_runner,
|
|
|
|
|
tp_size,
|
|
|
|
|
forward_mode,
|
|
|
|
|
req_pool_indices,
|
|
|
|
|
seq_lens,
|
|
|
|
|
prefix_lens,
|
|
|
|
|
position_ids_offsets,
|
|
|
|
|
out_cache_loc,
|
|
|
|
|
out_cache_cont_start=None,
|
|
|
|
|
out_cache_cont_end=None,
|
2024-03-28 14:34:49 +08:00
|
|
|
top_logprobs_nums=None,
|
2024-01-23 05:07:30 -08:00
|
|
|
return_logprob=False,
|
2024-01-08 04:37:50 +00:00
|
|
|
):
|
|
|
|
|
batch_size = len(req_pool_indices)
|
|
|
|
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
|
|
|
|
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
|
|
|
|
total_num_tokens = int(torch.sum(seq_lens))
|
|
|
|
|
max_seq_len = int(torch.max(seq_lens))
|
|
|
|
|
|
|
|
|
|
if forward_mode == ForwardMode.DECODE:
|
|
|
|
|
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
|
|
|
|
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
|
|
|
|
req_pool_indices[0], seq_lens[0] - 1
|
|
|
|
|
].item()
|
|
|
|
|
else:
|
2024-05-12 04:54:07 -07:00
|
|
|
seq_lens_cpu = seq_lens.cpu().numpy()
|
|
|
|
|
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
|
|
|
|
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
2024-01-08 04:37:50 +00:00
|
|
|
positions = torch.tensor(
|
|
|
|
|
np.concatenate(
|
|
|
|
|
[
|
|
|
|
|
np.arange(
|
2024-05-12 04:54:07 -07:00
|
|
|
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
|
|
|
|
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
|
|
|
|
for i in range(batch_size)
|
|
|
|
|
],
|
|
|
|
|
axis=0,
|
|
|
|
|
),
|
|
|
|
|
device="cuda",
|
|
|
|
|
)
|
|
|
|
|
other_kv_index = None
|
|
|
|
|
|
|
|
|
|
ret = cls(
|
|
|
|
|
model_runner=model_runner,
|
|
|
|
|
forward_mode=forward_mode,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
total_num_tokens=total_num_tokens,
|
|
|
|
|
max_seq_len=max_seq_len,
|
|
|
|
|
req_pool_indices=req_pool_indices,
|
|
|
|
|
start_loc=start_loc,
|
|
|
|
|
seq_lens=seq_lens,
|
|
|
|
|
prefix_lens=prefix_lens,
|
|
|
|
|
positions=positions,
|
|
|
|
|
req_to_token_pool=model_runner.req_to_token_pool,
|
|
|
|
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
|
|
|
|
out_cache_loc=out_cache_loc,
|
|
|
|
|
out_cache_cont_start=out_cache_cont_start,
|
|
|
|
|
out_cache_cont_end=out_cache_cont_end,
|
|
|
|
|
other_kv_index=other_kv_index,
|
2024-05-12 04:54:07 -07:00
|
|
|
return_logprob=return_logprob,
|
|
|
|
|
top_logprobs_nums=top_logprobs_nums,
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if forward_mode == ForwardMode.EXTEND:
|
|
|
|
|
ret.init_extend_args()
|
|
|
|
|
|
2024-03-11 05:49:27 -07:00
|
|
|
if global_server_args_dict.get("enable_flashinfer", False):
|
2024-01-08 04:37:50 +00:00
|
|
|
ret.init_flashinfer_args(tp_size)
|
|
|
|
|
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelRunner:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
model_config,
|
2024-05-27 21:24:10 -07:00
|
|
|
mem_fraction_static: float,
|
|
|
|
|
gpu_id: int,
|
|
|
|
|
tp_rank: int,
|
|
|
|
|
tp_size: int,
|
|
|
|
|
nccl_port: int,
|
2024-05-21 11:46:35 -07:00
|
|
|
server_args: ServerArgs,
|
2024-01-08 04:37:50 +00:00
|
|
|
):
|
|
|
|
|
self.model_config = model_config
|
|
|
|
|
self.mem_fraction_static = mem_fraction_static
|
2024-05-27 21:24:10 -07:00
|
|
|
self.gpu_id = gpu_id
|
2024-01-08 04:37:50 +00:00
|
|
|
self.tp_rank = tp_rank
|
|
|
|
|
self.tp_size = tp_size
|
|
|
|
|
self.nccl_port = nccl_port
|
2024-05-21 11:46:35 -07:00
|
|
|
self.server_args = server_args
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-03-11 20:06:52 +08:00
|
|
|
global global_server_args_dict
|
2024-05-21 11:46:35 -07:00
|
|
|
global_server_args_dict = {
|
|
|
|
|
"enable_flashinfer": server_args.enable_flashinfer,
|
|
|
|
|
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
|
|
|
|
}
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
# Init torch distributed
|
2024-05-27 21:24:10 -07:00
|
|
|
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
|
|
|
|
torch.cuda.set_device(self.gpu_id)
|
|
|
|
|
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
2024-06-07 19:22:34 -07:00
|
|
|
monkey_patch_vllm_p2p_access_check()
|
2024-06-07 12:11:31 -07:00
|
|
|
init_distributed_environment(
|
2024-01-08 04:37:50 +00:00
|
|
|
backend="nccl",
|
|
|
|
|
world_size=self.tp_size,
|
|
|
|
|
rank=self.tp_rank,
|
2024-06-07 19:22:34 -07:00
|
|
|
local_rank=self.gpu_id,
|
2024-06-07 12:11:31 -07:00
|
|
|
distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
|
|
|
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
2024-05-27 21:24:10 -07:00
|
|
|
total_gpu_memory = get_available_gpu_memory(
|
|
|
|
|
self.gpu_id, distributed=self.tp_size > 1
|
|
|
|
|
)
|
2024-05-24 03:48:53 -07:00
|
|
|
|
|
|
|
|
if self.tp_size > 1:
|
2024-05-27 21:24:10 -07:00
|
|
|
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
2024-05-24 03:48:53 -07:00
|
|
|
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
2024-05-27 21:24:10 -07:00
|
|
|
raise ValueError(
|
|
|
|
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
self.load_model()
|
|
|
|
|
self.init_memory_pool(total_gpu_memory)
|
|
|
|
|
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
|
|
|
|
|
|
|
|
|
def load_model(self):
|
2024-05-27 21:24:10 -07:00
|
|
|
logger.info(
|
|
|
|
|
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
2024-06-07 19:22:34 -07:00
|
|
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
2024-05-27 21:24:10 -07:00
|
|
|
)
|
2024-01-20 23:20:35 -08:00
|
|
|
|
2024-05-21 09:13:37 -07:00
|
|
|
device_config = DeviceConfig()
|
2024-05-21 11:46:35 -07:00
|
|
|
load_config = LoadConfig(load_format=self.server_args.load_format)
|
2024-05-21 09:13:37 -07:00
|
|
|
vllm_model_config = VllmModelConfig(
|
2024-05-21 11:46:35 -07:00
|
|
|
model=self.server_args.model_path,
|
|
|
|
|
quantization=self.server_args.quantization,
|
2024-05-21 09:13:37 -07:00
|
|
|
tokenizer=None,
|
|
|
|
|
tokenizer_mode=None,
|
2024-05-21 11:46:35 -07:00
|
|
|
trust_remote_code=self.server_args.trust_remote_code,
|
2024-05-21 09:13:37 -07:00
|
|
|
dtype=torch.float16,
|
|
|
|
|
seed=42,
|
|
|
|
|
skip_tokenizer_init=True,
|
|
|
|
|
)
|
|
|
|
|
if self.model_config.model_overide_args is not None:
|
|
|
|
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
|
|
|
|
|
|
|
|
|
self.model = get_model(
|
|
|
|
|
model_config=vllm_model_config,
|
|
|
|
|
device_config=device_config,
|
|
|
|
|
load_config=load_config,
|
|
|
|
|
lora_config=None,
|
|
|
|
|
vision_language_config=None,
|
|
|
|
|
parallel_config=None,
|
|
|
|
|
scheduler_config=None,
|
2024-06-07 12:11:31 -07:00
|
|
|
cache_config=None,
|
2024-05-21 09:13:37 -07:00
|
|
|
)
|
2024-05-27 21:24:10 -07:00
|
|
|
logger.info(
|
|
|
|
|
f"[gpu_id={self.gpu_id}] Load weight end. "
|
2024-06-07 19:22:34 -07:00
|
|
|
f"type={type(self.model).__name__}, "
|
|
|
|
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
2024-05-27 21:24:10 -07:00
|
|
|
)
|
2024-01-20 23:20:35 -08:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
def profile_max_num_token(self, total_gpu_memory):
|
2024-05-27 21:24:10 -07:00
|
|
|
available_gpu_memory = get_available_gpu_memory(
|
|
|
|
|
self.gpu_id, distributed=self.tp_size > 1
|
|
|
|
|
)
|
2024-03-11 12:14:27 +08:00
|
|
|
head_dim = self.model_config.head_dim
|
2024-01-08 04:37:50 +00:00
|
|
|
head_num = self.model_config.num_key_value_heads // self.tp_size
|
|
|
|
|
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
|
|
|
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
|
|
|
|
1 - self.mem_fraction_static
|
|
|
|
|
)
|
2024-05-24 03:48:53 -07:00
|
|
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
2024-01-08 04:37:50 +00:00
|
|
|
return max_num_token
|
|
|
|
|
|
|
|
|
|
def init_memory_pool(self, total_gpu_memory):
|
2024-05-26 12:51:45 -07:00
|
|
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
2024-01-19 17:03:33 -08:00
|
|
|
|
2024-05-26 12:51:45 -07:00
|
|
|
if self.max_total_num_tokens <= 0:
|
2024-01-21 01:39:23 -08:00
|
|
|
raise RuntimeError(
|
2024-05-27 21:24:10 -07:00
|
|
|
"Not enought memory. Please try to increase --mem-fraction-static."
|
2024-01-21 01:39:23 -08:00
|
|
|
)
|
2024-01-19 17:03:33 -08:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
self.req_to_token_pool = ReqToTokenPool(
|
2024-05-26 12:51:45 -07:00
|
|
|
int(self.max_total_num_tokens / self.model_config.context_len * 256),
|
2024-01-08 04:37:50 +00:00
|
|
|
self.model_config.context_len + 8,
|
|
|
|
|
)
|
|
|
|
|
self.token_to_kv_pool = TokenToKVPool(
|
2024-05-26 12:51:45 -07:00
|
|
|
self.max_total_num_tokens,
|
2024-01-08 04:37:50 +00:00
|
|
|
dtype=torch.float16,
|
|
|
|
|
head_num=self.model_config.num_key_value_heads // self.tp_size,
|
2024-03-11 12:14:27 +08:00
|
|
|
head_dim=self.model_config.head_dim,
|
2024-01-08 04:37:50 +00:00
|
|
|
layer_num=self.model_config.num_hidden_layers,
|
|
|
|
|
)
|
2024-05-27 21:24:10 -07:00
|
|
|
logger.info(
|
|
|
|
|
f"[gpu_id={self.gpu_id}] Memory pool end. "
|
2024-06-07 19:22:34 -07:00
|
|
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
2024-05-27 21:24:10 -07:00
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
2024-03-24 19:48:37 +08:00
|
|
|
def forward_prefill(self, batch: Batch):
|
2024-01-08 04:37:50 +00:00
|
|
|
input_metadata = InputMetadata.create(
|
|
|
|
|
self,
|
|
|
|
|
forward_mode=ForwardMode.PREFILL,
|
|
|
|
|
tp_size=self.tp_size,
|
2024-03-24 19:48:37 +08:00
|
|
|
req_pool_indices=batch.req_pool_indices,
|
|
|
|
|
seq_lens=batch.seq_lens,
|
|
|
|
|
prefix_lens=batch.prefix_lens,
|
|
|
|
|
position_ids_offsets=batch.position_ids_offsets,
|
|
|
|
|
out_cache_loc=batch.out_cache_loc,
|
2024-03-28 14:34:49 +08:00
|
|
|
top_logprobs_nums=batch.top_logprobs_nums,
|
2024-03-24 19:48:37 +08:00
|
|
|
return_logprob=batch.return_logprob,
|
|
|
|
|
)
|
|
|
|
|
return self.model.forward(
|
|
|
|
|
batch.input_ids, input_metadata.positions, input_metadata
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
2024-03-24 19:48:37 +08:00
|
|
|
def forward_extend(self, batch: Batch):
|
2024-01-08 04:37:50 +00:00
|
|
|
input_metadata = InputMetadata.create(
|
|
|
|
|
self,
|
|
|
|
|
forward_mode=ForwardMode.EXTEND,
|
|
|
|
|
tp_size=self.tp_size,
|
2024-03-24 19:48:37 +08:00
|
|
|
req_pool_indices=batch.req_pool_indices,
|
|
|
|
|
seq_lens=batch.seq_lens,
|
|
|
|
|
prefix_lens=batch.prefix_lens,
|
|
|
|
|
position_ids_offsets=batch.position_ids_offsets,
|
|
|
|
|
out_cache_loc=batch.out_cache_loc,
|
2024-03-28 14:34:49 +08:00
|
|
|
top_logprobs_nums=batch.top_logprobs_nums,
|
2024-03-24 19:48:37 +08:00
|
|
|
return_logprob=batch.return_logprob,
|
|
|
|
|
)
|
|
|
|
|
return self.model.forward(
|
|
|
|
|
batch.input_ids, input_metadata.positions, input_metadata
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
2024-03-24 19:48:37 +08:00
|
|
|
def forward_decode(self, batch: Batch):
|
2024-01-08 04:37:50 +00:00
|
|
|
input_metadata = InputMetadata.create(
|
|
|
|
|
self,
|
|
|
|
|
forward_mode=ForwardMode.DECODE,
|
|
|
|
|
tp_size=self.tp_size,
|
2024-03-24 19:48:37 +08:00
|
|
|
req_pool_indices=batch.req_pool_indices,
|
|
|
|
|
seq_lens=batch.seq_lens,
|
|
|
|
|
prefix_lens=batch.prefix_lens,
|
|
|
|
|
position_ids_offsets=batch.position_ids_offsets,
|
|
|
|
|
out_cache_loc=batch.out_cache_loc,
|
|
|
|
|
out_cache_cont_start=batch.out_cache_cont_start,
|
|
|
|
|
out_cache_cont_end=batch.out_cache_cont_end,
|
2024-03-28 14:34:49 +08:00
|
|
|
top_logprobs_nums=batch.top_logprobs_nums,
|
2024-03-24 19:48:37 +08:00
|
|
|
return_logprob=batch.return_logprob,
|
|
|
|
|
)
|
|
|
|
|
return self.model.forward(
|
|
|
|
|
batch.input_ids, input_metadata.positions, input_metadata
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
2024-03-24 19:48:37 +08:00
|
|
|
def forward_extend_multi_modal(self, batch: Batch):
|
2024-01-08 04:37:50 +00:00
|
|
|
input_metadata = InputMetadata.create(
|
|
|
|
|
self,
|
|
|
|
|
forward_mode=ForwardMode.EXTEND,
|
|
|
|
|
tp_size=self.tp_size,
|
2024-03-24 19:48:37 +08:00
|
|
|
req_pool_indices=batch.req_pool_indices,
|
|
|
|
|
seq_lens=batch.seq_lens,
|
|
|
|
|
prefix_lens=batch.prefix_lens,
|
|
|
|
|
position_ids_offsets=batch.position_ids_offsets,
|
|
|
|
|
out_cache_loc=batch.out_cache_loc,
|
2024-03-28 14:34:49 +08:00
|
|
|
top_logprobs_nums=batch.top_logprobs_nums,
|
2024-03-24 19:48:37 +08:00
|
|
|
return_logprob=batch.return_logprob,
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
|
|
|
|
return self.model.forward(
|
2024-03-24 19:48:37 +08:00
|
|
|
batch.input_ids,
|
2024-01-08 04:37:50 +00:00
|
|
|
input_metadata.positions,
|
|
|
|
|
input_metadata,
|
2024-03-24 19:48:37 +08:00
|
|
|
batch.pixel_values,
|
|
|
|
|
batch.image_sizes,
|
|
|
|
|
batch.image_offsets,
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
|
|
|
|
|
2024-03-24 19:48:37 +08:00
|
|
|
def forward(self, batch: Batch, forward_mode: ForwardMode):
|
2024-01-08 04:37:50 +00:00
|
|
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
2024-03-24 19:48:37 +08:00
|
|
|
return self.forward_extend_multi_modal(batch)
|
|
|
|
|
elif forward_mode == ForwardMode.DECODE:
|
|
|
|
|
return self.forward_decode(batch)
|
2024-01-08 04:37:50 +00:00
|
|
|
elif forward_mode == ForwardMode.EXTEND:
|
2024-03-24 19:48:37 +08:00
|
|
|
return self.forward_extend(batch)
|
2024-01-08 04:37:50 +00:00
|
|
|
elif forward_mode == ForwardMode.PREFILL:
|
2024-03-24 19:48:37 +08:00
|
|
|
return self.forward_prefill(batch)
|
2024-01-08 04:37:50 +00:00
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
2024-05-21 09:13:37 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache()
|
|
|
|
|
def import_model_classes():
|
|
|
|
|
model_arch_name_to_cls = {}
|
|
|
|
|
package_name = "sglang.srt.models"
|
|
|
|
|
package = importlib.import_module(package_name)
|
|
|
|
|
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
|
|
|
|
if not ispkg:
|
|
|
|
|
module = importlib.import_module(name)
|
|
|
|
|
if hasattr(module, "EntryClass"):
|
2024-05-27 03:29:51 +08:00
|
|
|
entry = module.EntryClass
|
|
|
|
|
if isinstance(entry, list): # To support multiple model classes in one module
|
2024-05-27 21:24:10 -07:00
|
|
|
for tmp in entry:
|
|
|
|
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
2024-05-27 03:29:51 +08:00
|
|
|
else:
|
|
|
|
|
model_arch_name_to_cls[entry.__name__] = entry
|
2024-05-21 09:13:37 -07:00
|
|
|
return model_arch_name_to_cls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
|
|
|
|
model_arch_name_to_cls = import_model_classes()
|
|
|
|
|
if model_arch not in model_arch_name_to_cls:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Unsupported architectures: {model_arch}. "
|
|
|
|
|
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
|
|
|
|
)
|
|
|
|
|
return model_arch_name_to_cls[model_arch]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Monkey patch model loader
|
2024-05-27 21:24:10 -07:00
|
|
|
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|