Organize server_args (#277)
This commit is contained in:
@@ -43,6 +43,17 @@ def Runtime(*args, **kwargs):
|
||||
def set_default_backend(backend: BaseBackend):
|
||||
global_config.default_backend = backend
|
||||
|
||||
def flush_cache(backend: BaseBackend = None):
|
||||
backend = backend or global_config.default_backend
|
||||
if backend is None:
|
||||
return False
|
||||
return backend.flush_cache()
|
||||
|
||||
def get_server_args(backend: BaseBackend = None):
|
||||
backend = backend or global_config.default_backend
|
||||
if backend is None:
|
||||
return None
|
||||
return backend.get_server_args()
|
||||
|
||||
def gen(
|
||||
name: Optional[str] = None,
|
||||
|
||||
@@ -72,3 +72,9 @@ class BaseBackend:
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
def get_server_args(self):
|
||||
pass
|
||||
|
||||
@@ -35,6 +35,22 @@ class RuntimeEndpoint(BaseBackend):
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_path"]
|
||||
|
||||
def flush_cache(self):
|
||||
res = http_request(
|
||||
self.base_url + "/flush_cache",
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify,
|
||||
)
|
||||
return res.status_code == 200
|
||||
|
||||
def get_server_args(self):
|
||||
res = http_request(
|
||||
self.base_url + "/get_server_args",
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify,
|
||||
)
|
||||
return res.json()
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
|
||||
@@ -15,11 +15,9 @@ class RadixAttention(nn.Module):
|
||||
self.head_dim = head_dim
|
||||
self.layer_id = layer_id
|
||||
|
||||
from sglang.srt.managers.router.model_runner import global_server_args
|
||||
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
||||
|
||||
self.use_flashinfer = "flashinfer" in global_server_args.model_mode
|
||||
|
||||
if self.use_flashinfer:
|
||||
if global_server_args_dict["enable_flashinfer"]:
|
||||
self.prefill_forward = self.prefill_forward_flashinfer
|
||||
self.extend_forward = self.prefill_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.managers.router.model_runner import global_server_args
|
||||
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
if global_server_args.attention_reduce_in_fp32:
|
||||
if global_server_args_dict["attention_reduce_in_fp32"]:
|
||||
REDUCE_TRITON_TYPE = tl.float32
|
||||
REDUCE_TORCH_TYPE = torch.float32
|
||||
else:
|
||||
|
||||
@@ -46,7 +46,6 @@ class ModelRpcServer(rpyc.Service):
|
||||
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
||||
|
||||
# Copy arguments
|
||||
self.model_mode = server_args.model_mode
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
self.schedule_heuristic = server_args.schedule_heuristic
|
||||
@@ -61,15 +60,22 @@ class ModelRpcServer(rpyc.Service):
|
||||
server_args.trust_remote_code,
|
||||
context_length=server_args.context_length,
|
||||
)
|
||||
|
||||
# for model end global settings
|
||||
server_args_dict = {
|
||||
"enable_flashinfer": server_args.enable_flashinfer,
|
||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||
}
|
||||
|
||||
self.model_runner = ModelRunner(
|
||||
model_config=self.model_config,
|
||||
mem_fraction_static=server_args.mem_fraction_static,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
nccl_port=port_args.nccl_port,
|
||||
server_args=server_args,
|
||||
load_format=server_args.load_format,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
server_args_dict=server_args_dict,
|
||||
)
|
||||
if is_multimodal_model(server_args.model_path):
|
||||
self.processor = get_processor(
|
||||
@@ -104,11 +110,11 @@ class ModelRpcServer(rpyc.Service):
|
||||
f"max_total_num_token={self.max_total_num_token}, "
|
||||
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
||||
f"context_len={self.model_config.context_len}, "
|
||||
f"model_mode={self.model_mode}"
|
||||
)
|
||||
logger.info(server_args.get_optional_modes_logging())
|
||||
|
||||
# Init cache
|
||||
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
|
||||
self.tree_cache = RadixCache(server_args.disable_radix_cache)
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.scheduler = Scheduler(
|
||||
self.schedule_heuristic,
|
||||
|
||||
@@ -23,7 +23,7 @@ logger = logging.getLogger("model_runner")
|
||||
|
||||
|
||||
# for server args in model endpoints
|
||||
global_server_args = None
|
||||
global_server_args_dict: dict = None
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@@ -222,7 +222,7 @@ class InputMetadata:
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
ret.init_extend_args()
|
||||
|
||||
if "flashinfer" in global_server_args.model_mode:
|
||||
if global_server_args_dict["enable_flashinfer"]:
|
||||
ret.init_flashinfer_args(tp_size)
|
||||
|
||||
return ret
|
||||
@@ -236,9 +236,9 @@ class ModelRunner:
|
||||
tp_rank,
|
||||
tp_size,
|
||||
nccl_port,
|
||||
server_args,
|
||||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
server_args_dict: dict = {},
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.mem_fraction_static = mem_fraction_static
|
||||
@@ -248,8 +248,8 @@ class ModelRunner:
|
||||
self.load_format = load_format
|
||||
self.trust_remote_code = trust_remote_code
|
||||
|
||||
global global_server_args
|
||||
global_server_args = server_args
|
||||
global global_server_args_dict
|
||||
global_server_args_dict = server_args_dict
|
||||
|
||||
# Init torch distributed
|
||||
torch.cuda.set_device(self.tp_rank)
|
||||
|
||||
@@ -82,6 +82,8 @@ class TokenizerManager:
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
self.server_args = server_args
|
||||
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""SRT: SGLang Runtime"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
@@ -86,6 +87,11 @@ async def get_model_info():
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/get_server_args")
|
||||
async def get_server_args():
|
||||
return dataclasses.asdict(tokenizer_manager.server_args)
|
||||
|
||||
|
||||
@app.get("/flush_cache")
|
||||
async def flush_cache():
|
||||
await tokenizer_manager.flush_cache()
|
||||
@@ -548,7 +554,6 @@ class Runtime:
|
||||
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
|
||||
context_length: int = ServerArgs.context_length,
|
||||
tp_size: int = 1,
|
||||
model_mode: List[str] = (),
|
||||
schedule_heuristic: str = "lpm",
|
||||
attention_reduce_in_fp32: bool = False,
|
||||
random_seed: int = 42,
|
||||
@@ -571,7 +576,6 @@ class Runtime:
|
||||
max_prefill_num_token=max_prefill_num_token,
|
||||
context_length=context_length,
|
||||
tp_size=tp_size,
|
||||
model_mode=model_mode,
|
||||
schedule_heuristic=schedule_heuristic,
|
||||
attention_reduce_in_fp32=attention_reduce_in_fp32,
|
||||
random_seed=random_seed,
|
||||
|
||||
@@ -18,7 +18,6 @@ class ServerArgs:
|
||||
max_prefill_num_token: Optional[int] = None
|
||||
context_length: Optional[int] = None
|
||||
tp_size: int = 1
|
||||
model_mode: List[str] = ()
|
||||
schedule_heuristic: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
attention_reduce_in_fp32: bool = False
|
||||
@@ -27,6 +26,10 @@ class ServerArgs:
|
||||
disable_log_stats: bool = False
|
||||
log_stats_interval: int = 10
|
||||
log_level: str = "info"
|
||||
|
||||
# optional modes
|
||||
disable_radix_cache: bool = False
|
||||
enable_flashinfer: bool = False
|
||||
disable_regex_jump_forward: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
|
||||
@@ -131,14 +134,6 @@ class ServerArgs:
|
||||
default=ServerArgs.tp_size,
|
||||
help="Tensor parallelism degree.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-mode",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="+",
|
||||
choices=["flashinfer", "no-cache"],
|
||||
help="Model mode: [flashinfer, no-cache]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schedule-heuristic",
|
||||
type=str,
|
||||
@@ -185,6 +180,17 @@ class ServerArgs:
|
||||
default=ServerArgs.log_stats_interval,
|
||||
help="Log stats interval in second.",
|
||||
)
|
||||
# optional modes
|
||||
parser.add_argument(
|
||||
"--disable-radix-cache",
|
||||
action="store_true",
|
||||
help="Disable RadixAttention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer",
|
||||
action="store_true",
|
||||
help="Enable flashinfer inference kernels",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-regex-jump-forward",
|
||||
action="store_true",
|
||||
@@ -204,6 +210,15 @@ class ServerArgs:
|
||||
def url(self):
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
def get_optional_modes_logging(self):
|
||||
return (
|
||||
f"disable_radix_cache={self.disable_radix_cache}, "
|
||||
f"enable_flashinfer={self.enable_flashinfer}, "
|
||||
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
||||
f"disable_disk_cache={self.disable_disk_cache}, "
|
||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
|
||||
Reference in New Issue
Block a user