[Fix] use torch.inference_mode() instead of torch.no_grad() (#4372)
This commit is contained in:
57
python/sglang/test/test_dynamic_grad_mode.py
Normal file
57
python/sglang/test/test_dynamic_grad_mode.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import DynamicGradMode
|
||||
|
||||
|
||||
class TestDynamicGradMode(unittest.TestCase):
|
||||
def test_inference(self):
|
||||
# Test inference_mode
|
||||
DynamicGradMode.set_inference_mode(True)
|
||||
|
||||
@DynamicGradMode()
|
||||
def create_tensor_x():
|
||||
return torch.empty(0)
|
||||
|
||||
X = create_tensor_x()
|
||||
self.assertTrue(not X.requires_grad and X.is_inference())
|
||||
|
||||
def test_no_grad(self):
|
||||
# Test no_grad
|
||||
DynamicGradMode.set_inference_mode(False)
|
||||
|
||||
@DynamicGradMode()
|
||||
def create_tensor_y():
|
||||
return torch.empty(0)
|
||||
|
||||
Y = create_tensor_y()
|
||||
self.assertTrue(not Y.requires_grad and not Y.is_inference())
|
||||
|
||||
def test_nested_inference(self):
|
||||
# Test no_grad nested inference_mode, inference_mode should has higher priority
|
||||
DynamicGradMode.set_inference_mode(False)
|
||||
|
||||
@DynamicGradMode()
|
||||
def create_tensor_z():
|
||||
with torch.inference_mode():
|
||||
return torch.empty(0)
|
||||
|
||||
Z = create_tensor_z()
|
||||
self.assertTrue(not Z.requires_grad and Z.is_inference())
|
||||
|
||||
def test_nested_no_grad(self):
|
||||
# Test inference_mode nested no_grad, inference_mode should has higher priority
|
||||
DynamicGradMode.set_inference_mode(True)
|
||||
|
||||
@DynamicGradMode()
|
||||
def create_tensor_w():
|
||||
with torch.no_grad():
|
||||
return torch.empty(0)
|
||||
|
||||
W = create_tensor_w()
|
||||
self.assertTrue(not W.requires_grad and W.is_inference())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user