diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index caeaa7736..7f57c6a96 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -7,8 +7,8 @@ from torch import nn from sglang.global_config import global_config from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd -from sglang.srt.managers.controller.infer_batch import global_server_args_dict from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata +from sglang.srt.server import global_server_args_dict class RadixAttention(nn.Module): diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index 2d0250114..d8db1e01d 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -5,7 +5,7 @@ import torch import triton import triton.language as tl -from sglang.srt.managers.controller.model_runner import global_server_args_dict +from sglang.srt.server import global_server_args_dict from sglang.srt.utils import wrap_kernel_launcher if global_server_args_dict.get("attention_reduce_in_fp32", False): diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 6f0a08f37..db0af09da 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -16,9 +16,6 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 -# Store some global server args -global_server_args_dict = {} - class ForwardMode(IntEnum): # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index ae1f555a1..b98ae32c8 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -20,12 +20,7 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config -from sglang.srt.managers.controller.infer_batch import ( - Batch, - ForwardMode, - InputMetadata, - global_server_args_dict, -) +from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -91,12 +86,6 @@ class ModelRunner: "The memory capacity is unbalanced. Some GPUs may be occupied by other processes." ) - # Set some global args - global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer - global_server_args_dict["attention_reduce_in_fp32"] = ( - server_args.attention_reduce_in_fp32 - ) - # Load the model and create memory pool self.load_model() self.init_memory_pool(total_gpu_memory) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 3e52bfcdd..cef04bc4b 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -64,6 +64,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) app = FastAPI() tokenizer_manager = None +# Put some args for easily access +global_server_args_dict = {} + @app.get("/health") async def health() -> Response: @@ -135,6 +138,14 @@ async def openai_v1_chat_completions(raw_request: Request): return await v1_chat_completions(tokenizer_manager, raw_request) +def _set_global_server_args(server_args: ServerArgs): + global global_server_args_dict + global_server_args_dict = { + "disable_flashinfer": server_args.disable_flashinfer, + "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, + } + + def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None): global tokenizer_manager @@ -163,6 +174,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg # TODO: replace this with huggingface transformers template load_chat_template_for_openai_api(server_args.chat_template) + _set_global_server_args(server_args) + # Allocate ports assert server_args.tp_size % server_args.nnodes == 0 tp_size_local = server_args.tp_size // server_args.nnodes