diff --git a/docs/flashinfer.md b/docs/flashinfer.md index 7ea1e1efc..2f1fd2dc1 100644 --- a/docs/flashinfer.md +++ b/docs/flashinfer.md @@ -16,10 +16,10 @@ please build it from source (the compilation takes a long time). ### Run a Server With Flashinfer Mode -Add `--model-mode flashinfer` argument to enable flashinfer when launching a server. +Add `--enable-flashinfer` argument to enable flashinfer when launching a server. Example: ```bash -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --model-mode flashinfer +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --enable-flashinfer ``` diff --git a/python/sglang/api.py b/python/sglang/api.py index 80cf080a9..f1337a67e 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -43,6 +43,17 @@ def Runtime(*args, **kwargs): def set_default_backend(backend: BaseBackend): global_config.default_backend = backend +def flush_cache(backend: BaseBackend = None): + backend = backend or global_config.default_backend + if backend is None: + return False + return backend.flush_cache() + +def get_server_args(backend: BaseBackend = None): + backend = backend or global_config.default_backend + if backend is None: + return None + return backend.get_server_args() def gen( name: Optional[str] = None, diff --git a/python/sglang/backend/base_backend.py b/python/sglang/backend/base_backend.py index 0bbf3ef3e..cb504f51b 100644 --- a/python/sglang/backend/base_backend.py +++ b/python/sglang/backend/base_backend.py @@ -72,3 +72,9 @@ class BaseBackend: def shutdown(self): pass + + def flush_cache(self): + pass + + def get_server_args(self): + pass diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index eb9dc3264..1c0d44540 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -35,6 +35,22 @@ class RuntimeEndpoint(BaseBackend): def get_model_name(self): return self.model_info["model_path"] + def flush_cache(self): + res = http_request( + self.base_url + "/flush_cache", + auth_token=self.auth_token, + verify=self.verify, + ) + return res.status_code == 200 + + def get_server_args(self): + res = http_request( + self.base_url + "/get_server_args", + auth_token=self.auth_token, + verify=self.verify, + ) + return res.json() + def get_chat_template(self): return self.chat_template diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 457cabc88..464327eed 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -15,11 +15,9 @@ class RadixAttention(nn.Module): self.head_dim = head_dim self.layer_id = layer_id - from sglang.srt.managers.router.model_runner import global_server_args + from sglang.srt.managers.router.model_runner import global_server_args_dict - self.use_flashinfer = "flashinfer" in global_server_args.model_mode - - if self.use_flashinfer: + if global_server_args_dict["enable_flashinfer"]: self.prefill_forward = self.prefill_forward_flashinfer self.extend_forward = self.prefill_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index fc316b6e7..a4a57fbe7 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -4,10 +4,10 @@ import torch import triton import triton.language as tl -from sglang.srt.managers.router.model_runner import global_server_args +from sglang.srt.managers.router.model_runner import global_server_args_dict from sglang.srt.utils import wrap_kernel_launcher -if global_server_args.attention_reduce_in_fp32: +if global_server_args_dict["attention_reduce_in_fp32"]: REDUCE_TRITON_TYPE = tl.float32 REDUCE_TORCH_TYPE = torch.float32 else: diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index a651f2c7a..68f8423ed 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -46,7 +46,6 @@ class ModelRpcServer(rpyc.Service): server_args, port_args = [obtain(x) for x in [server_args, port_args]] # Copy arguments - self.model_mode = server_args.model_mode self.tp_rank = tp_rank self.tp_size = server_args.tp_size self.schedule_heuristic = server_args.schedule_heuristic @@ -61,15 +60,22 @@ class ModelRpcServer(rpyc.Service): server_args.trust_remote_code, context_length=server_args.context_length, ) + + # for model end global settings + server_args_dict = { + "enable_flashinfer": server_args.enable_flashinfer, + "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, + } + self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, tp_rank=tp_rank, tp_size=server_args.tp_size, nccl_port=port_args.nccl_port, - server_args=server_args, load_format=server_args.load_format, trust_remote_code=server_args.trust_remote_code, + server_args_dict=server_args_dict, ) if is_multimodal_model(server_args.model_path): self.processor = get_processor( @@ -104,11 +110,11 @@ class ModelRpcServer(rpyc.Service): f"max_total_num_token={self.max_total_num_token}, " f"max_prefill_num_token={self.max_prefill_num_token}, " f"context_len={self.model_config.context_len}, " - f"model_mode={self.model_mode}" ) + logger.info(server_args.get_optional_modes_logging()) # Init cache - self.tree_cache = RadixCache(disable="no-cache" in self.model_mode) + self.tree_cache = RadixCache(server_args.disable_radix_cache) self.tree_cache_metrics = {"total": 0, "hit": 0} self.scheduler = Scheduler( self.schedule_heuristic, diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 20d163947..4ec7946c6 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -23,7 +23,7 @@ logger = logging.getLogger("model_runner") # for server args in model endpoints -global_server_args = None +global_server_args_dict: dict = None @lru_cache() @@ -222,7 +222,7 @@ class InputMetadata: if forward_mode == ForwardMode.EXTEND: ret.init_extend_args() - if "flashinfer" in global_server_args.model_mode: + if global_server_args_dict["enable_flashinfer"]: ret.init_flashinfer_args(tp_size) return ret @@ -236,9 +236,9 @@ class ModelRunner: tp_rank, tp_size, nccl_port, - server_args, load_format="auto", trust_remote_code=True, + server_args_dict: dict = {}, ): self.model_config = model_config self.mem_fraction_static = mem_fraction_static @@ -248,8 +248,8 @@ class ModelRunner: self.load_format = load_format self.trust_remote_code = trust_remote_code - global global_server_args - global_server_args = server_args + global global_server_args_dict + global_server_args_dict = server_args_dict # Init torch distributed torch.cuda.set_device(self.tp_rank) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 084ac791e..814141e1f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -82,6 +82,8 @@ class TokenizerManager: server_args: ServerArgs, port_args: PortArgs, ): + self.server_args = server_args + context = zmq.asyncio.Context(2) self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 2e604c2c7..95b7b439c 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -1,6 +1,7 @@ """SRT: SGLang Runtime""" import asyncio +import dataclasses import json import multiprocessing as mp import os @@ -86,6 +87,11 @@ async def get_model_info(): return result +@app.get("/get_server_args") +async def get_server_args(): + return dataclasses.asdict(tokenizer_manager.server_args) + + @app.get("/flush_cache") async def flush_cache(): await tokenizer_manager.flush_cache() @@ -548,7 +554,6 @@ class Runtime: max_prefill_num_token: int = ServerArgs.max_prefill_num_token, context_length: int = ServerArgs.context_length, tp_size: int = 1, - model_mode: List[str] = (), schedule_heuristic: str = "lpm", attention_reduce_in_fp32: bool = False, random_seed: int = 42, @@ -571,7 +576,6 @@ class Runtime: max_prefill_num_token=max_prefill_num_token, context_length=context_length, tp_size=tp_size, - model_mode=model_mode, schedule_heuristic=schedule_heuristic, attention_reduce_in_fp32=attention_reduce_in_fp32, random_seed=random_seed, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 596b584ab..f236a9ae1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -18,7 +18,6 @@ class ServerArgs: max_prefill_num_token: Optional[int] = None context_length: Optional[int] = None tp_size: int = 1 - model_mode: List[str] = () schedule_heuristic: str = "lpm" schedule_conservativeness: float = 1.0 attention_reduce_in_fp32: bool = False @@ -27,6 +26,10 @@ class ServerArgs: disable_log_stats: bool = False log_stats_interval: int = 10 log_level: str = "info" + + # optional modes + disable_radix_cache: bool = False + enable_flashinfer: bool = False disable_regex_jump_forward: bool = False disable_disk_cache: bool = False @@ -131,14 +134,6 @@ class ServerArgs: default=ServerArgs.tp_size, help="Tensor parallelism degree.", ) - parser.add_argument( - "--model-mode", - type=str, - default=[], - nargs="+", - choices=["flashinfer", "no-cache"], - help="Model mode: [flashinfer, no-cache]", - ) parser.add_argument( "--schedule-heuristic", type=str, @@ -185,6 +180,17 @@ class ServerArgs: default=ServerArgs.log_stats_interval, help="Log stats interval in second.", ) + # optional modes + parser.add_argument( + "--disable-radix-cache", + action="store_true", + help="Disable RadixAttention", + ) + parser.add_argument( + "--enable-flashinfer", + action="store_true", + help="Enable flashinfer inference kernels", + ) parser.add_argument( "--disable-regex-jump-forward", action="store_true", @@ -204,6 +210,15 @@ class ServerArgs: def url(self): return f"http://{self.host}:{self.port}" + def get_optional_modes_logging(self): + return ( + f"disable_radix_cache={self.disable_radix_cache}, " + f"enable_flashinfer={self.enable_flashinfer}, " + f"disable_regex_jump_forward={self.disable_regex_jump_forward}, " + f"disable_disk_cache={self.disable_disk_cache}, " + f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}" + ) + @dataclasses.dataclass class PortArgs: diff --git a/test/srt/model/bench_llama_low_api.py b/test/srt/model/bench_llama_low_api.py index 3e3534709..34c64cd6c 100644 --- a/test/srt/model/bench_llama_low_api.py +++ b/test/srt/model/bench_llama_low_api.py @@ -151,7 +151,7 @@ def bench_generate_worker( shared_len, unique_len, decode_len, - model_mode, + server_args_dict, ): assert unique_num % shared_num == 0 @@ -162,7 +162,7 @@ def bench_generate_worker( tp_rank=tp_rank, tp_size=tp_size, nccl_port=28888, - model_mode=model_mode, + server_args_dict=server_args_dict, ) batch = BenchBatch(model_runner) @@ -227,7 +227,7 @@ def bench_generate( shared_len, unique_len, decode_len, - model_mode, + server_args_dict, ): print( f"tp_size: {tp_size}, " @@ -236,7 +236,7 @@ def bench_generate( f"shared_len: {shared_len}, " f"unique_len: {unique_len}, " f"decode_len: {decode_len}, " - f"model_mode: {model_mode}" + f"server_args: {server_args_dict}" ) workers = [] for tp_rank in range(tp_size): @@ -251,7 +251,7 @@ def bench_generate( shared_len, unique_len, decode_len, - model_mode, + server_args_dict, ), ) proc.start() @@ -270,5 +270,5 @@ if __name__ == "__main__": shared_len=256, unique_len=256, decode_len=8, - model_mode=[], + server_args_dict={}, )