diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 658b3d2f8..ad5aa6aa5 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -47,6 +47,7 @@ suites = { "test_update_weights_from_tensor.py", "test_vision_chunked_prefill.py", "test_vision_openai_server.py", + "test_w8a8_quantization.py", "test_session_control.py", ], "nightly": [ diff --git a/test/srt/test_w8a8_quantization.py b/test/srt/test_w8a8_quantization.py new file mode 100644 index 000000000..78579d5e2 --- /dev/null +++ b/test/srt/test_w8a8_quantization.py @@ -0,0 +1,74 @@ +import time +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestW8A8(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "neuralmagic/Meta-Llama-3-8B-Instruct-quantized.w8a8" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--quantization", "w8a8_int8"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.7) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + max_tokens = 256 + + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + assert throughput >= 140 + + +if __name__ == "__main__": + unittest.main()