Files
sglang/python/sglang/srt/managers/router/model_rpc.py

785 lines
30 KiB
Python

import asyncio
import logging
import multiprocessing
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple, Union
import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
try:
from vllm.logger import _default_handler as vllm_default_logger
except ImportError:
from vllm.logger import logger as vllm_default_logger
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
class ModelRpcServer:
def __init__(
self,
tp_rank: int,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args: Optional[dict] = None,
):
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
# Copy arguments
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
context_length=server_args.context_length,
model_overide_args=model_overide_args,
)
# For model end global settings
server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=port_args.nccl_port,
load_format=server_args.load_format,
trust_remote_code=server_args.trust_remote_code,
server_args_dict=server_args_dict,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.max_total_num_token = self.model_runner.max_total_num_token
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.model_config.context_len,
(
self.max_total_num_token // 6
if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token
),
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
set_random_seed(server_args.random_seed)
# Print info
logger.info(
f"Rank {self.tp_rank}: "
f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, "
)
if self.tp_rank == 0:
logger.info(f"server_args: {server_args.print_mode_args()}")
# Init cache
self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler(
self.schedule_heuristic,
self.max_num_running_seq,
self.max_prefill_num_token,
self.max_total_num_token,
self.tree_cache,
)
self.req_to_token_pool = self.model_runner.req_to_token_pool
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
# Init running status
self.forward_queue: List[Req] = []
self.running_batch: Batch = None
self.out_pyobjs = []
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
)
self.jump_forward_cache = JumpForwardCache()
# Init new token estimation
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
def exposed_step(self, recv_reqs):
if self.tp_size != 1:
recv_reqs = obtain(recv_reqs)
try:
# Recv requests
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
else:
raise ValueError(f"Invalid request: {recv_req}")
# Forward
self.forward_step()
except Exception:
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
# Return results
ret = self.out_pyobjs
self.out_pyobjs = []
return ret
@torch.inference_mode()
def forward_step(self):
new_batch = self.get_new_fill_batch()
if new_batch is not None:
# Run a new fill batch
self.forward_fill_batch(new_batch)
self.cache_filled_batch(new_batch)
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge(new_batch)
else:
# Run decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(10):
self.num_generated_tokens += len(self.running_batch.reqs)
self.forward_decode_batch(self.running_batch)
# Print stats
if self.tp_rank == 0:
if self.decode_forward_ct % 40 == 0:
num_used = self.max_total_num_token - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throuhgput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_token:.2f}, "
f"gen throughput (token/s): {throuhgput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
if self.running_batch.is_empty():
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.reqs[0].stream:
break
else:
# Check the available size
available_size = (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_token:
warnings.warn(
"Warning: "
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
"KV cache pool leak detected!"
)
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
)
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
req.sampling_params.regex
)
# Truncate prompts that are too long
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min(
req.sampling_params.max_new_tokens,
self.model_config.context_len - 1 - len(req.input_ids),
self.max_total_num_token - 128 - len(req.input_ids),
)
self.forward_queue.append(req)
def get_new_fill_batch(self):
if (
self.running_batch is not None
and len(self.running_batch.reqs) > self.max_num_running_seq
):
return None
# Compute matched prefix length
for req in self.forward_queue:
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len]
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
req.prefix_indices = prefix_indices
req.last_node = last_node
# Get priority queue
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
# Add requests if there is available space
can_run_list = []
new_batch_total_tokens = 0
new_batch_input_tokens = 0
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
if self.running_batch:
available_size -= sum(
[
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
for r in self.running_batch.reqs
]
)
for req in self.forward_queue:
if req.return_logprob:
# Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2:
delta = 2 - req.extend_input_len
req.extend_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.extend_input_len == 0 and req.max_new_tokens() > 0:
# Need at least one token to compute logits
req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1
if (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
and req.extend_input_len + new_batch_input_tokens
< self.max_prefill_num_token
):
delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta
if not (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
):
# Undo locking
delta = self.tree_cache.dec_lock_ref(req.last_node)
available_size += delta
break
else:
# Add this request to the running batch
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.max_new_tokens()
)
new_batch_input_tokens += req.extend_input_len
else:
break
if len(can_run_list) == 0:
return None
# Print stats
if self.tp_rank == 0:
running_req = (
0 if self.running_batch is None else len(self.running_batch.reqs)
)
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
self.tree_cache_metrics["total"] += (
hit_tokens + new_batch_input_tokens
) / 10**9
self.tree_cache_metrics["hit"] += hit_tokens / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
logger.info(
f"new fill batch. #seq: {len(can_run_list)}. "
f"#cached_token: {hit_tokens}. "
f"#new_token: {new_batch_input_tokens}. "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
)
# logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
# )
# Return the new batch
new_batch = Batch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
)
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
return new_batch
def forward_fill_batch(self, batch: Batch):
# Build batch tensors
batch.prepare_for_extend(
self.model_config.vocab_size, self.int_token_logit_bias
)
if batch.extend_num_tokens != 0:
# Forward
logits, (
prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs,
decode_top_logprobs,
last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
if prefill_token_logprobs is not None:
prefill_token_logprobs = prefill_token_logprobs.tolist()
normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
next_token_ids, _ = batch.sample(logits)
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
last_token_logprobs = last_logprobs[
torch.arange(len(batch.reqs), device=next_token_ids.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish condition
pt = 0
for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1
req.output_ids = [next_token_ids[i]]
req.check_finished()
if req.return_logprob:
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list(
zip(
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
)
)
if req.logprob_start_len == 0:
req.prefill_token_logprobs = [
(None, req.input_ids[0])
] + req.prefill_token_logprobs
req.decode_token_logprobs = [
(last_token_logprobs[i], next_token_ids[i])
]
if req.top_logprobs_num > 0:
req.prefill_top_logprobs = prefill_top_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
req.decode_top_logprobs = [decode_top_logprobs[i]]
pt += req.extend_input_len
self.handle_finished_requests(batch)
def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
del_in_memory_pool=False,
old_last_node=req.last_node,
)
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
def forward_decode_batch(self, batch: Batch):
# check if decode out of memory
if not batch.check_decode_mem():
old_ratio = self.new_token_ratio
self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
retracted_reqs = batch.retract_decode()
logger.info(
"decode out of memory happened, "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self.forward_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_step[0],
self.min_new_token_ratio,
)
if not self.disable_regex_jump_forward:
# check for jump-forward
jump_forward_reqs = batch.check_for_jump_forward()
# check for image jump-forward
for req in jump_forward_reqs:
if req.pixel_values is not None:
(
req.input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.input_ids,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
self.forward_queue.extend(jump_forward_reqs)
if batch.is_empty():
return
# Update batch tensors
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode()
# Forward
logits, (
_,
_,
_,
decode_top_logprobs,
last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.tolist()
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
new_token_logprobs = last_logprobs[
torch.arange(len(batch.reqs)), next_token_ids
].tolist()
# Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
if req.return_logprob:
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(decode_top_logprobs[i])
self.handle_finished_requests(batch)
def handle_finished_requests(self, batch: Batch):
output_rids = []
output_tokens = []
output_and_jump_forward_strs = []
output_hit_stop_str = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_meta_info = []
output_finished = []
finished_indices = []
unfinished_indices = []
for i, req in enumerate(batch.reqs):
if req.finished:
finished_indices.append(i)
else:
unfinished_indices.append(i)
if req.finished or (
(
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
)
):
output_rids.append(req.rid)
output_tokens.append(req.output_ids)
output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
meta_info = {
"prompt_tokens": req.prompt_tokens,
"completion_tokens": len(req.input_ids)
+ len(req.output_ids)
- req.prompt_tokens,
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": FinishReason.to_str(req.finish_reason),
"hit_stop_str": req.hit_stop_str,
}
if req.return_logprob:
(
meta_info["prefill_token_logprobs"],
meta_info["decode_token_logprobs"],
meta_info["prefill_top_logprobs"],
meta_info["decode_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.prefill_token_logprobs,
req.decode_token_logprobs,
req.prefill_top_logprobs,
req.decode_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
output_finished.append(req.finished)
# Send to detokenizer
if output_rids:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
output_tokens,
output_and_jump_forward_strs,
output_hit_stop_str,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
output_finished,
)
)
# Remove finished reqs
if finished_indices:
# Update radix cache
req_pool_indices_cpu = batch.req_pool_indices.tolist()
for i in finished_indices:
req = batch.reqs[i]
self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
)
self.tree_cache.dec_lock_ref(req.last_node)
# Update batch tensors
if unfinished_indices:
batch.filter_batch(unfinished_indices)
else:
batch.reqs = []
def flush_cache(self):
if len(self.forward_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.regex_fsm_cache.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
torch.cuda.empty_cache()
logger.info("Cache flushed successfully!")
else:
warnings.warn(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.forward_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
def abort_request(self, recv_req):
to_del = None
for i, req in enumerate(self.forward_queue):
if req.rid == recv_req.rid:
to_del = i
break
if to_del is not None:
del self.forward_queue[to_del]
class ModelRpcService(rpyc.Service):
exposed_ModelRpcServer = ModelRpcServer
class ModelRpcClient:
def __init__(
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
):
tp_size = server_args.tp_size
if tp_size == 1:
# Init model
self.model_server = ModelRpcService().exposed_ModelRpcServer(
0, server_args, port_args, model_overide_args
)
# Wrap functions
def async_wrap(f):
async def _func(*args, **kwargs):
return f(*args, **kwargs)
return _func
self.step = async_wrap(self.model_server.exposed_step)
else:
with ThreadPoolExecutor(tp_size) as executor:
# Launch model processes
rets = executor.map(start_model_process, port_args.model_rpc_ports)
self.remote_services = [x[0] for x in rets]
self.procs = [x[1] for x in rets]
# Init model
def init_model(i):
return self.remote_services[i].ModelRpcServer(
i, server_args, port_args, model_overide_args
)
self.model_servers = executor.map(init_model, range(tp_size))
# Wrap functions
def async_wrap(func_name):
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
async def _func(*args, **kwargs):
tasks = [f(*args, **kwargs) for f in fs]
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
return obtain(tasks[0].value)
return _func
self.step = async_wrap("step")
def _init_service(port):
t = ThreadedServer(
ModelRpcService(),
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 1800,
},
)
t.start()
def start_model_process(port):
proc = multiprocessing.Process(target=_init_service, args=(port,))
proc.start()
time.sleep(1)
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
"localhost",
port,
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 1800,
},
)
break
except ConnectionRefusedError:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError("init rpc env error!")
assert proc.is_alive()
return con.root, proc