512 lines
19 KiB
Python
512 lines
19 KiB
Python
import asyncio
|
|
import logging
|
|
import multiprocessing
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from enum import Enum, auto
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import rpyc
|
|
import torch
|
|
from rpyc.utils.classic import obtain
|
|
from rpyc.utils.server import ThreadedServer
|
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
|
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
|
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
|
from sglang.srt.managers.router.model_runner import ModelRunner
|
|
from sglang.srt.managers.router.radix_cache import RadixCache
|
|
from sglang.srt.managers.router.scheduler import Scheduler
|
|
from sglang.srt.model_config import ModelConfig
|
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
from sglang.srt.utils import (
|
|
get_exception_traceback,
|
|
get_int_token_logit_bias,
|
|
is_multimodal_model,
|
|
set_random_seed,
|
|
)
|
|
|
|
logger = logging.getLogger("model_rpc")
|
|
|
|
|
|
class ModelRpcServer(rpyc.Service):
|
|
def exposed_init_model(
|
|
self,
|
|
tp_rank: int,
|
|
server_args: ServerArgs,
|
|
port_args: PortArgs,
|
|
):
|
|
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
|
|
|
# Copy arguments
|
|
self.model_mode = server_args.model_mode
|
|
self.tp_rank = tp_rank
|
|
self.tp_size = server_args.tp_size
|
|
self.schedule_heuristic = server_args.schedule_heuristic
|
|
|
|
# Init model and tokenizer
|
|
self.model_config = ModelConfig(
|
|
server_args.model_path, server_args.trust_remote_code
|
|
)
|
|
self.model_runner = ModelRunner(
|
|
self.model_config,
|
|
server_args.mem_fraction_static,
|
|
tp_rank,
|
|
server_args.tp_size,
|
|
port_args.nccl_port,
|
|
server_args.load_format,
|
|
server_args.trust_remote_code,
|
|
server_args.model_mode,
|
|
)
|
|
if is_multimodal_model(server_args.model_path):
|
|
self.processor = get_processor(
|
|
server_args.tokenizer_path,
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
)
|
|
self.tokenizer = self.processor.tokenizer
|
|
else:
|
|
self.tokenizer = get_tokenizer(
|
|
server_args.tokenizer_path,
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
)
|
|
self.eos_token_id = self.tokenizer.eos_token_id
|
|
self.max_total_num_token = self.model_runner.max_total_num_token
|
|
self.max_num_running_seq = self.max_total_num_token // 2
|
|
self.max_prefill_num_token = max(
|
|
self.model_config.context_len, self.max_total_num_token // 6
|
|
)
|
|
self.int_token_logit_bias = torch.tensor(
|
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
|
)
|
|
set_random_seed(server_args.random_seed)
|
|
logger.info(
|
|
f"Rank {self.tp_rank}: "
|
|
f"max_total_num_token={self.max_total_num_token}, "
|
|
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
|
f"context_len={self.model_config.context_len}, "
|
|
f"model_mode={self.model_mode}"
|
|
)
|
|
|
|
# Init cache
|
|
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
|
|
self.scheduler = Scheduler(
|
|
self.schedule_heuristic,
|
|
self.max_num_running_seq,
|
|
self.max_prefill_num_token,
|
|
self.max_total_num_token,
|
|
self.tree_cache,
|
|
)
|
|
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
|
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
|
|
|
# Init running status
|
|
self.forward_queue: List[Req] = []
|
|
self.running_batch: Batch = None
|
|
self.out_pyobjs = []
|
|
self.decode_forward_ct = 0
|
|
self.stream_interval = 2
|
|
|
|
# Init the FSM cache for constrained generation
|
|
self.regex_fsm_cache = FSMCache(self.tokenizer)
|
|
|
|
def exposed_step(self, recv_reqs):
|
|
if self.tp_size != 1:
|
|
recv_reqs = obtain(recv_reqs)
|
|
|
|
try:
|
|
# Recv requests
|
|
for recv_req in recv_reqs:
|
|
if isinstance(recv_req, TokenizedGenerateReqInput):
|
|
self.handle_generate_request(recv_req)
|
|
else:
|
|
raise ValueError(f"Invalid request: {recv_req}")
|
|
|
|
# Forward
|
|
self.forward_step()
|
|
except Exception:
|
|
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
|
|
|
|
# Return results
|
|
ret = self.out_pyobjs
|
|
self.out_pyobjs = []
|
|
return ret
|
|
|
|
@torch.inference_mode()
|
|
def forward_step(self):
|
|
new_batch = self.get_new_fill_batch()
|
|
|
|
if new_batch is not None:
|
|
# Run new fill batch
|
|
self.forward_fill_batch(new_batch)
|
|
|
|
if not new_batch.is_empty():
|
|
if self.running_batch is None:
|
|
self.running_batch = new_batch
|
|
else:
|
|
self.running_batch.merge(new_batch)
|
|
else:
|
|
# Run decode batch
|
|
if self.running_batch is not None:
|
|
# Run a few decode batches continuously for reducing overhead
|
|
for _ in range(10):
|
|
self.forward_decode_batch(self.running_batch)
|
|
|
|
if self.running_batch.is_empty():
|
|
self.running_batch = None
|
|
break
|
|
else:
|
|
# check the available size
|
|
available_size = (
|
|
self.token_to_kv_pool.available_size()
|
|
+ self.tree_cache.evictable_size()
|
|
)
|
|
if available_size != self.max_total_num_token:
|
|
warnings.warn(
|
|
"Warning: "
|
|
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
|
|
"KV cache pool leak detected!"
|
|
)
|
|
|
|
if self.running_batch is not None and self.tp_rank == 0:
|
|
if self.decode_forward_ct >= 20:
|
|
self.decode_forward_ct = 0
|
|
num_used = self.max_total_num_token - (
|
|
self.token_to_kv_pool.available_size()
|
|
+ self.tree_cache.evictable_size()
|
|
)
|
|
logger.info(
|
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
|
f"#token: {num_used}, "
|
|
f"token usage: {num_used / self.max_total_num_token:.2f}, "
|
|
f"#queue-req: {len(self.forward_queue)}"
|
|
)
|
|
|
|
def handle_generate_request(
|
|
self,
|
|
recv_req: TokenizedGenerateReqInput,
|
|
):
|
|
req = Req(recv_req.rid)
|
|
req.input_ids = recv_req.input_ids
|
|
req.pixel_values = recv_req.pixel_values
|
|
if req.pixel_values is not None:
|
|
pad_value = [
|
|
(recv_req.image_hash) % self.model_config.vocab_size,
|
|
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
|
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
|
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
|
]
|
|
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
|
req.input_ids, pad_value
|
|
)
|
|
req.sampling_params = recv_req.sampling_params
|
|
req.return_normalized_logprob = recv_req.return_normalized_logprob
|
|
req.normalized_logprob_start_len = recv_req.normalized_logprob_start_len
|
|
req.stream = recv_req.stream
|
|
req.tokenizer = self.tokenizer
|
|
|
|
# init the regex fsm
|
|
if req.sampling_params.regex is not None:
|
|
req.regex_fsm_state = 0
|
|
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
|
|
|
|
# Truncate long prompts
|
|
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
|
req.sampling_params.max_new_tokens = min(
|
|
req.sampling_params.max_new_tokens,
|
|
self.model_config.context_len - 1 - len(req.input_ids),
|
|
)
|
|
self.forward_queue.append(req)
|
|
|
|
def get_new_fill_batch(self):
|
|
if (
|
|
self.running_batch is not None
|
|
and len(self.running_batch.reqs) > self.max_num_running_seq
|
|
):
|
|
return None
|
|
|
|
for req in self.forward_queue:
|
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
|
if req.return_normalized_logprob:
|
|
prefix_indices = prefix_indices[: req.normalized_logprob_start_len]
|
|
req.adjust_input_len = len(req.input_ids) - len(prefix_indices)
|
|
req.prefix_indices = prefix_indices
|
|
req.last_node = last_node
|
|
|
|
# Get priority queue
|
|
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
|
|
|
|
# Add requests if there is available space
|
|
can_run_list = []
|
|
new_batch_total_tokens = 0
|
|
new_batch_input_tokens = 0
|
|
new_batch_prefix_tokens = 0
|
|
|
|
available_size = (
|
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
|
)
|
|
new_ratio = self.scheduler.new_token_estimation_ratio()
|
|
if self.running_batch:
|
|
available_size -= sum(
|
|
[
|
|
(r.max_new_tokens() - len(r.output_ids)) * new_ratio
|
|
for r in self.running_batch.reqs
|
|
]
|
|
)
|
|
|
|
for req in self.forward_queue:
|
|
if req.return_normalized_logprob:
|
|
# Need at least two tokens to compute normalized logprob
|
|
if req.adjust_input_len < 2:
|
|
delta = 2 - req.adjust_input_len
|
|
req.adjust_input_len += delta
|
|
req.prefix_indices = req.prefix_indices[:-delta]
|
|
if req.image_offset is not None:
|
|
req.image_offset += delta
|
|
if req.adjust_input_len == 0 and req.max_new_tokens() > 0:
|
|
# Need at least one token to compute logits
|
|
req.adjust_input_len = 1
|
|
req.prefix_indices = req.prefix_indices[:-1]
|
|
if req.image_offset is not None:
|
|
req.image_offset += 1
|
|
|
|
if (
|
|
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
|
|
< available_size
|
|
and req.adjust_input_len + new_batch_input_tokens
|
|
< self.max_prefill_num_token
|
|
):
|
|
delta = self.tree_cache.inc_ref_counter(req.last_node)
|
|
available_size += delta
|
|
|
|
if not (
|
|
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
|
|
< available_size
|
|
):
|
|
delta = self.tree_cache.dec_ref_counter(req.last_node)
|
|
available_size += delta
|
|
else:
|
|
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
|
can_run_list.append(req)
|
|
new_batch_total_tokens += (
|
|
req.adjust_input_len + req.max_new_tokens()
|
|
)
|
|
new_batch_input_tokens += req.adjust_input_len
|
|
|
|
if len(can_run_list) == 0:
|
|
return None
|
|
|
|
if self.tp_rank == 0:
|
|
logger.info(
|
|
f"new fill batch. #seq: {len(can_run_list)}. "
|
|
f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. "
|
|
f"#new_token: {new_batch_input_tokens}. "
|
|
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
|
|
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
|
)
|
|
|
|
new_batch = Batch(
|
|
can_run_list,
|
|
self.req_to_token_pool,
|
|
self.token_to_kv_pool,
|
|
self.tree_cache,
|
|
)
|
|
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
|
return new_batch
|
|
|
|
def forward_fill_batch(self, batch: Batch):
|
|
# Build batch tensors
|
|
batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias)
|
|
if batch.extend_num_tokens != 0:
|
|
# Forward
|
|
logits, normalized_logprobs = self.model_runner.forward(
|
|
batch, ForwardMode.EXTEND, batch.return_normalized_logprob
|
|
)
|
|
# print("extend logits", logits)
|
|
if normalized_logprobs is not None:
|
|
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
|
|
|
next_token_ids, next_token_probs = batch.sample(logits)
|
|
next_token_ids = next_token_ids.cpu().tolist()
|
|
else:
|
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
|
normalized_logprobs = None
|
|
|
|
# Check finish condition
|
|
reqs = batch.reqs
|
|
for i in range(len(reqs)):
|
|
reqs[i].output_ids = [next_token_ids[i]]
|
|
reqs[i].check_finished()
|
|
|
|
if normalized_logprobs is not None:
|
|
reqs[i].normalized_logprob = normalized_logprobs[i]
|
|
|
|
self.handle_finished_requests(batch)
|
|
|
|
def forward_decode_batch(self, batch: Batch):
|
|
# Update batch tensors
|
|
self.decode_forward_ct += 1
|
|
batch.update_for_decode()
|
|
|
|
# Forward
|
|
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
|
|
next_token_ids, next_token_probs = batch.sample(logits)
|
|
next_token_ids = next_token_ids.cpu().tolist()
|
|
|
|
# Check finish condition
|
|
reqs = batch.reqs
|
|
for i in range(len(reqs)):
|
|
reqs[i].output_ids.append(next_token_ids[i])
|
|
reqs[i].check_finished()
|
|
|
|
self.handle_finished_requests(batch)
|
|
|
|
def handle_finished_requests(self, batch: Batch):
|
|
output_rids = []
|
|
output_tokens = []
|
|
output_hit_stop_str = []
|
|
output_skip_special_tokens = []
|
|
output_meta_info = []
|
|
output_finished = []
|
|
finished_indices = []
|
|
unfinished_indices = []
|
|
for i, req in enumerate(batch.reqs):
|
|
if req.finished:
|
|
finished_indices.append(i)
|
|
else:
|
|
unfinished_indices.append(i)
|
|
|
|
if req.finished or (
|
|
req.stream and self.decode_forward_ct % self.stream_interval == 0
|
|
):
|
|
output_rids.append(req.rid)
|
|
output_tokens.append(req.output_ids)
|
|
output_hit_stop_str.append(req.hit_stop_str)
|
|
output_skip_special_tokens.append(
|
|
req.sampling_params.skip_special_tokens
|
|
)
|
|
meta_info = {
|
|
"prompt_tokens": len(req.input_ids),
|
|
"completion_tokens": len(req.output_ids),
|
|
}
|
|
if req.return_normalized_logprob:
|
|
meta_info["normalized_logprob"] = req.normalized_logprob
|
|
output_meta_info.append(meta_info)
|
|
output_finished.append(req.finished)
|
|
|
|
# Send to detokenizer
|
|
if output_rids:
|
|
self.out_pyobjs.append(
|
|
BatchTokenIDOut(
|
|
output_rids,
|
|
output_tokens,
|
|
output_hit_stop_str,
|
|
output_skip_special_tokens,
|
|
output_meta_info,
|
|
output_finished,
|
|
)
|
|
)
|
|
|
|
# Remove finished reqs
|
|
if finished_indices:
|
|
# Update radix cache
|
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
|
|
for i in finished_indices:
|
|
req = batch.reqs[i]
|
|
req_pool_idx = req_pool_indices_cpu[i]
|
|
token_ids = tuple(req.input_ids + req.output_ids)
|
|
seq_len = len(token_ids) - 1
|
|
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
|
|
prefix_len = self.tree_cache.insert(
|
|
token_ids[:seq_len], indices.clone()
|
|
)
|
|
|
|
self.token_to_kv_pool.free(indices[:prefix_len])
|
|
self.req_to_token_pool.free(req_pool_idx)
|
|
self.tree_cache.dec_ref_counter(req.last_node)
|
|
|
|
# Update batch tensors
|
|
if unfinished_indices:
|
|
batch.filter_batch(unfinished_indices)
|
|
else:
|
|
batch.reqs = []
|
|
|
|
|
|
class ModelRpcClient:
|
|
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
|
|
tp_size = server_args.tp_size
|
|
|
|
if tp_size == 1:
|
|
# Init model
|
|
self.model_server = ModelRpcServer()
|
|
self.model_server.exposed_init_model(0, server_args, port_args)
|
|
|
|
# Wrap functions
|
|
def async_wrap(f):
|
|
async def _func(*args, **kwargs):
|
|
return f(*args, **kwargs)
|
|
|
|
return _func
|
|
|
|
self.step = async_wrap(self.model_server.exposed_step)
|
|
else:
|
|
with ThreadPoolExecutor(tp_size) as executor:
|
|
# Launch model processes
|
|
rets = executor.map(start_model_process, port_args.model_rpc_ports)
|
|
self.model_servers = [x[0] for x in rets]
|
|
self.procs = [x[1] for x in rets]
|
|
|
|
# Init model
|
|
def init_model(i):
|
|
return self.model_servers[i].init_model(i, server_args, port_args)
|
|
|
|
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
|
|
|
|
# Wrap functions
|
|
def async_wrap(func_name):
|
|
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
|
|
|
async def _func(*args, **kwargs):
|
|
tasks = [f(*args, **kwargs) for f in fs]
|
|
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
|
return obtain(tasks[0].value)
|
|
|
|
return _func
|
|
|
|
self.step = async_wrap("step")
|
|
|
|
|
|
def start_model_process(port):
|
|
def _init_service(port):
|
|
t = ThreadedServer(
|
|
ModelRpcServer(),
|
|
port=port,
|
|
protocol_config={"allow_pickle": True, "sync_request_timeout": 600},
|
|
)
|
|
t.start()
|
|
|
|
proc = multiprocessing.Process(target=_init_service, args=(port,))
|
|
proc.start()
|
|
time.sleep(1)
|
|
|
|
repeat_count = 0
|
|
while repeat_count < 20:
|
|
try:
|
|
con = rpyc.connect(
|
|
"localhost",
|
|
port,
|
|
config={"allow_pickle": True, "sync_request_timeout": 600},
|
|
)
|
|
break
|
|
except ConnectionRefusedError:
|
|
time.sleep(1)
|
|
repeat_count += 1
|
|
if repeat_count == 20:
|
|
raise RuntimeError("init rpc env error!")
|
|
|
|
assert proc.is_alive()
|
|
return con.root, proc
|