[Fix] use torch.inference_mode() instead of torch.no_grad() (#4372)
This commit is contained in:
@@ -101,6 +101,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
DynamicGradMode,
|
||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
crash_on_warnings,
|
crash_on_warnings,
|
||||||
@@ -487,7 +488,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@DynamicGradMode()
|
||||||
def event_loop_normal(self):
|
def event_loop_normal(self):
|
||||||
"""A normal scheduler loop."""
|
"""A normal scheduler loop."""
|
||||||
while True:
|
while True:
|
||||||
@@ -507,7 +508,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
@torch.no_grad()
|
@DynamicGradMode()
|
||||||
def event_loop_overlap(self):
|
def event_loop_overlap(self):
|
||||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||||
self.result_queue = deque()
|
self.result_queue = deque()
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_compiler_backend
|
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -115,7 +115,7 @@ class TpModelWorkerClient:
|
|||||||
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
||||||
self.parent_process.send_signal(signal.SIGQUIT)
|
self.parent_process.send_signal(signal.SIGQUIT)
|
||||||
|
|
||||||
@torch.no_grad()
|
@DynamicGradMode()
|
||||||
def forward_thread_func_(self):
|
def forward_thread_func_(self):
|
||||||
batch_pt = 0
|
batch_pt = 0
|
||||||
batch_lists = [None] * 2
|
batch_lists = [None] * 2
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ from torch import nn
|
|||||||
from torch.func import functional_call
|
from torch.func import functional_call
|
||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
from torch.profiler import ProfilerActivity, profile, record_function
|
from torch.profiler import ProfilerActivity, profile, record_function
|
||||||
|
from torch.utils._contextlib import _DecoratorContextManager
|
||||||
from torch.utils.cpp_extension import CUDA_HOME
|
from torch.utils.cpp_extension import CUDA_HOME
|
||||||
from triton.runtime.cache import (
|
from triton.runtime.cache import (
|
||||||
FileCacheManager,
|
FileCacheManager,
|
||||||
@@ -127,6 +128,63 @@ def is_cuda_available():
|
|||||||
return is_cuda()
|
return is_cuda()
|
||||||
|
|
||||||
|
|
||||||
|
_ENABLE_TORCH_INFERENCE_MODE = os.getenv(
|
||||||
|
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
|
||||||
|
).lower() in ("true", "1")
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicGradMode(_DecoratorContextManager):
|
||||||
|
"""
|
||||||
|
A combination of torch.no_grad and torch.inference_mode,
|
||||||
|
with their behavior controlled by an environment variable. Just refer to them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_inference_mode(mode: bool):
|
||||||
|
if isinstance(mode, bool):
|
||||||
|
global _ENABLE_TORCH_INFERENCE_MODE
|
||||||
|
|
||||||
|
_ENABLE_TORCH_INFERENCE_MODE = mode
|
||||||
|
else:
|
||||||
|
logger.warning("mode is not a boolean object")
|
||||||
|
|
||||||
|
def __init__(self, mode=True):
|
||||||
|
if not torch._jit_internal.is_scripting():
|
||||||
|
super().__init__()
|
||||||
|
if _ENABLE_TORCH_INFERENCE_MODE:
|
||||||
|
self.mode = mode
|
||||||
|
else:
|
||||||
|
self.prev = False
|
||||||
|
|
||||||
|
def __new__(cls, mode_or_orig_func=True if _ENABLE_TORCH_INFERENCE_MODE else None):
|
||||||
|
if mode_or_orig_func is None or isinstance(mode_or_orig_func, bool):
|
||||||
|
return super().__new__(cls)
|
||||||
|
return cls()(mode_or_orig_func)
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
if _ENABLE_TORCH_INFERENCE_MODE:
|
||||||
|
self._inference_mode_context = torch._C._InferenceMode(self.mode)
|
||||||
|
self._inference_mode_context.__enter__()
|
||||||
|
else:
|
||||||
|
self.prev = torch.is_grad_enabled()
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||||
|
if _ENABLE_TORCH_INFERENCE_MODE:
|
||||||
|
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
|
||||||
|
else:
|
||||||
|
torch.set_grad_enabled(self.prev)
|
||||||
|
|
||||||
|
def clone(self) -> "DynamicGradMode":
|
||||||
|
r"""
|
||||||
|
Create a copy of this class
|
||||||
|
"""
|
||||||
|
if _ENABLE_TORCH_INFERENCE_MODE:
|
||||||
|
return self.__class__(self.mode)
|
||||||
|
else:
|
||||||
|
return self.__class__()
|
||||||
|
|
||||||
|
|
||||||
def enable_show_time_cost():
|
def enable_show_time_cost():
|
||||||
global show_time_cost
|
global show_time_cost
|
||||||
show_time_cost = True
|
show_time_cost = True
|
||||||
|
|||||||
57
python/sglang/test/test_dynamic_grad_mode.py
Normal file
57
python/sglang/test/test_dynamic_grad_mode.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.utils import DynamicGradMode
|
||||||
|
|
||||||
|
|
||||||
|
class TestDynamicGradMode(unittest.TestCase):
|
||||||
|
def test_inference(self):
|
||||||
|
# Test inference_mode
|
||||||
|
DynamicGradMode.set_inference_mode(True)
|
||||||
|
|
||||||
|
@DynamicGradMode()
|
||||||
|
def create_tensor_x():
|
||||||
|
return torch.empty(0)
|
||||||
|
|
||||||
|
X = create_tensor_x()
|
||||||
|
self.assertTrue(not X.requires_grad and X.is_inference())
|
||||||
|
|
||||||
|
def test_no_grad(self):
|
||||||
|
# Test no_grad
|
||||||
|
DynamicGradMode.set_inference_mode(False)
|
||||||
|
|
||||||
|
@DynamicGradMode()
|
||||||
|
def create_tensor_y():
|
||||||
|
return torch.empty(0)
|
||||||
|
|
||||||
|
Y = create_tensor_y()
|
||||||
|
self.assertTrue(not Y.requires_grad and not Y.is_inference())
|
||||||
|
|
||||||
|
def test_nested_inference(self):
|
||||||
|
# Test no_grad nested inference_mode, inference_mode should has higher priority
|
||||||
|
DynamicGradMode.set_inference_mode(False)
|
||||||
|
|
||||||
|
@DynamicGradMode()
|
||||||
|
def create_tensor_z():
|
||||||
|
with torch.inference_mode():
|
||||||
|
return torch.empty(0)
|
||||||
|
|
||||||
|
Z = create_tensor_z()
|
||||||
|
self.assertTrue(not Z.requires_grad and Z.is_inference())
|
||||||
|
|
||||||
|
def test_nested_no_grad(self):
|
||||||
|
# Test inference_mode nested no_grad, inference_mode should has higher priority
|
||||||
|
DynamicGradMode.set_inference_mode(True)
|
||||||
|
|
||||||
|
@DynamicGradMode()
|
||||||
|
def create_tensor_w():
|
||||||
|
with torch.no_grad():
|
||||||
|
return torch.empty(0)
|
||||||
|
|
||||||
|
W = create_tensor_w()
|
||||||
|
self.assertTrue(not W.requires_grad and W.is_inference())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
Reference in New Issue
Block a user