Files
sglang/python/sglang/srt/managers/router/model_runner.py

477 lines
17 KiB
Python
Raw Normal View History

import logging
from dataclasses import dataclass
from enum import Enum, auto
from typing import List
import numpy as np
import torch
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
from sglang.utils import get_available_gpu_memory
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
2024-01-20 23:20:35 -08:00
logger = logging.getLogger("model_runner")
# for model_mode
global_model_mode: List[str] = []
@dataclass
class InputMetadata:
model_runner: "ModelRunner"
forward_mode: ForwardMode
batch_size: int
total_num_tokens: int
max_seq_len: int
req_pool_indices: torch.Tensor
start_loc: torch.Tensor
seq_lens: torch.Tensor
prefix_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
# for extend
extend_seq_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
max_extend_len: int = 0
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None
return_normalized_logprob: bool = False
# for flashinfer
use_flashinfer: bool = False
qo_indptr: torch.Tensor = None
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
kv_last_page_len: torch.Tensor = None
prefill_wrapper = None
decode_wrapper = None
def init_flashinfer_args(self, tp_size):
self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
self.kv_indices = torch.cat(
[
self.req_to_token_pool.req_to_token[
self.req_pool_indices[i].item(), : self.seq_lens[i].item()
]
for i in range(self.batch_size)
],
dim=0,
).contiguous()
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
from flashinfer.ops import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
):
self.qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper()
self.prefill_wrapper.begin_forward(
self.qo_indptr,
self.batch_size,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper()
self.decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_last_page_len,
self.batch_size,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
self.model_runner.model_config.head_dim,
1,
"NONE",
"float16",
)
def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.max_extend_len = int(torch.max(self.extend_seq_lens))
@classmethod
def create(
cls,
model_runner,
tp_size,
forward_mode,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
out_cache_cont_start=None,
out_cache_cont_end=None,
return_normalized_logprob=False,
):
batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
total_num_tokens = int(torch.sum(seq_lens))
max_seq_len = int(torch.max(seq_lens))
if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
other_kv_index = model_runner.req_to_token_pool.req_to_token[
req_pool_indices[0], seq_lens[0] - 1
].item()
else:
seq_lens_np = seq_lens.cpu().numpy()
prefix_lens_np = prefix_lens.cpu().numpy()
position_ids_offsets_np = position_ids_offsets.cpu().numpy()
positions = torch.tensor(
np.concatenate(
[
np.arange(
prefix_lens_np[i] + position_ids_offsets_np[i],
seq_lens_np[i] + position_ids_offsets_np[i],
)
for i in range(batch_size)
],
axis=0,
),
device="cuda",
)
other_kv_index = None
ret = cls(
model_runner=model_runner,
forward_mode=forward_mode,
batch_size=batch_size,
total_num_tokens=total_num_tokens,
max_seq_len=max_seq_len,
req_pool_indices=req_pool_indices,
start_loc=start_loc,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
positions=positions,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
return_normalized_logprob=return_normalized_logprob,
other_kv_index=other_kv_index,
)
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()
ret.use_flashinfer = "flashinfer" in model_runner.model_mode
if ret.use_flashinfer:
ret.init_flashinfer_args(tp_size)
return ret
class ModelRunner:
def __init__(
self,
model_config,
mem_fraction_static,
tp_rank,
tp_size,
nccl_port,
load_format="auto",
trust_remote_code=True,
model_mode: List[str] = (),
):
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.tp_rank = tp_rank
self.tp_size = tp_size
self.nccl_port = nccl_port
self.load_format = load_format
self.trust_remote_code = trust_remote_code
self.model_mode = model_mode
global global_model_mode
global_model_mode = model_mode
# Init torch distributed
torch.cuda.set_device(self.tp_rank)
torch.distributed.init_process_group(
backend="nccl",
world_size=self.tp_size,
rank=self.tp_rank,
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
)
# A small all_reduce for warmup.
if self.tp_size > 1:
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
self.load_model()
self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config)
def load_model(self):
"""See also vllm/model_executor/model_loader.py::get_model"""
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llava import LlavaLlamaForCausalLM
from sglang.srt.models.mixtral import MixtralForCausalLM
from sglang.srt.models.qwen import QWenLMHeadModel
# Select model class
architectures = getattr(self.model_config.hf_config, "architectures", [])
model_class = None
for arch in architectures:
if arch == "LlamaForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "MistralForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "LlavaLlamaForCausalLM":
model_class = LlavaLlamaForCausalLM
break
if arch == "MixtralForCausalLM":
model_class = MixtralForCausalLM
break
if arch == "QWenLMHeadModel":
model_class = QWenLMHeadModel
break
if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}")
logger.info(f"Rank {self.tp_rank}: load weight begin.")
2024-01-20 23:20:35 -08:00
# Load weights
linear_method = None
with _set_default_torch_dtype(torch.float16):
with torch.device("cuda"):
hf_quant_config = getattr(
self.model_config.hf_config, "quantization_config", None
)
if hf_quant_config is not None:
# TODO: config quantization awq etc
quant_config = AWQConfig.from_config(hf_quant_config)
2024-01-20 23:20:35 -08:00
logger.info(f"quant_config: {quant_config}")
linear_method = quant_config.get_linear_method()
model = model_class(
config=self.model_config.hf_config, linear_method=linear_method
)
model.load_weights(
self.model_config.path,
cache_dir=None,
load_format=self.load_format,
revision=None,
)
2024-01-17 04:43:17 -08:00
self.model = model.eval()
logger.info(f"Rank {self.tp_rank}: load weight end.")
2024-01-20 23:20:35 -08:00
def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
head_dim = (
self.model_config.hidden_size // self.model_config.num_attention_heads
)
head_num = self.model_config.num_key_value_heads // self.tp_size
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
max_num_token = int(rest_memory // cell_size)
return max_num_token
def init_memory_pool(self, total_gpu_memory):
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
if self.max_total_num_token <= 0:
raise RuntimeError(
"Not enought memory. " "Please try to increase --mem-fraction-static."
)
self.req_to_token_pool = ReqToTokenPool(
int(self.max_total_num_token / self.model_config.context_len * 256),
self.model_config.context_len + 8,
)
self.token_to_kv_pool = TokenToKVPool(
self.max_total_num_token,
dtype=torch.float16,
head_num=self.model_config.num_key_value_heads // self.tp_size,
head_dim=self.model_config.hidden_size
// self.model_config.num_attention_heads,
layer_num=self.model_config.num_hidden_layers,
)
@torch.inference_mode()
def forward_prefill(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_normalized_logprob,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.PREFILL,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob,
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode()
def forward_extend(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_normalized_logprob,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob,
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode()
def forward_decode(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.DECODE,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)[
0
]
@torch.inference_mode()
def forward_extend_multi_modal(
self,
input_ids,
pixel_values,
image_offsets,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_normalized_logprob,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob,
)
return self.model.forward(
input_ids,
input_metadata.positions,
input_metadata,
pixel_values,
image_offsets,
)
def forward(
self, batch: Batch, forward_mode: ForwardMode, return_normalized_logprob=False
):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
kwargs = {
"input_ids": batch.input_ids,
"pixel_values": batch.pixel_values,
"image_offsets": batch.image_offsets,
"req_pool_indices": batch.req_pool_indices,
"seq_lens": batch.seq_lens,
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
}
kwargs["return_normalized_logprob"] = return_normalized_logprob
return self.forward_extend_multi_modal(**kwargs)
else:
kwargs = {
"input_ids": batch.input_ids,
"req_pool_indices": batch.req_pool_indices,
"seq_lens": batch.seq_lens,
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
}
if forward_mode == ForwardMode.DECODE:
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
return self.forward_decode(**kwargs)
elif forward_mode == ForwardMode.EXTEND:
kwargs["return_normalized_logprob"] = return_normalized_logprob
return self.forward_extend(**kwargs)
elif forward_mode == ForwardMode.PREFILL:
kwargs["return_normalized_logprob"] = return_normalized_logprob
return self.forward_prefill(**kwargs)
else:
raise ValueError(f"Invaid forward mode: {forward_mode}")