diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b74dcc39d..fa4a49ce8 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -101,6 +101,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( + DynamicGradMode, broadcast_pyobj, configure_logger, crash_on_warnings, @@ -487,7 +488,7 @@ class Scheduler(SchedulerOutputProcessorMixin): }, ) - @torch.no_grad() + @DynamicGradMode() def event_loop_normal(self): """A normal scheduler loop.""" while True: @@ -507,7 +508,7 @@ class Scheduler(SchedulerOutputProcessorMixin): self.last_batch = batch - @torch.no_grad() + @DynamicGradMode() def event_loop_overlap(self): """A scheduler loop that overlaps the CPU processing and GPU computation.""" self.result_queue = deque() diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 4a1f2d5c1..deb2fbe59 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import ( from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_compiler_backend +from sglang.srt.utils import DynamicGradMode, get_compiler_backend from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -115,7 +115,7 @@ class TpModelWorkerClient: logger.error(f"TpModelWorkerClient hit an exception: {traceback}") self.parent_process.send_signal(signal.SIGQUIT) - @torch.no_grad() + @DynamicGradMode() def forward_thread_func_(self): batch_pt = 0 batch_lists = [None] * 2 diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 85bf35967..2522ee324 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -61,6 +61,7 @@ from torch import nn from torch.func import functional_call from torch.library import Library from torch.profiler import ProfilerActivity, profile, record_function +from torch.utils._contextlib import _DecoratorContextManager from torch.utils.cpp_extension import CUDA_HOME from triton.runtime.cache import ( FileCacheManager, @@ -127,6 +128,63 @@ def is_cuda_available(): return is_cuda() +_ENABLE_TORCH_INFERENCE_MODE = os.getenv( + "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false" +).lower() in ("true", "1") + + +class DynamicGradMode(_DecoratorContextManager): + """ + A combination of torch.no_grad and torch.inference_mode, + with their behavior controlled by an environment variable. Just refer to them. + """ + + @staticmethod + def set_inference_mode(mode: bool): + if isinstance(mode, bool): + global _ENABLE_TORCH_INFERENCE_MODE + + _ENABLE_TORCH_INFERENCE_MODE = mode + else: + logger.warning("mode is not a boolean object") + + def __init__(self, mode=True): + if not torch._jit_internal.is_scripting(): + super().__init__() + if _ENABLE_TORCH_INFERENCE_MODE: + self.mode = mode + else: + self.prev = False + + def __new__(cls, mode_or_orig_func=True if _ENABLE_TORCH_INFERENCE_MODE else None): + if mode_or_orig_func is None or isinstance(mode_or_orig_func, bool): + return super().__new__(cls) + return cls()(mode_or_orig_func) + + def __enter__(self) -> None: + if _ENABLE_TORCH_INFERENCE_MODE: + self._inference_mode_context = torch._C._InferenceMode(self.mode) + self._inference_mode_context.__enter__() + else: + self.prev = torch.is_grad_enabled() + torch.set_grad_enabled(False) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + if _ENABLE_TORCH_INFERENCE_MODE: + self._inference_mode_context.__exit__(exc_type, exc_value, traceback) + else: + torch.set_grad_enabled(self.prev) + + def clone(self) -> "DynamicGradMode": + r""" + Create a copy of this class + """ + if _ENABLE_TORCH_INFERENCE_MODE: + return self.__class__(self.mode) + else: + return self.__class__() + + def enable_show_time_cost(): global show_time_cost show_time_cost = True diff --git a/python/sglang/test/test_dynamic_grad_mode.py b/python/sglang/test/test_dynamic_grad_mode.py new file mode 100644 index 000000000..c0287ec3d --- /dev/null +++ b/python/sglang/test/test_dynamic_grad_mode.py @@ -0,0 +1,57 @@ +import unittest + +import torch + +from sglang.srt.utils import DynamicGradMode + + +class TestDynamicGradMode(unittest.TestCase): + def test_inference(self): + # Test inference_mode + DynamicGradMode.set_inference_mode(True) + + @DynamicGradMode() + def create_tensor_x(): + return torch.empty(0) + + X = create_tensor_x() + self.assertTrue(not X.requires_grad and X.is_inference()) + + def test_no_grad(self): + # Test no_grad + DynamicGradMode.set_inference_mode(False) + + @DynamicGradMode() + def create_tensor_y(): + return torch.empty(0) + + Y = create_tensor_y() + self.assertTrue(not Y.requires_grad and not Y.is_inference()) + + def test_nested_inference(self): + # Test no_grad nested inference_mode, inference_mode should has higher priority + DynamicGradMode.set_inference_mode(False) + + @DynamicGradMode() + def create_tensor_z(): + with torch.inference_mode(): + return torch.empty(0) + + Z = create_tensor_z() + self.assertTrue(not Z.requires_grad and Z.is_inference()) + + def test_nested_no_grad(self): + # Test inference_mode nested no_grad, inference_mode should has higher priority + DynamicGradMode.set_inference_mode(True) + + @DynamicGradMode() + def create_tensor_w(): + with torch.no_grad(): + return torch.empty(0) + + W = create_tensor_w() + self.assertTrue(not W.requires_grad and W.is_inference()) + + +if __name__ == "__main__": + unittest.main(verbosity=2)