diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index b1f9d090d..42037be7c 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -20,7 +20,12 @@ class TestMLA(unittest.TestCase): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--trust-remote-code"], + other_args=[ + "--trust-remote-code", + "--enable-torch-compile", + "--cuda-graph-max-bs", + "2", + ], ) @classmethod