diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 355960367..7c26657aa 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -7,8 +7,11 @@ from torch import nn from sglang.global_config import global_config from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd -from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata -from sglang.srt.server import global_server_args_dict +from sglang.srt.managers.controller.model_runner import ( + ForwardMode, + InputMetadata, + global_server_args_dict, +) class RadixAttention(nn.Module): diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index 688e1ddd2..3d069775f 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -5,7 +5,7 @@ import torch import triton import triton.language as tl -from sglang.srt.server import global_server_args_dict +from sglang.srt.managers.controller.infer_batch import global_server_args_dict if global_server_args_dict.get("attention_reduce_in_fp32", False): REDUCE_TRITON_TYPE = tl.float32 diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 47dd043a5..006490a38 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -17,6 +17,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 +# Put some global args for easy access +global_server_args_dict = { + "disable_flashinfer": False, + "disable_flashinfer_sampling": False, + "attention_reduce_in_fp32": False, +} + class ForwardMode(IntEnum): # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. @@ -687,7 +694,7 @@ class Batch: # TODO(lmzheng): apply penalty probs = torch.softmax(logits, dim=-1) - if True: + if not global_server_args_dict["disable_flashinfer_sampling"]: max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( (max_top_k_round, batch_size), device=probs.device diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index a23f26aa4..219f2f692 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -25,7 +25,12 @@ from vllm.distributed import ( from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config -from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata +from sglang.srt.managers.controller.infer_batch import ( + Batch, + ForwardMode, + InputMetadata, + global_server_args_dict, +) from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -60,7 +65,13 @@ class ModelRunner: self.nccl_port = nccl_port self.server_args = server_args self.is_multimodal_model = is_multimodal_model(self.model_config) - monkey_patch_vllm_dummy_weight_loader() + global_server_args_dict.update( + { + "disable_flashinfer": server_args.disable_flashinfer, + "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, + "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, + } + ) # Init torch distributed torch.cuda.set_device(self.gpu_id) @@ -108,6 +119,7 @@ class ModelRunner: f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) + monkey_patch_vllm_dummy_weight_loader() device_config = DeviceConfig() load_config = LoadConfig(load_format=self.server_args.load_format) vllm_model_config = VllmModelConfig( diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index d65379b94..dd88d9a99 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -65,9 +65,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) app = FastAPI() tokenizer_manager = None -# Put some args for easily access -global_server_args_dict = {} - @app.get("/health") async def health() -> Response: @@ -150,14 +147,6 @@ def available_models(): return ModelList(data=model_cards) -def _set_global_server_args(server_args: ServerArgs): - global global_server_args_dict - global_server_args_dict = { - "disable_flashinfer": server_args.disable_flashinfer, - "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, - } - - def _set_torch_compile_config(): # The following configurations are for torch compile optimizations import torch._dynamo.config @@ -213,8 +202,6 @@ def launch_server( if server_args.enable_torch_compile: _set_torch_compile_config() - _set_global_server_args(server_args) - # Allocate ports server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d487cd7b8..cb59ec986 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -52,13 +52,14 @@ class ServerArgs: # Optimization/debug options disable_flashinfer: bool = False + disable_flashinfer_sampling: bool = False disable_radix_cache: bool = False disable_regex_jump_forward: bool = False disable_cuda_graph: bool = False disable_disk_cache: bool = False enable_torch_compile: bool = False - attention_reduce_in_fp32: bool = False enable_p2p_check: bool = False + attention_reduce_in_fp32: bool = False efficient_weight_load: bool = False # Distributed args @@ -303,7 +304,12 @@ class ServerArgs: parser.add_argument( "--disable-flashinfer", action="store_true", - help="Disable flashinfer inference kernels.", + help="Disable flashinfer attention kernels.", + ) + parser.add_argument( + "--disable-flashinfer-sampling", + action="store_true", + help="Disable flashinfer sampling kernels.", ) parser.add_argument( "--disable-radix-cache", @@ -330,17 +336,17 @@ class ServerArgs: action="store_true", help="Optimize the model with torch.compile, experimental feature.", ) + parser.add_argument( + "--enable-p2p-check", + action="store_true", + help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", + ) parser.add_argument( "--attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." "This only affects Triton attention kernels", ) - parser.add_argument( - "--enable-p2p-check", - action="store_true", - help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", - ) parser.add_argument( "--efficient-weight-load", action="store_true",