diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index 26daf4fa5..e8cafa15d 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -1,6 +1,8 @@ import unittest from types import SimpleNamespace +import requests + from sglang.srt.utils import kill_child_process from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( @@ -20,7 +22,7 @@ class TestTorchCompile(unittest.TestCase): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-torch-compile"], + other_args=["--enable-torch-compile", "--disable-radix-cache"], ) @classmethod @@ -39,6 +41,33 @@ class TestTorchCompile(unittest.TestCase): metrics = run_eval(args) assert metrics["score"] >= 0.6 + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + import time + + max_tokens = 256 + + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + assert throughput >= 152 + if __name__ == "__main__": unittest.main()