From 5ec610e832769b0da0b0f99db07e71047004cdf8 Mon Sep 17 00:00:00 2001 From: Cao Yi Date: Fri, 13 Mar 2026 22:53:25 +0800 Subject: [PATCH] [Feature][Quant] Reapply auto-detect quantization format and support remote model ID (#7111) ### What this PR does / why we need it? Reapply the auto-detect quantization format feature (originally in #6645, reverted in #6873) and extend it to support remote model identifiers (e.g., `org/model-name`). Changes: - Reapply auto-detection of quantization method from model files (`quant_model_description.json` for ModelSlim, `config.json` for compressed-tensors) - Add `get_model_file()` utility to handle file retrieval from both local paths and remote repos (HuggingFace Hub / ModelScope) - Update `detect_quantization_method()` to accept remote repo IDs with optional `revision` parameter - Update `maybe_update_config()` to work with remote model identifiers - Add platform-level `auto_detect_quantization` support - Add unit tests and e2e tests for both local and remote model ID scenarios Closes #6836 ### Does this PR introduce _any_ user-facing change? Yes. When `--quantization` is not explicitly specified, vllm-ascend will now automatically detect the quantization format from the model files for both local directories and remote model IDs. - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: SlightwindSec --- tests/e2e/singlecard/test_quantization.py | 37 ++++ .../ut/quantization/test_modelslim_config.py | 91 +++++++- tests/ut/quantization/test_quant_utils.py | 192 +++++++++++++++++ tests/ut/test_platform.py | 24 ++- vllm_ascend/platform.py | 5 + vllm_ascend/quantization/modelslim_config.py | 120 ++++++++++- vllm_ascend/quantization/utils.py | 201 ++++++++++++++++++ 7 files changed, 658 insertions(+), 12 deletions(-) create mode 100644 tests/ut/quantization/test_quant_utils.py create mode 100644 vllm_ascend/quantization/utils.py diff --git a/tests/e2e/singlecard/test_quantization.py b/tests/e2e/singlecard/test_quantization.py index b50ac3cf..19ee42e0 100644 --- a/tests/e2e/singlecard/test_quantization.py +++ b/tests/e2e/singlecard/test_quantization.py @@ -49,6 +49,43 @@ def test_qwen3_w8a8_quant(): name_1="vllm_quant_w8a8_outputs", ) +# fmt: off +def test_qwen3_w8a8_quant_auto_detect(): + """Test that ModelSlim quantization is auto-detected without --quantization. + + Uses the same W8A8 model as test_qwen3_w8a8_quant but omits the + quantization parameter, verifying that the auto-detection in + maybe_auto_detect_quantization() picks up quant_model_description.json + and produces identical results. + """ + max_tokens = 5 + example_prompts = [ + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs." + ] + vllm_target_outputs = [([ + 85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323, + 13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 311, 387 + ], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be' + )] +# fmt: on + + with VllmRunner( + "vllm-ascend/Qwen3-0.6B-W8A8", + max_model_len=8192, + gpu_memory_utilization=0.7, + cudagraph_capture_sizes=[1, 2, 4, 8], + ) as vllm_model: + vllm_quant_auto_detect_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_target_outputs, + outputs_1_lst=vllm_quant_auto_detect_outputs, + name_0="vllm_target_outputs", + name_1="vllm_quant_auto_detect_outputs", + ) + + # fmt: off def test_qwen3_dense_w8a16(): max_tokens = 5 diff --git a/tests/ut/quantization/test_modelslim_config.py b/tests/ut/quantization/test_modelslim_config.py index f71238aa..556c8a4a 100644 --- a/tests/ut/quantization/test_modelslim_config.py +++ b/tests/ut/quantization/test_modelslim_config.py @@ -1,3 +1,6 @@ +import json +import os +import tempfile from unittest.mock import MagicMock, patch from vllm.model_executor.layers.fused_moe import FusedMoE @@ -7,6 +10,7 @@ from vllm.model_executor.layers.linear import LinearBase from tests.ut.base import TestBase from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend.quantization.modelslim_config import ( + MODELSLIM_CONFIG_FILENAME, AscendModelSlimConfig, ) from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD @@ -53,7 +57,7 @@ class TestAscendModelSlimConfig(TestBase): def test_get_config_filenames(self): filenames = AscendModelSlimConfig.get_config_filenames() - self.assertEqual(filenames, ["quant_model_description.json"]) + self.assertEqual(filenames, []) def test_from_config(self): config = AscendModelSlimConfig.from_config(self.sample_config) @@ -161,5 +165,90 @@ class TestAscendModelSlimConfig(TestBase): with self.assertRaises(ValueError): config.is_layer_skipped_ascend("fused_layer", fused_mapping) + def test_init_with_none_config(self): + config = AscendModelSlimConfig(None) + self.assertEqual(config.quant_description, {}) + + def test_init_with_default_config(self): + config = AscendModelSlimConfig() + self.assertEqual(config.quant_description, {}) + + def test_maybe_update_config_already_populated(self): + # When quant_description is already populated, should be a no-op + self.assertTrue(len(self.ascend_config.quant_description) > 0) + self.ascend_config.maybe_update_config("/some/model/path") + # quant_description should remain unchanged + self.assertEqual(self.ascend_config.quant_description, + self.sample_config) + + def test_maybe_update_config_loads_from_file(self): + config = AscendModelSlimConfig() + self.assertEqual(config.quant_description, {}) + + quant_data = {"layer1.weight": "INT8", "layer2.weight": "FLOAT"} + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME) + with open(config_path, "w") as f: + json.dump(quant_data, f) + + config.maybe_update_config(tmpdir) + + self.assertEqual(config.quant_description, quant_data) + + def test_maybe_update_config_raises_when_file_missing(self): + config = AscendModelSlimConfig() + + with tempfile.TemporaryDirectory() as tmpdir: + with self.assertRaises(ValueError) as ctx: + config.maybe_update_config(tmpdir) + + error_msg = str(ctx.exception) + self.assertIn("ModelSlim Quantization Config Not Found", error_msg) + self.assertIn(MODELSLIM_CONFIG_FILENAME, error_msg) + + def test_maybe_update_config_raises_with_json_files_listed(self): + config = AscendModelSlimConfig() + + with tempfile.TemporaryDirectory() as tmpdir: + # Create a dummy json file that is NOT the config file + dummy_path = os.path.join(tmpdir, "config.json") + with open(dummy_path, "w") as f: + json.dump({"dummy": True}, f) + + with self.assertRaises(ValueError) as ctx: + config.maybe_update_config(tmpdir) + + error_msg = str(ctx.exception) + self.assertIn("config.json", error_msg) + + def test_maybe_update_config_non_directory_raises(self): + config = AscendModelSlimConfig() + + with self.assertRaises(ValueError) as ctx: + config.maybe_update_config("not_a_real_directory_path") + + error_msg = str(ctx.exception) + self.assertIn("ModelSlim Quantization Config Not Found", error_msg) + + def test_apply_extra_quant_adaptations_shared_head(self): + config = AscendModelSlimConfig() + config.quant_description = { + "model.layers.0.shared_head.weight": "INT8", + } + config._apply_extra_quant_adaptations() + self.assertIn("model.layers.0.weight", config.quant_description) + self.assertEqual(config.quant_description["model.layers.0.weight"], + "INT8") + + def test_apply_extra_quant_adaptations_weight_packed(self): + config = AscendModelSlimConfig() + config.quant_description = { + "model.layers.0.weight_packed": "INT8", + } + config._apply_extra_quant_adaptations() + self.assertIn("model.layers.0.weight", config.quant_description) + self.assertEqual(config.quant_description["model.layers.0.weight"], + "INT8") + def test_get_scaled_act_names(self): self.assertEqual(self.ascend_config.get_scaled_act_names(), []) diff --git a/tests/ut/quantization/test_quant_utils.py b/tests/ut/quantization/test_quant_utils.py new file mode 100644 index 00000000..70c82ec8 --- /dev/null +++ b/tests/ut/quantization/test_quant_utils.py @@ -0,0 +1,192 @@ +import json +import logging +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.modelslim_config import MODELSLIM_CONFIG_FILENAME +from vllm_ascend.quantization.utils import ( + detect_quantization_method, + maybe_auto_detect_quantization, +) +from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD + + +class TestDetectQuantizationMethod(TestBase): + + def test_returns_none_for_non_existent_path(self): + result = detect_quantization_method("/non/existent/path") + self.assertIsNone(result) + + def test_detects_modelslim(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME) + with open(config_path, "w") as f: + json.dump({"layer.weight": "INT8"}, f) + + result = detect_quantization_method(tmpdir) + self.assertEqual(result, ASCEND_QUANTIZATION_METHOD) + + def test_detects_compressed_tensors(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, "w") as f: + json.dump({ + "quantization_config": { + "quant_method": "compressed-tensors" + } + }, f) + + result = detect_quantization_method(tmpdir) + self.assertEqual(result, COMPRESSED_TENSORS_METHOD) + + def test_returns_none_for_no_quant(self): + with tempfile.TemporaryDirectory() as tmpdir: + result = detect_quantization_method(tmpdir) + self.assertIsNone(result) + + def test_returns_none_for_non_compressed_tensors_quant_method(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, "w") as f: + json.dump({ + "quantization_config": { + "quant_method": "gptq" + } + }, f) + + result = detect_quantization_method(tmpdir) + self.assertIsNone(result) + + def test_returns_none_for_config_without_quant_config(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, "w") as f: + json.dump({"model_type": "llama"}, f) + + result = detect_quantization_method(tmpdir) + self.assertIsNone(result) + + def test_returns_none_for_malformed_config_json(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, "w") as f: + f.write("not valid json{{{") + + result = detect_quantization_method(tmpdir) + self.assertIsNone(result) + + def test_modelslim_takes_priority_over_compressed_tensors(self): + """When both ModelSlim config and compressed-tensors config exist, + ModelSlim should take priority.""" + with tempfile.TemporaryDirectory() as tmpdir: + modelslim_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME) + with open(modelslim_path, "w") as f: + json.dump({"layer.weight": "INT8"}, f) + + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, "w") as f: + json.dump({ + "quantization_config": { + "quant_method": "compressed-tensors" + } + }, f) + + result = detect_quantization_method(tmpdir) + self.assertEqual(result, ASCEND_QUANTIZATION_METHOD) + + +class TestMaybeAutoDetectQuantization(TestBase): + + def _make_vllm_config(self, model_path="/fake/model", + quantization=None, revision=None): + vllm_config = MagicMock() + vllm_config.model_config.model = model_path + vllm_config.model_config.quantization = quantization + vllm_config.model_config.revision = revision + return vllm_config + + @patch("vllm_ascend.quantization.utils.detect_quantization_method", + return_value=None) + def test_no_detection_does_nothing(self, mock_detect): + vllm_config = self._make_vllm_config() + maybe_auto_detect_quantization(vllm_config) + self.assertIsNone(vllm_config.model_config.quantization) + + @patch("vllm_ascend.quantization.utils.detect_quantization_method", + return_value=ASCEND_QUANTIZATION_METHOD) + def test_user_specified_same_method_no_change(self, mock_detect): + vllm_config = self._make_vllm_config( + quantization=ASCEND_QUANTIZATION_METHOD) + maybe_auto_detect_quantization(vllm_config) + self.assertEqual(vllm_config.model_config.quantization, + ASCEND_QUANTIZATION_METHOD) + + @patch("vllm.config.VllmConfig._get_quantization_config", + return_value=MagicMock()) + @patch("vllm_ascend.quantization.utils.detect_quantization_method", + return_value=ASCEND_QUANTIZATION_METHOD) + def test_auto_detect_sets_quantization_and_logs_info( + self, mock_detect, mock_get_quant_config): + """When no --quantization is specified but ModelSlim config is found, + the method should auto-set quantization and emit an INFO log.""" + vllm_config = self._make_vllm_config( + model_path="/fake/quant_model", quantization=None) + + with self.assertLogs("vllm_ascend.quantization.utils", + level=logging.INFO) as cm: + maybe_auto_detect_quantization(vllm_config) + + self.assertEqual(vllm_config.model_config.quantization, + ASCEND_QUANTIZATION_METHOD) + log_output = "\n".join(cm.output) + self.assertIn("Auto-detected quantization method", log_output) + self.assertIn(ASCEND_QUANTIZATION_METHOD, log_output) + self.assertIn("/fake/quant_model", log_output) + + @patch("vllm_ascend.quantization.utils.detect_quantization_method", + return_value=ASCEND_QUANTIZATION_METHOD) + def test_user_mismatch_logs_warning(self, mock_detect): + """When user specifies a different method than auto-detected, + a WARNING should be emitted and user's choice should be respected.""" + vllm_config = self._make_vllm_config( + model_path="/fake/quant_model", + quantization=COMPRESSED_TENSORS_METHOD) + + with self.assertLogs("vllm_ascend.quantization.utils", + level=logging.WARNING) as cm: + maybe_auto_detect_quantization(vllm_config) + + self.assertEqual(vllm_config.model_config.quantization, + COMPRESSED_TENSORS_METHOD) + log_output = "\n".join(cm.output) + self.assertIn("Auto-detected quantization method", log_output) + self.assertIn(ASCEND_QUANTIZATION_METHOD, log_output) + self.assertIn(COMPRESSED_TENSORS_METHOD, log_output) + + @patch("vllm_ascend.quantization.utils.detect_quantization_method", + return_value=None) + def test_no_detection_emits_no_log(self, mock_detect): + """When no quantization is detected, no log should be emitted.""" + vllm_config = self._make_vllm_config(quantization=None) + logger_name = "vllm_ascend.quantization.utils" + + with self.assertRaises(AssertionError): + # assertLogs raises AssertionError when no logs are emitted + with self.assertLogs(logger_name, level=logging.DEBUG): + maybe_auto_detect_quantization(vllm_config) + + self.assertIsNone(vllm_config.model_config.quantization) + + @patch("vllm.config.VllmConfig._get_quantization_config", + return_value=MagicMock()) + @patch("vllm_ascend.quantization.utils.detect_quantization_method", + return_value=ASCEND_QUANTIZATION_METHOD) + def test_passes_revision_to_detect(self, mock_detect, mock_get_quant): + """Verify that model revision is forwarded to detect_quantization_method.""" + vllm_config = self._make_vllm_config( + model_path="org/model-name", revision="v1.0", quantization=None) + maybe_auto_detect_quantization(vllm_config) + mock_detect.assert_called_once_with("org/model-name", revision="v1.0") diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 3b75d32e..0f91ca4d 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -125,13 +125,14 @@ class TestNPUPlatform(TestBase): self.assertIsNone(self.platform.inference_mode()) mock_inference_mode.assert_called_once() + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") @patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.utils.update_aclgraph_sizes") @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("os.environ", {}) @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") def test_check_and_update_config_basic_config_update( - self, mock_init_recompute, mock_soc_version, mock_update_acl, mock_init_ascend + self, mock_init_recompute, mock_soc_version, mock_update_acl, mock_init_ascend, mock_auto_detect ): mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() @@ -155,11 +156,12 @@ class TestNPUPlatform(TestBase): mock_init_ascend.assert_called_once_with(vllm_config) + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") def test_check_and_update_config_no_model_config_warning( - self, mock_init_recompute, mock_init_ascend, mock_soc_version + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect ): mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() @@ -181,10 +183,11 @@ class TestNPUPlatform(TestBase): self.assertTrue("Model config is missing" in cm.output[0]) + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") - def test_check_and_update_config_enforce_eager_mode(self, mock_init_recompute, mock_init_ascend, mock_soc_version): + def test_check_and_update_config_enforce_eager_mode(self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect): mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.model_config.enforce_eager = True @@ -215,11 +218,12 @@ class TestNPUPlatform(TestBase): CUDAGraphMode.NONE, ) + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") def test_check_and_update_config_unsupported_compilation_level( - self, mock_init_recompute, mock_init_ascend, mock_soc_version + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect ): mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() @@ -253,9 +257,10 @@ class TestNPUPlatform(TestBase): ) @pytest.mark.skip("Revert me when vllm support setting cudagraph_mode on oot platform") + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") - def test_check_and_update_config_unsupported_cudagraph_mode(self, mock_init_ascend, mock_soc_version): + def test_check_and_update_config_unsupported_cudagraph_mode(self, mock_init_ascend, mock_soc_version, mock_auto_detect): mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.model_config.enforce_eager = False @@ -277,11 +282,12 @@ class TestNPUPlatform(TestBase): CUDAGraphMode.NONE, ) + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") def test_check_and_update_config_cache_config_block_size( - self, mock_init_recompute, mock_init_ascend, mock_soc_version + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect ): mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() @@ -301,11 +307,12 @@ class TestNPUPlatform(TestBase): self.assertEqual(vllm_config.cache_config.block_size, 128) + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") def test_check_and_update_config_v1_worker_class_selection( - self, mock_init_recompute, mock_init_ascend, mock_soc_version + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect ): mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() @@ -336,10 +343,11 @@ class TestNPUPlatform(TestBase): "vllm_ascend.xlite.xlite_worker.XliteWorker", ) + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") @patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType._310P) @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") - def test_check_and_update_config_310p_no_custom_ops(self, mock_init_recompute, mock_soc_version, mock_init_ascend): + def test_check_and_update_config_310p_no_custom_ops(self, mock_init_recompute, mock_soc_version, mock_init_ascend, mock_auto_detect): mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.compilation_config.custom_ops = [] diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index bdf12811..55e1408c 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -178,6 +178,11 @@ class NPUPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + from vllm_ascend.quantization.utils import maybe_auto_detect_quantization + + if vllm_config.model_config is not None: + maybe_auto_detect_quantization(vllm_config) + # initialize ascend config from vllm additional_config cls._fix_incompatible_config(vllm_config) ascend_config = init_ascend_config(vllm_config) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 0945d1a3..eb9dddb1 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -21,6 +21,9 @@ This module provides the AscendModelSlimConfig class for parsing quantization configs generated by the ModelSlim tool, along with model-specific mappings. """ +import glob +import json +import os from collections.abc import Mapping from types import MappingProxyType from typing import Any, Optional @@ -39,6 +42,9 @@ from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD from .methods import get_scheme_class +# The config filename that ModelSlim generates after quantizing a model. +MODELSLIM_CONFIG_FILENAME = "quant_model_description.json" + logger = init_logger(__name__) # key: model_type @@ -397,9 +403,9 @@ class AscendModelSlimConfig(QuantizationConfig): quantized using the ModelSlim tool. """ - def __init__(self, quant_config: dict[str, Any]): + def __init__(self, quant_config: dict[str, Any] | None = None): super().__init__() - self.quant_description = quant_config + self.quant_description = quant_config if quant_config is not None else {} # TODO(whx): remove this adaptation after adding "shared_head" # to prefix of DeepSeekShareHead in vLLM. extra_quant_dict = {} @@ -433,7 +439,12 @@ class AscendModelSlimConfig(QuantizationConfig): @classmethod def get_config_filenames(cls) -> list[str]: - return ["quant_model_description.json"] + # Return empty list so that vllm's get_quant_config() skips the + # file-based lookup (which raises an unfriendly "Cannot find the + # config file for ascend" error when the model is not quantized). + # Instead, the config file is loaded in maybe_update_config(), + # which can provide a user-friendly error message. + return [] @classmethod def from_config(cls, config: dict[str, Any]) -> "AscendModelSlimConfig": @@ -604,5 +615,108 @@ class AscendModelSlimConfig(QuantizationConfig): assert is_skipped is not None return is_skipped + def maybe_update_config(self, model_name: str, revision: str | None = None) -> None: + """Load the ModelSlim quantization config from model directory. + + This method is called by vllm after get_quant_config() returns + successfully. Since we return an empty list from get_config_filenames() + to bypass vllm's built-in file lookup, we do the actual config loading + here and provide user-friendly error messages when the config is missing. + + Works with both local directories (``/path/to/model``) and remote + repository identifiers (``org/model-name``). For remote repos the + lookup goes through the HuggingFace / ModelScope cache via + ``get_model_file`` to fetch the config if not already cached. + + Args: + model_name: Path to the model directory or HuggingFace / + ModelScope repo id. + revision: Optional revision (branch, tag, or commit hash) for + remote repos. + """ + from vllm_ascend.quantization.utils import get_model_file + + # If quant_description is already populated (e.g. from from_config()), + # there is nothing to do. + if self.quant_description: + return + + # Try to get the config file (local or remote) + config_path = get_model_file(model_name, MODELSLIM_CONFIG_FILENAME, revision=revision) + + if config_path is not None: + with open(config_path) as f: + self.quant_description = json.load(f) + self._apply_extra_quant_adaptations() + return + + # Collect diagnostic info for the error message + json_names: list[str] = [] + if os.path.isdir(model_name): + json_files = glob.glob(os.path.join(model_name, "*.json")) + json_names = [os.path.basename(f) for f in json_files] + + # Config file not found - raise a friendly error message + raise ValueError( + "\n" + + "=" * 80 + + "\n" + + "ERROR: ModelSlim Quantization Config Not Found\n" + + "=" * 80 + + "\n" + + "\n" + + f"You have enabled '--quantization {ASCEND_QUANTIZATION_METHOD}' " + + "(ModelSlim quantization),\n" + + f"but the model '{model_name}' does not contain the required\n" + + f"quantization config file ('{MODELSLIM_CONFIG_FILENAME}').\n" + + "\n" + + "This usually means the model weights are NOT quantized by " + + "ModelSlim.\n" + + "\n" + + "Please choose one of the following solutions:\n" + + "\n" + + " Solution 1: Remove the quantization option " + + "(for float/unquantized models)\n" + + " " + + "-" * 58 + + "\n" + + f" Remove '--quantization {ASCEND_QUANTIZATION_METHOD}' from " + + "your command if you want to\n" + + " run the model with the original (float) weights.\n" + + "\n" + + " Example:\n" + + f" vllm serve {model_name}\n" + + "\n" + + " Solution 2: Quantize your model weights with ModelSlim first\n" + + " " + + "-" * 58 + + "\n" + + " Use the ModelSlim tool to quantize your model weights " + + "before deployment.\n" + + " After quantization, the model directory should contain " + + f"'{MODELSLIM_CONFIG_FILENAME}'.\n" + + " For more information, please refer to:\n" + + " https://gitee.com/ascend/msit/tree/master/msmodelslim\n" + + "\n" + + (f" (Found JSON files in model directory: {json_names})\n" if json_names else "") + + "=" * 80 + ) + + def _apply_extra_quant_adaptations(self) -> None: + """Apply extra adaptations to the quant_description dict. + + This handles known key transformations such as shared_head and + weight_packed mappings. + """ + extra_quant_dict = {} + for k in self.quant_description: + if "shared_head" in k: + new_k = k.replace(".shared_head.", ".") + extra_quant_dict[new_k] = self.quant_description[k] + if "weight_packed" in k: + new_k = k.replace("weight_packed", "weight") + extra_quant_dict[new_k] = self.quant_description[k] + self.quant_description.update(extra_quant_dict) + def get_scaled_act_names(self) -> list[str]: return [] diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py new file mode 100644 index 00000000..c1042761 --- /dev/null +++ b/vllm_ascend/quantization/utils.py @@ -0,0 +1,201 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import json +from pathlib import Path + +from vllm import envs +from vllm.logger import init_logger + +from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD + +logger = init_logger(__name__) + + +def get_model_file( + model: str | Path, + filename: str, + revision: str | None = None, +) -> Path | None: + """Get a file from local model directory or download from remote repo. + + This function handles both local paths and remote repository IDs, + automatically downloading files from HuggingFace Hub or ModelScope + if they are not already cached. + + Args: + model: Local directory path or HuggingFace/ModelScope repo id. + filename: Name of the file to retrieve (e.g., "config.json"). + revision: Optional revision (branch, tag, or commit hash) for remote repos. + + Returns: + Path to the file if found, None otherwise. + """ + # Check if it's a local path + model_path = Path(model) if isinstance(model, str) else model + if model_path.exists(): + file_path = model_path / filename + return file_path if file_path.exists() else None + + # Remote repo: try to download from HF Hub or ModelScope + try: + if envs.VLLM_USE_MODELSCOPE: + from modelscope.hub.file_download import model_file_download # type: ignore[import-untyped] + + downloaded_path = model_file_download( + model_id=str(model), + file_path=filename, + revision=revision, + ) + return Path(downloaded_path) + else: + from huggingface_hub import hf_hub_download + + downloaded_path = hf_hub_download( + repo_id=str(model), + filename=filename, + revision=revision, + ) + return Path(downloaded_path) + except Exception as e: + logger.debug(f"Could not download {filename} from {model}: {e}") + return None + + +def detect_quantization_method(model: str, revision: str | None = None) -> str | None: + """Auto-detect the quantization method from model files. + + This function performs a lightweight check (JSON files only — no + .safetensors or .bin inspection) to determine which quantization + method was used to produce the weights in *model*. + + Works with both local directories (``/path/to/model``) and remote + repository identifiers (``org/model-name``). For remote repos the + lookup goes through the HuggingFace / ModelScope cache, downloading + config files if not already cached. + + Detection priority: + 1. **ModelSlim (Ascend)** – ``quant_model_description.json`` exists. + 2. **LLM-Compressor (compressed-tensors)** – ``config.json`` contains + a ``quantization_config`` section with + ``"quant_method": "compressed-tensors"``. + 3. **None** – neither condition is met; the caller should fall back to + the default (float) behaviour. + + Args: + model: Local directory path **or** HuggingFace / ModelScope repo id. + revision: Optional model revision (branch, tag, or commit id). + + Returns: + ``"ascend"`` for ModelSlim models, + ``"compressed-tensors"`` for LLM-Compressor models, + or ``None`` if no quantization signature is found. + """ + from vllm_ascend.quantization.modelslim_config import MODELSLIM_CONFIG_FILENAME + + # Case 1: ModelSlim — look for quant_model_description.json + modelslim_path = get_model_file(model, MODELSLIM_CONFIG_FILENAME, revision=revision) + if modelslim_path is not None: + return ASCEND_QUANTIZATION_METHOD + + # Case 2: LLM-Compressor — look for compressed-tensors in config.json + config_path = get_model_file(model, "config.json", revision=revision) + if config_path is not None: + try: + with open(config_path) as f: + config = json.load(f) + quant_cfg = config.get("quantization_config") + if isinstance(quant_cfg, dict): + quant_method = quant_cfg.get("quant_method", "") + if quant_method == COMPRESSED_TENSORS_METHOD: + return COMPRESSED_TENSORS_METHOD + except (json.JSONDecodeError, OSError): + pass + + # Case 3: No quantization signature found. + return None + + +def maybe_auto_detect_quantization(vllm_config) -> None: + """Auto-detect and apply the quantization method on *vllm_config*. + + This should be called during engine initialisation (from + ``NPUPlatform.check_and_update_config``) **after** ``VllmConfig`` has been + created but **before** heavy weights are loaded. + + Because ``check_and_update_config`` runs *after* + ``VllmConfig.__post_init__`` has already evaluated + ``_get_quantization_config`` (which returned ``None`` when + ``model_config.quantization`` was not set), we must: + + 1. Set ``model_config.quantization`` to the detected value. + 2. Recreate ``vllm_config.quant_config`` so that the quantization + pipeline (``get_quant_config`` → ``QuantizationConfig`` → + ``get_quant_method`` for every layer) is properly initialised. + + Rules: + * If the user explicitly set ``--quantization``, that value is + respected. A warning is emitted when the detected method differs. + * If no ``--quantization`` was given, the detected method (if any) is + applied automatically. + + Args: + vllm_config: A ``vllm.config.VllmConfig`` instance (mutable). + """ + model_config = vllm_config.model_config + model = model_config.model + revision = model_config.revision + user_quant = model_config.quantization + detected = detect_quantization_method(model, revision=revision) + + if detected is None: + # No quantization signature found — nothing to do. + return + + if user_quant is not None: + # User explicitly specified a quantization method. + if user_quant != detected: + logger.warning( + "Auto-detected quantization method '%s' from model " + "files for '%s', but user explicitly specified " + "'--quantization %s'. Respecting the user-specified " + "value. If you encounter errors during model loading, " + "consider using '--quantization %s' instead.", + detected, + model, + user_quant, + detected, + ) + return + + # No user-specified quantization — apply auto-detected value. + model_config.quantization = detected + logger.info( + "Auto-detected quantization method '%s' from model files " + "for '%s'. To override, pass '--quantization ' explicitly.", + detected, + model, + ) + + # Recreate quant_config on VllmConfig. The original __post_init__ + # already ran _get_quantization_config(), but at that point + # model_config.quantization was None so it returned None. Now that + # we've set it, we need to build the actual QuantizationConfig so the + # downstream model-loading code can use it. + from vllm.config import VllmConfig as _VllmConfig + + vllm_config.quant_config = _VllmConfig._get_quantization_config(model_config, vllm_config.load_config)