Files
sglang/python/sglang/srt/managers/router/model_rpc.py
Ying Sheng ffe4aaee1d Fix for T4 GPUs (#16)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
2024-01-16 15:49:03 -08:00

512 lines
19 KiB
Python

import asyncio
import logging
import multiprocessing
import time
from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto
from typing import Dict, List, Optional, Tuple, Union
import warnings
import numpy as np
import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_exception_traceback,
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
)
logger = logging.getLogger("model_rpc")
class ModelRpcServer(rpyc.Service):
def exposed_init_model(
self,
tp_rank: int,
server_args: ServerArgs,
port_args: PortArgs,
):
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
# Copy arguments
self.model_mode = server_args.model_mode
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path, server_args.trust_remote_code
)
self.model_runner = ModelRunner(
self.model_config,
server_args.mem_fraction_static,
tp_rank,
server_args.tp_size,
port_args.nccl_port,
server_args.load_format,
server_args.trust_remote_code,
server_args.model_mode,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.eos_token_id = self.tokenizer.eos_token_id
self.max_total_num_token = self.model_runner.max_total_num_token
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.model_config.context_len, self.max_total_num_token // 6
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
set_random_seed(server_args.random_seed)
logger.info(
f"Rank {self.tp_rank}: "
f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, "
f"model_mode={self.model_mode}"
)
# Init cache
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
self.scheduler = Scheduler(
self.schedule_heuristic,
self.max_num_running_seq,
self.max_prefill_num_token,
self.max_total_num_token,
self.tree_cache,
)
self.req_to_token_pool = self.model_runner.req_to_token_pool
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
# Init running status
self.forward_queue: List[Req] = []
self.running_batch: Batch = None
self.out_pyobjs = []
self.decode_forward_ct = 0
self.stream_interval = 2
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(self.tokenizer)
def exposed_step(self, recv_reqs):
if self.tp_size != 1:
recv_reqs = obtain(recv_reqs)
try:
# Recv requests
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
else:
raise ValueError(f"Invalid request: {recv_req}")
# Forward
self.forward_step()
except Exception:
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
# Return results
ret = self.out_pyobjs
self.out_pyobjs = []
return ret
@torch.inference_mode()
def forward_step(self):
new_batch = self.get_new_fill_batch()
if new_batch is not None:
# Run new fill batch
self.forward_fill_batch(new_batch)
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge(new_batch)
else:
# Run decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(10):
self.forward_decode_batch(self.running_batch)
if self.running_batch.is_empty():
self.running_batch = None
break
else:
# check the available size
available_size = (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_token:
warnings.warn(
"Warning: "
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
"KV cache pool leak detected!"
)
if self.running_batch is not None and self.tp_rank == 0:
if self.decode_forward_ct >= 20:
self.decode_forward_ct = 0
num_used = self.max_total_num_token - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
logger.info(
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_token:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid)
req.input_ids = recv_req.input_ids
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, pad_value
)
req.sampling_params = recv_req.sampling_params
req.return_normalized_logprob = recv_req.return_normalized_logprob
req.normalized_logprob_start_len = recv_req.normalized_logprob_start_len
req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# init the regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
# Truncate long prompts
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min(
req.sampling_params.max_new_tokens,
self.model_config.context_len - 1 - len(req.input_ids),
)
self.forward_queue.append(req)
def get_new_fill_batch(self):
if (
self.running_batch is not None
and len(self.running_batch.reqs) > self.max_num_running_seq
):
return None
for req in self.forward_queue:
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_normalized_logprob:
prefix_indices = prefix_indices[: req.normalized_logprob_start_len]
req.adjust_input_len = len(req.input_ids) - len(prefix_indices)
req.prefix_indices = prefix_indices
req.last_node = last_node
# Get priority queue
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
# Add requests if there is available space
can_run_list = []
new_batch_total_tokens = 0
new_batch_input_tokens = 0
new_batch_prefix_tokens = 0
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
new_ratio = self.scheduler.new_token_estimation_ratio()
if self.running_batch:
available_size -= sum(
[
(r.max_new_tokens() - len(r.output_ids)) * new_ratio
for r in self.running_batch.reqs
]
)
for req in self.forward_queue:
if req.return_normalized_logprob:
# Need at least two tokens to compute normalized logprob
if req.adjust_input_len < 2:
delta = 2 - req.adjust_input_len
req.adjust_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.adjust_input_len == 0 and req.max_new_tokens() > 0:
# Need at least one token to compute logits
req.adjust_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1
if (
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
and req.adjust_input_len + new_batch_input_tokens
< self.max_prefill_num_token
):
delta = self.tree_cache.inc_ref_counter(req.last_node)
available_size += delta
if not (
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
):
delta = self.tree_cache.dec_ref_counter(req.last_node)
available_size += delta
else:
self.token_to_kv_pool.add_refs(req.prefix_indices)
can_run_list.append(req)
new_batch_total_tokens += (
req.adjust_input_len + req.max_new_tokens()
)
new_batch_input_tokens += req.adjust_input_len
if len(can_run_list) == 0:
return None
if self.tp_rank == 0:
logger.info(
f"new fill batch. #seq: {len(can_run_list)}. "
f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. "
f"#new_token: {new_batch_input_tokens}. "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
new_batch = Batch(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
)
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
return new_batch
def forward_fill_batch(self, batch: Batch):
# Build batch tensors
batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias)
if batch.extend_num_tokens != 0:
# Forward
logits, normalized_logprobs = self.model_runner.forward(
batch, ForwardMode.EXTEND, batch.return_normalized_logprob
)
# print("extend logits", logits)
if normalized_logprobs is not None:
normalized_logprobs = normalized_logprobs.cpu().tolist()
next_token_ids, next_token_probs = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
normalized_logprobs = None
# Check finish condition
reqs = batch.reqs
for i in range(len(reqs)):
reqs[i].output_ids = [next_token_ids[i]]
reqs[i].check_finished()
if normalized_logprobs is not None:
reqs[i].normalized_logprob = normalized_logprobs[i]
self.handle_finished_requests(batch)
def forward_decode_batch(self, batch: Batch):
# Update batch tensors
self.decode_forward_ct += 1
batch.update_for_decode()
# Forward
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
# Check finish condition
reqs = batch.reqs
for i in range(len(reqs)):
reqs[i].output_ids.append(next_token_ids[i])
reqs[i].check_finished()
self.handle_finished_requests(batch)
def handle_finished_requests(self, batch: Batch):
output_rids = []
output_tokens = []
output_hit_stop_str = []
output_skip_special_tokens = []
output_meta_info = []
output_finished = []
finished_indices = []
unfinished_indices = []
for i, req in enumerate(batch.reqs):
if req.finished:
finished_indices.append(i)
else:
unfinished_indices.append(i)
if req.finished or (
req.stream and self.decode_forward_ct % self.stream_interval == 0
):
output_rids.append(req.rid)
output_tokens.append(req.output_ids)
output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
meta_info = {
"prompt_tokens": len(req.input_ids),
"completion_tokens": len(req.output_ids),
}
if req.return_normalized_logprob:
meta_info["normalized_logprob"] = req.normalized_logprob
output_meta_info.append(meta_info)
output_finished.append(req.finished)
# Send to detokenizer
if output_rids:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
output_tokens,
output_hit_stop_str,
output_skip_special_tokens,
output_meta_info,
output_finished,
)
)
# Remove finished reqs
if finished_indices:
# Update radix cache
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
for i in finished_indices:
req = batch.reqs[i]
req_pool_idx = req_pool_indices_cpu[i]
token_ids = tuple(req.input_ids + req.output_ids)
seq_len = len(token_ids) - 1
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
prefix_len = self.tree_cache.insert(
token_ids[:seq_len], indices.clone()
)
self.token_to_kv_pool.free(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node)
# Update batch tensors
if unfinished_indices:
batch.filter_batch(unfinished_indices)
else:
batch.reqs = []
class ModelRpcClient:
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
tp_size = server_args.tp_size
if tp_size == 1:
# Init model
self.model_server = ModelRpcServer()
self.model_server.exposed_init_model(0, server_args, port_args)
# Wrap functions
def async_wrap(f):
async def _func(*args, **kwargs):
return f(*args, **kwargs)
return _func
self.step = async_wrap(self.model_server.exposed_step)
else:
with ThreadPoolExecutor(tp_size) as executor:
# Launch model processes
rets = executor.map(start_model_process, port_args.model_rpc_ports)
self.model_servers = [x[0] for x in rets]
self.procs = [x[1] for x in rets]
# Init model
def init_model(i):
return self.model_servers[i].init_model(i, server_args, port_args)
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
# Wrap functions
def async_wrap(func_name):
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
async def _func(*args, **kwargs):
tasks = [f(*args, **kwargs) for f in fs]
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
return obtain(tasks[0].value)
return _func
self.step = async_wrap("step")
def start_model_process(port):
def _init_service(port):
t = ThreadedServer(
ModelRpcServer(),
port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 600},
)
t.start()
proc = multiprocessing.Process(target=_init_service, args=(port,))
proc.start()
time.sleep(1)
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
"localhost",
port,
config={"allow_pickle": True, "sync_request_timeout": 600},
)
break
except ConnectionRefusedError:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError("init rpc env error!")
assert proc.is_alive()
return con.root, proc