diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index c68591fc6..808e3e822 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( @@ -27,7 +28,6 @@ from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.utils import set_weight_attrs if TYPE_CHECKING: - from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import ( CombineInput, StandardDispatchOutput, diff --git a/test/srt/quant/test_w8a8_quantization.py b/test/srt/quant/test_w8a8_quantization.py index acb7f5c7d..cef51f0f0 100644 --- a/test/srt/quant/test_w8a8_quantization.py +++ b/test/srt/quant/test_w8a8_quantization.py @@ -14,23 +14,39 @@ from sglang.test.test_utils import ( ) -class TestW8A8(CustomTestCase): +class BaseW8A8Test(CustomTestCase): + model: str = None + quantization: str = None + gsm8k_accuracy_threshold: float = None + throughput_threshold: float = None + @classmethod def setUpClass(cls): - cls.model = "neuralmagic/Meta-Llama-3-8B-Instruct-quantized.w8a8" + if cls is BaseW8A8Test: + raise unittest.SkipTest("Skip base test class") + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [] + if cls.quantization: + other_args.extend(["--quantization", cls.quantization]) + cls.process = popen_launch_server( cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--quantization", "w8a8_int8"], + other_args=other_args, ) @classmethod def tearDownClass(cls): + if cls is BaseW8A8Test: + return kill_process_tree(cls.process.pid) def test_gsm8k(self): + if self.gsm8k_accuracy_threshold is None: + self.skipTest("gsm8k_accuracy_threshold not set for this test") + args = SimpleNamespace( num_shots=5, data_path=None, @@ -42,8 +58,7 @@ class TestW8A8(CustomTestCase): ) metrics = run_eval(args) print(metrics) - - self.assertGreater(metrics["accuracy"], 0.69) + self.assertGreater(metrics["accuracy"], self.gsm8k_accuracy_threshold) def run_decode(self, max_new_tokens): response = requests.post( @@ -60,15 +75,36 @@ class TestW8A8(CustomTestCase): return response.json() def test_throughput(self): - max_tokens = 256 + max_tokens = 256 tic = time.perf_counter() res = self.run_decode(max_tokens) tok = time.perf_counter() print(res["text"]) throughput = max_tokens / (tok - tic) print(f"Throughput: {throughput} tokens/s") - assert throughput >= 140 + self.assertGreaterEqual(throughput, self.throughput_threshold) + + +class TestW8A8Int8(BaseW8A8Test): + model = "neuralmagic/Meta-Llama-3-8B-Instruct-quantized.w8a8" + quantization = "w8a8_int8" + gsm8k_accuracy_threshold = 0.69 + throughput_threshold = 200 + + +class TestW8A8Fp8(BaseW8A8Test): + model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic" + quantization = "w8a8_fp8" + gsm8k_accuracy_threshold = 0.69 + throughput_threshold = 200 + + +class TestW8A8Fp8MoE(BaseW8A8Test): + model = "RedHatAI/Qwen3-30B-A3B-FP8-dynamic" + quantization = "w8a8_fp8" + gsm8k_accuracy_threshold = 0.88 + throughput_threshold = 180 if __name__ == "__main__":