Organize server_args (#277)

This commit is contained in:
Liangsheng Yin
2024-03-11 20:06:52 +08:00
committed by GitHub
parent faba293a0d
commit 1b35547927
12 changed files with 92 additions and 34 deletions

View File

@@ -43,6 +43,17 @@ def Runtime(*args, **kwargs):
def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend
def flush_cache(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return False
return backend.flush_cache()
def get_server_args(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return None
return backend.get_server_args()
def gen(
name: Optional[str] = None,

View File

@@ -72,3 +72,9 @@ class BaseBackend:
def shutdown(self):
pass
def flush_cache(self):
pass
def get_server_args(self):
pass

View File

@@ -35,6 +35,22 @@ class RuntimeEndpoint(BaseBackend):
def get_model_name(self):
return self.model_info["model_path"]
def flush_cache(self):
res = http_request(
self.base_url + "/flush_cache",
auth_token=self.auth_token,
verify=self.verify,
)
return res.status_code == 200
def get_server_args(self):
res = http_request(
self.base_url + "/get_server_args",
auth_token=self.auth_token,
verify=self.verify,
)
return res.json()
def get_chat_template(self):
return self.chat_template

View File

@@ -15,11 +15,9 @@ class RadixAttention(nn.Module):
self.head_dim = head_dim
self.layer_id = layer_id
from sglang.srt.managers.router.model_runner import global_server_args
from sglang.srt.managers.router.model_runner import global_server_args_dict
self.use_flashinfer = "flashinfer" in global_server_args.model_mode
if self.use_flashinfer:
if global_server_args_dict["enable_flashinfer"]:
self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer

View File

@@ -4,10 +4,10 @@
import torch
import triton
import triton.language as tl
from sglang.srt.managers.router.model_runner import global_server_args
from sglang.srt.managers.router.model_runner import global_server_args_dict
from sglang.srt.utils import wrap_kernel_launcher
if global_server_args.attention_reduce_in_fp32:
if global_server_args_dict["attention_reduce_in_fp32"]:
REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32
else:

View File

@@ -46,7 +46,6 @@ class ModelRpcServer(rpyc.Service):
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
# Copy arguments
self.model_mode = server_args.model_mode
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
@@ -61,15 +60,22 @@ class ModelRpcServer(rpyc.Service):
server_args.trust_remote_code,
context_length=server_args.context_length,
)
# for model end global settings
server_args_dict = {
"enable_flashinfer": server_args.enable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=port_args.nccl_port,
server_args=server_args,
load_format=server_args.load_format,
trust_remote_code=server_args.trust_remote_code,
server_args_dict=server_args_dict,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
@@ -104,11 +110,11 @@ class ModelRpcServer(rpyc.Service):
f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, "
f"model_mode={self.model_mode}"
)
logger.info(server_args.get_optional_modes_logging())
# Init cache
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
self.tree_cache = RadixCache(server_args.disable_radix_cache)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler(
self.schedule_heuristic,

View File

@@ -23,7 +23,7 @@ logger = logging.getLogger("model_runner")
# for server args in model endpoints
global_server_args = None
global_server_args_dict: dict = None
@lru_cache()
@@ -222,7 +222,7 @@ class InputMetadata:
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()
if "flashinfer" in global_server_args.model_mode:
if global_server_args_dict["enable_flashinfer"]:
ret.init_flashinfer_args(tp_size)
return ret
@@ -236,9 +236,9 @@ class ModelRunner:
tp_rank,
tp_size,
nccl_port,
server_args,
load_format="auto",
trust_remote_code=True,
server_args_dict: dict = {},
):
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
@@ -248,8 +248,8 @@ class ModelRunner:
self.load_format = load_format
self.trust_remote_code = trust_remote_code
global global_server_args
global_server_args = server_args
global global_server_args_dict
global_server_args_dict = server_args_dict
# Init torch distributed
torch.cuda.set_device(self.tp_rank)

View File

@@ -82,6 +82,8 @@ class TokenizerManager:
server_args: ServerArgs,
port_args: PortArgs,
):
self.server_args = server_args
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")

View File

@@ -1,6 +1,7 @@
"""SRT: SGLang Runtime"""
import asyncio
import dataclasses
import json
import multiprocessing as mp
import os
@@ -86,6 +87,11 @@ async def get_model_info():
return result
@app.get("/get_server_args")
async def get_server_args():
return dataclasses.asdict(tokenizer_manager.server_args)
@app.get("/flush_cache")
async def flush_cache():
await tokenizer_manager.flush_cache()
@@ -548,7 +554,6 @@ class Runtime:
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
context_length: int = ServerArgs.context_length,
tp_size: int = 1,
model_mode: List[str] = (),
schedule_heuristic: str = "lpm",
attention_reduce_in_fp32: bool = False,
random_seed: int = 42,
@@ -571,7 +576,6 @@ class Runtime:
max_prefill_num_token=max_prefill_num_token,
context_length=context_length,
tp_size=tp_size,
model_mode=model_mode,
schedule_heuristic=schedule_heuristic,
attention_reduce_in_fp32=attention_reduce_in_fp32,
random_seed=random_seed,

View File

@@ -18,7 +18,6 @@ class ServerArgs:
max_prefill_num_token: Optional[int] = None
context_length: Optional[int] = None
tp_size: int = 1
model_mode: List[str] = ()
schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0
attention_reduce_in_fp32: bool = False
@@ -27,6 +26,10 @@ class ServerArgs:
disable_log_stats: bool = False
log_stats_interval: int = 10
log_level: str = "info"
# optional modes
disable_radix_cache: bool = False
enable_flashinfer: bool = False
disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False
@@ -131,14 +134,6 @@ class ServerArgs:
default=ServerArgs.tp_size,
help="Tensor parallelism degree.",
)
parser.add_argument(
"--model-mode",
type=str,
default=[],
nargs="+",
choices=["flashinfer", "no-cache"],
help="Model mode: [flashinfer, no-cache]",
)
parser.add_argument(
"--schedule-heuristic",
type=str,
@@ -185,6 +180,17 @@ class ServerArgs:
default=ServerArgs.log_stats_interval,
help="Log stats interval in second.",
)
# optional modes
parser.add_argument(
"--disable-radix-cache",
action="store_true",
help="Disable RadixAttention",
)
parser.add_argument(
"--enable-flashinfer",
action="store_true",
help="Enable flashinfer inference kernels",
)
parser.add_argument(
"--disable-regex-jump-forward",
action="store_true",
@@ -204,6 +210,15 @@ class ServerArgs:
def url(self):
return f"http://{self.host}:{self.port}"
def get_optional_modes_logging(self):
return (
f"disable_radix_cache={self.disable_radix_cache}, "
f"enable_flashinfer={self.enable_flashinfer}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
)
@dataclasses.dataclass
class PortArgs: