diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 654095c85..860568097 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -15,6 +15,7 @@ import json import logging import math +import os from enum import IntEnum, auto from typing import List, Optional, Set, Union @@ -234,6 +235,20 @@ class ModelConfig: if quant_cfg is None: # compressed-tensors uses a "compression_config" key quant_cfg = getattr(self.hf_config, "compression_config", None) + if quant_cfg is None: + # check if is modelopt model -- modelopt doesn't have corresponding field + # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory + # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main + is_local = os.path.isdir(self.model_path) + modelopt_quant_config = {"quant_method": "modelopt"} + if not is_local: + from huggingface_hub import HfApi + + hf_api = HfApi() + if hf_api.file_exists(self.model_path, "hf_quant_config.json"): + quant_cfg = modelopt_quant_config + elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")): + quant_cfg = modelopt_quant_config return quant_cfg # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py diff --git a/test/srt/test_eval_fp8_accuracy.py b/test/srt/test_eval_fp8_accuracy.py index 12b2499d9..a9d126fdf 100644 --- a/test/srt/test_eval_fp8_accuracy.py +++ b/test/srt/test_eval_fp8_accuracy.py @@ -1,15 +1,11 @@ import unittest from types import SimpleNamespace -import torch - from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST, DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST, - DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST, - DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -110,50 +106,5 @@ class TestEvalFP8DynamicQuantAccuracy(CustomTestCase): ) -class TestEvalFP8ModelOptQuantAccuracy(CustomTestCase): - - def _run_test(self, model, other_args, expected_score): - base_url = DEFAULT_URL_FOR_TEST - other_args = other_args or [] - - process = popen_launch_server( - model, - base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=other_args, - ) - - try: - args = SimpleNamespace( - base_url=base_url, - model=model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - temperature=0.1, - ) - - metrics = run_eval(args) - self.assertGreaterEqual(metrics["score"], expected_score) - finally: - kill_process_tree(process.pid) - - @unittest.skipIf( - torch.version.hip is not None, "modelopt quantization unsupported on ROCm" - ) - def test_mmlu_offline_only(self): - """Test with offline quantization only.""" - self._run_test( - model=DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST, - other_args=[ - "--quantization", - "modelopt", - "--revision", - DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION, - ], - expected_score=0.64, - ) - - if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_modelopt.py b/test/srt/test_modelopt.py new file mode 100644 index 000000000..166af22a5 --- /dev/null +++ b/test/srt/test_modelopt.py @@ -0,0 +1,58 @@ +import unittest +from types import SimpleNamespace + +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST, + DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestEvalFP8ModelOptQuantAccuracy(CustomTestCase): + + def _run_test(self, model, other_args, expected_score): + base_url = DEFAULT_URL_FOR_TEST + other_args = other_args or [] + + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + try: + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], expected_score) + finally: + kill_process_tree(process.pid) + + @unittest.skipIf( + torch.version.hip is not None, "modelopt quantization unsupported on ROCm" + ) + def test_mmlu_offline_only(self): + """Test with offline quantization only.""" + self._run_test( + model=DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST, + other_args=[ + "--revision", + DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION, + ], + expected_score=0.64, + )