diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index df7f1f098..567e97fcb 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -74,6 +74,11 @@ def convert_to_channelwise( (sum(logical_widths), 1), dtype=torch.float32, device=weight_scale.device ) + # Handle scalar tensor case: broadcast same scale to all channels + if weight_scale.dim() == 0: + weight_scale_channel.fill_(weight_scale.item()) + return weight_scale_channel + # Expand each scale to match the size of each logical matrix. start = 0 for idx, logical_width in enumerate(logical_widths): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index a88698b96..87426729a 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -33,6 +33,15 @@ DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST = "neuralmagic/Meta-Llama-3-8B-Instruct DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST = ( "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic" ) +DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST = ( + "nvidia/Llama-3.1-8B-Instruct-FP8" +) +# TODO(yundai424): right now specifying to an older revision since the latest one +# carries kv cache quantization which doesn't work yet +DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION = ( + "13858565416dbdc0b4e7a4a677fadfbd5b9e5bb9" +) + DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" diff --git a/test/srt/test_eval_fp8_accuracy.py b/test/srt/test_eval_fp8_accuracy.py index 25aa8a50d..d36216dd2 100644 --- a/test/srt/test_eval_fp8_accuracy.py +++ b/test/srt/test_eval_fp8_accuracy.py @@ -6,6 +6,8 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST, DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST, + DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST, + DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -105,5 +107,47 @@ class TestEvalFP8DynamicQuantAccuracy(unittest.TestCase): ) +class TestEvalFP8ModelOptQuantAccuracy(unittest.TestCase): + + def _run_test(self, model, other_args, expected_score): + base_url = DEFAULT_URL_FOR_TEST + other_args = other_args or [] + + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + try: + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], expected_score) + finally: + kill_process_tree(process.pid) + + def test_mmlu_offline_only(self): + """Test with offline quantization only.""" + self._run_test( + model=DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST, + other_args=[ + "--quantization", + "modelopt", + "--revision", + DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION, + ], + expected_score=0.64, + ) + + if __name__ == "__main__": unittest.main()