[Feature][Quant] Auto-detect quantization format from model files (#6645)

## Summary

- Add automatic quantization format detection, eliminating the need to
manually specify `--quantization` when serving quantized models.
- The detection inspects only lightweight JSON files
(`quant_model_description.json` and `config.json`) at engine
initialization time, with no `.safetensors` reads.
- User-explicit `--quantization` flags are always respected;
auto-detection only applies when the flag is omitted.

## Details

**Detection priority:**
1. `quant_model_description.json` exists → `quantization="ascend"`
(ModelSlim)
2. `config.json` contains `"quant_method": "compressed-tensors"` →
`quantization="compressed-tensors"` (LLM-Compressor)
3. Neither → default float behavior

**Technical approach:**
Hooked into `NPUPlatform.check_and_update_config()` to run detection
after `VllmConfig.__post_init__`. Since `quant_config` is already `None`
at that point, we explicitly recreate it via
`VllmConfig._get_quantization_config()` to trigger the full quantization
initialization pipeline.

## Files Changed

| File | Description |
|------|-------------|
| `vllm_ascend/quantization/utils.py` | Added
`detect_quantization_method()` and `maybe_auto_detect_quantization()` |
| `vllm_ascend/platform.py` | Integrated auto-detection in
`check_and_update_config()` |
| `vllm_ascend/quantization/modelslim_config.py` | Improved error
handling for weight loading |
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
Cao Yi
2026-02-26 10:59:25 +08:00
committed by GitHub
parent bc1622338c
commit 3953dcf784
7 changed files with 587 additions and 13 deletions

View File

@@ -49,6 +49,43 @@ def test_qwen3_w8a8_quant():
name_1="vllm_quant_w8a8_outputs", name_1="vllm_quant_w8a8_outputs",
) )
# fmt: off
def test_qwen3_w8a8_quant_auto_detect():
"""Test that ModelSlim quantization is auto-detected without --quantization.
Uses the same W8A8 model as test_qwen3_w8a8_quant but omits the
quantization parameter, verifying that the auto-detection in
maybe_auto_detect_quantization() picks up quant_model_description.json
and produces identical results.
"""
max_tokens = 5
example_prompts = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
]
vllm_target_outputs = [([
85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323,
13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 311, 387
], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be'
)]
# fmt: on
with VllmRunner(
"vllm-ascend/Qwen3-0.6B-W8A8",
max_model_len=8192,
gpu_memory_utilization=0.7,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as vllm_model:
vllm_quant_auto_detect_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=vllm_target_outputs,
outputs_1_lst=vllm_quant_auto_detect_outputs,
name_0="vllm_target_outputs",
name_1="vllm_quant_auto_detect_outputs",
)
# fmt: off # fmt: off
def test_qwen3_dense_w8a16(): def test_qwen3_dense_w8a16():
max_tokens = 5 max_tokens = 5

View File

@@ -1,3 +1,6 @@
import json
import os
import tempfile
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -6,7 +9,10 @@ from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.quantization.modelslim_config import AscendModelSlimConfig from vllm_ascend.quantization.modelslim_config import (
MODELSLIM_CONFIG_FILENAME,
AscendModelSlimConfig,
)
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, vllm_version_is from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, vllm_version_is
if vllm_version_is("v0.15.0"): if vllm_version_is("v0.15.0"):
@@ -54,7 +60,7 @@ class TestAscendModelSlimConfig(TestBase):
def test_get_config_filenames(self): def test_get_config_filenames(self):
filenames = AscendModelSlimConfig.get_config_filenames() filenames = AscendModelSlimConfig.get_config_filenames()
self.assertEqual(filenames, ["quant_model_description.json"]) self.assertEqual(filenames, [])
def test_from_config(self): def test_from_config(self):
config = AscendModelSlimConfig.from_config(self.sample_config) config = AscendModelSlimConfig.from_config(self.sample_config)
@@ -162,5 +168,90 @@ class TestAscendModelSlimConfig(TestBase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
config.is_layer_skipped_ascend("fused_layer", fused_mapping) config.is_layer_skipped_ascend("fused_layer", fused_mapping)
def test_init_with_none_config(self):
config = AscendModelSlimConfig(None)
self.assertEqual(config.quant_description, {})
def test_init_with_default_config(self):
config = AscendModelSlimConfig()
self.assertEqual(config.quant_description, {})
def test_maybe_update_config_already_populated(self):
# When quant_description is already populated, should be a no-op
self.assertTrue(len(self.ascend_config.quant_description) > 0)
self.ascend_config.maybe_update_config("/some/model/path")
# quant_description should remain unchanged
self.assertEqual(self.ascend_config.quant_description,
self.sample_config)
def test_maybe_update_config_loads_from_file(self):
config = AscendModelSlimConfig()
self.assertEqual(config.quant_description, {})
quant_data = {"layer1.weight": "INT8", "layer2.weight": "FLOAT"}
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME)
with open(config_path, "w") as f:
json.dump(quant_data, f)
config.maybe_update_config(tmpdir)
self.assertEqual(config.quant_description, quant_data)
def test_maybe_update_config_raises_when_file_missing(self):
config = AscendModelSlimConfig()
with tempfile.TemporaryDirectory() as tmpdir:
with self.assertRaises(ValueError) as ctx:
config.maybe_update_config(tmpdir)
error_msg = str(ctx.exception)
self.assertIn("ModelSlim Quantization Config Not Found", error_msg)
self.assertIn(MODELSLIM_CONFIG_FILENAME, error_msg)
def test_maybe_update_config_raises_with_json_files_listed(self):
config = AscendModelSlimConfig()
with tempfile.TemporaryDirectory() as tmpdir:
# Create a dummy json file that is NOT the config file
dummy_path = os.path.join(tmpdir, "config.json")
with open(dummy_path, "w") as f:
json.dump({"dummy": True}, f)
with self.assertRaises(ValueError) as ctx:
config.maybe_update_config(tmpdir)
error_msg = str(ctx.exception)
self.assertIn("config.json", error_msg)
def test_maybe_update_config_non_directory_raises(self):
config = AscendModelSlimConfig()
with self.assertRaises(ValueError) as ctx:
config.maybe_update_config("not_a_real_directory_path")
error_msg = str(ctx.exception)
self.assertIn("ModelSlim Quantization Config Not Found", error_msg)
def test_apply_extra_quant_adaptations_shared_head(self):
config = AscendModelSlimConfig()
config.quant_description = {
"model.layers.0.shared_head.weight": "INT8",
}
config._apply_extra_quant_adaptations()
self.assertIn("model.layers.0.weight", config.quant_description)
self.assertEqual(config.quant_description["model.layers.0.weight"],
"INT8")
def test_apply_extra_quant_adaptations_weight_packed(self):
config = AscendModelSlimConfig()
config.quant_description = {
"model.layers.0.weight_packed": "INT8",
}
config._apply_extra_quant_adaptations()
self.assertIn("model.layers.0.weight", config.quant_description)
self.assertEqual(config.quant_description["model.layers.0.weight"],
"INT8")
def test_get_scaled_act_names(self): def test_get_scaled_act_names(self):
self.assertEqual(self.ascend_config.get_scaled_act_names(), []) self.assertEqual(self.ascend_config.get_scaled_act_names(), [])

View File

@@ -0,0 +1,182 @@
import json
import logging
import os
import tempfile
from unittest.mock import MagicMock, patch
from tests.ut.base import TestBase
from vllm_ascend.quantization.modelslim_config import MODELSLIM_CONFIG_FILENAME
from vllm_ascend.quantization.utils import (
detect_quantization_method,
maybe_auto_detect_quantization,
)
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
class TestDetectQuantizationMethod(TestBase):
def test_returns_none_for_non_directory(self):
result = detect_quantization_method("/non/existent/path")
self.assertIsNone(result)
def test_detects_modelslim(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME)
with open(config_path, "w") as f:
json.dump({"layer.weight": "INT8"}, f)
result = detect_quantization_method(tmpdir)
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
def test_detects_compressed_tensors(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "w") as f:
json.dump({
"quantization_config": {
"quant_method": "compressed-tensors"
}
}, f)
result = detect_quantization_method(tmpdir)
self.assertEqual(result, COMPRESSED_TENSORS_METHOD)
def test_returns_none_for_no_quant(self):
with tempfile.TemporaryDirectory() as tmpdir:
result = detect_quantization_method(tmpdir)
self.assertIsNone(result)
def test_returns_none_for_non_compressed_tensors_quant_method(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "w") as f:
json.dump({
"quantization_config": {
"quant_method": "gptq"
}
}, f)
result = detect_quantization_method(tmpdir)
self.assertIsNone(result)
def test_returns_none_for_config_without_quant_config(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "w") as f:
json.dump({"model_type": "llama"}, f)
result = detect_quantization_method(tmpdir)
self.assertIsNone(result)
def test_returns_none_for_malformed_config_json(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "w") as f:
f.write("not valid json{{{")
result = detect_quantization_method(tmpdir)
self.assertIsNone(result)
def test_modelslim_takes_priority_over_compressed_tensors(self):
"""When both ModelSlim config and compressed-tensors config exist,
ModelSlim should take priority."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create ModelSlim config
modelslim_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME)
with open(modelslim_path, "w") as f:
json.dump({"layer.weight": "INT8"}, f)
# Create compressed-tensors config
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "w") as f:
json.dump({
"quantization_config": {
"quant_method": "compressed-tensors"
}
}, f)
result = detect_quantization_method(tmpdir)
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
class TestMaybeAutoDetectQuantization(TestBase):
def _make_vllm_config(self, model_path="/fake/model", quantization=None):
vllm_config = MagicMock()
vllm_config.model_config.model = model_path
vllm_config.model_config.quantization = quantization
return vllm_config
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
return_value=None)
def test_no_detection_does_nothing(self, mock_detect):
vllm_config = self._make_vllm_config()
maybe_auto_detect_quantization(vllm_config)
# quantization should remain unchanged
self.assertIsNone(vllm_config.model_config.quantization)
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
return_value=ASCEND_QUANTIZATION_METHOD)
def test_user_specified_same_method_no_change(self, mock_detect):
vllm_config = self._make_vllm_config(
quantization=ASCEND_QUANTIZATION_METHOD)
maybe_auto_detect_quantization(vllm_config)
self.assertEqual(vllm_config.model_config.quantization,
ASCEND_QUANTIZATION_METHOD)
@patch("vllm.config.VllmConfig._get_quantization_config",
return_value=MagicMock())
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
return_value=ASCEND_QUANTIZATION_METHOD)
def test_auto_detect_sets_quantization_and_logs_info(
self, mock_detect, mock_get_quant_config):
"""When no --quantization is specified but ModelSlim config is found,
the method should auto-set quantization and emit an INFO log."""
vllm_config = self._make_vllm_config(
model_path="/fake/quant_model", quantization=None)
with self.assertLogs("vllm_ascend.quantization.utils",
level=logging.INFO) as cm:
maybe_auto_detect_quantization(vllm_config)
self.assertEqual(vllm_config.model_config.quantization,
ASCEND_QUANTIZATION_METHOD)
log_output = "\n".join(cm.output)
self.assertIn("Auto-detected quantization method", log_output)
self.assertIn(ASCEND_QUANTIZATION_METHOD, log_output)
self.assertIn("/fake/quant_model", log_output)
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
return_value=ASCEND_QUANTIZATION_METHOD)
def test_user_mismatch_logs_warning(self, mock_detect):
"""When user specifies a different method than auto-detected,
a WARNING should be emitted and user's choice should be respected."""
vllm_config = self._make_vllm_config(
model_path="/fake/quant_model",
quantization=COMPRESSED_TENSORS_METHOD)
with self.assertLogs("vllm_ascend.quantization.utils",
level=logging.WARNING) as cm:
maybe_auto_detect_quantization(vllm_config)
# User's choice is respected
self.assertEqual(vllm_config.model_config.quantization,
COMPRESSED_TENSORS_METHOD)
log_output = "\n".join(cm.output)
self.assertIn("Auto-detected quantization method", log_output)
self.assertIn(ASCEND_QUANTIZATION_METHOD, log_output)
self.assertIn(COMPRESSED_TENSORS_METHOD, log_output)
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
return_value=None)
def test_no_detection_emits_no_log(self, mock_detect):
"""When no quantization is detected, no log should be emitted."""
vllm_config = self._make_vllm_config(quantization=None)
logger_name = "vllm_ascend.quantization.utils"
with self.assertRaises(AssertionError):
# assertLogs raises AssertionError when no logs are emitted
with self.assertLogs(logger_name, level=logging.DEBUG):
maybe_auto_detect_quantization(vllm_config)
self.assertIsNone(vllm_config.model_config.quantization)

View File

@@ -116,13 +116,14 @@ class TestNPUPlatform(TestBase):
self.assertIsNone(self.platform.inference_mode()) self.assertIsNone(self.platform.inference_mode())
mock_inference_mode.assert_called_once() mock_inference_mode.assert_called_once()
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.utils.update_aclgraph_sizes") @patch("vllm_ascend.utils.update_aclgraph_sizes")
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
@patch("os.environ", {}) @patch("os.environ", {})
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
def test_check_and_update_config_basic_config_update( def test_check_and_update_config_basic_config_update(
self, mock_init_recompute, mock_soc_version, mock_update_acl, mock_init_ascend self, mock_init_recompute, mock_soc_version, mock_update_acl, mock_init_ascend, mock_auto_detect
): ):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
@@ -146,11 +147,12 @@ class TestNPUPlatform(TestBase):
mock_init_ascend.assert_called_once_with(vllm_config) mock_init_ascend.assert_called_once_with(vllm_config)
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
def test_check_and_update_config_no_model_config_warning( def test_check_and_update_config_no_model_config_warning(
self, mock_init_recompute, mock_init_ascend, mock_soc_version self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
): ):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
@@ -172,10 +174,11 @@ class TestNPUPlatform(TestBase):
self.assertTrue("Model config is missing" in cm.output[0]) self.assertTrue("Model config is missing" in cm.output[0])
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
def test_check_and_update_config_enforce_eager_mode(self, mock_init_recompute, mock_init_ascend, mock_soc_version): def test_check_and_update_config_enforce_eager_mode(self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = True vllm_config.model_config.enforce_eager = True
@@ -206,11 +209,12 @@ class TestNPUPlatform(TestBase):
CUDAGraphMode.NONE, CUDAGraphMode.NONE,
) )
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
def test_check_and_update_config_unsupported_compilation_level( def test_check_and_update_config_unsupported_compilation_level(
self, mock_init_recompute, mock_init_ascend, mock_soc_version self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
): ):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
@@ -244,9 +248,10 @@ class TestNPUPlatform(TestBase):
) )
@pytest.mark.skip("Revert me when vllm support setting cudagraph_mode on oot platform") @pytest.mark.skip("Revert me when vllm support setting cudagraph_mode on oot platform")
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_unsupported_cudagraph_mode(self, mock_init_ascend, mock_soc_version): def test_check_and_update_config_unsupported_cudagraph_mode(self, mock_init_ascend, mock_soc_version, mock_auto_detect):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = False vllm_config.model_config.enforce_eager = False
@@ -268,11 +273,12 @@ class TestNPUPlatform(TestBase):
CUDAGraphMode.NONE, CUDAGraphMode.NONE,
) )
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
def test_check_and_update_config_cache_config_block_size( def test_check_and_update_config_cache_config_block_size(
self, mock_init_recompute, mock_init_ascend, mock_soc_version self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
): ):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
@@ -292,11 +298,12 @@ class TestNPUPlatform(TestBase):
self.assertEqual(vllm_config.cache_config.block_size, 128) self.assertEqual(vllm_config.cache_config.block_size, 128)
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
def test_check_and_update_config_v1_worker_class_selection( def test_check_and_update_config_v1_worker_class_selection(
self, mock_init_recompute, mock_init_ascend, mock_soc_version self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
): ):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
@@ -327,10 +334,11 @@ class TestNPUPlatform(TestBase):
"vllm_ascend.xlite.xlite_worker.XliteWorker", "vllm_ascend.xlite.xlite_worker.XliteWorker",
) )
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType._310P) @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType._310P)
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
def test_check_and_update_config_310p_no_custom_ops(self, mock_init_recompute, mock_soc_version, mock_init_ascend): def test_check_and_update_config_310p_no_custom_ops(self, mock_init_recompute, mock_soc_version, mock_init_ascend, mock_auto_detect):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.compilation_config.custom_ops = [] vllm_config.compilation_config.custom_ops = []

View File

@@ -171,6 +171,11 @@ class NPUPlatform(Platform):
@classmethod @classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm_ascend.quantization.utils import maybe_auto_detect_quantization
if vllm_config.model_config is not None:
maybe_auto_detect_quantization(vllm_config)
# initialize ascend config from vllm additional_config # initialize ascend config from vllm additional_config
cls._fix_incompatible_config(vllm_config) cls._fix_incompatible_config(vllm_config)
ascend_config = init_ascend_config(vllm_config) ascend_config = init_ascend_config(vllm_config)

View File

@@ -21,6 +21,9 @@ This module provides the AscendModelSlimConfig class for parsing quantization
configs generated by the ModelSlim tool, along with model-specific mappings. configs generated by the ModelSlim tool, along with model-specific mappings.
""" """
import glob
import json
import os
from collections.abc import Mapping from collections.abc import Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Optional from typing import Any, Optional
@@ -39,6 +42,9 @@ from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
from .methods import get_scheme_class from .methods import get_scheme_class
# The config filename that ModelSlim generates after quantizing a model.
MODELSLIM_CONFIG_FILENAME = "quant_model_description.json"
logger = init_logger(__name__) logger = init_logger(__name__)
# key: model_type # key: model_type
@@ -310,9 +316,9 @@ class AscendModelSlimConfig(QuantizationConfig):
quantized using the ModelSlim tool. quantized using the ModelSlim tool.
""" """
def __init__(self, quant_config: dict[str, Any]): def __init__(self, quant_config: dict[str, Any] | None = None):
super().__init__() super().__init__()
self.quant_description = quant_config self.quant_description = quant_config if quant_config is not None else {}
# TODO(whx): remove this adaptation after adding "shared_head" # TODO(whx): remove this adaptation after adding "shared_head"
# to prefix of DeepSeekShareHead in vLLM. # to prefix of DeepSeekShareHead in vLLM.
extra_quant_dict = {} extra_quant_dict = {}
@@ -342,7 +348,12 @@ class AscendModelSlimConfig(QuantizationConfig):
@classmethod @classmethod
def get_config_filenames(cls) -> list[str]: def get_config_filenames(cls) -> list[str]:
return ["quant_model_description.json"] # Return empty list so that vllm's get_quant_config() skips the
# file-based lookup (which raises an unfriendly "Cannot find the
# config file for ascend" error when the model is not quantized).
# Instead, the config file is loaded in maybe_update_config(),
# which can provide a user-friendly error message.
return []
@classmethod @classmethod
def from_config(cls, config: dict[str, Any]) -> "AscendModelSlimConfig": def from_config(cls, config: dict[str, Any]) -> "AscendModelSlimConfig":
@@ -456,5 +467,98 @@ class AscendModelSlimConfig(QuantizationConfig):
assert is_skipped is not None assert is_skipped is not None
return is_skipped return is_skipped
def maybe_update_config(self, model_name: str) -> None:
"""Load the ModelSlim quantization config from model directory.
This method is called by vllm after get_quant_config() returns
successfully. Since we return an empty list from get_config_filenames()
to bypass vllm's built-in file lookup, we do the actual config loading
here and provide user-friendly error messages when the config is missing.
Args:
model_name: Path to the model directory or model name.
"""
# If quant_description is already populated (e.g. from from_config()),
# there is nothing to do.
if self.quant_description:
return
# Try to find and load the ModelSlim config file
if os.path.isdir(model_name):
config_path = os.path.join(model_name, MODELSLIM_CONFIG_FILENAME)
if os.path.isfile(config_path):
with open(config_path) as f:
self.quant_description = json.load(f)
self._apply_extra_quant_adaptations()
return
# Check if there are any json files at all to help diagnose
json_files = glob.glob(os.path.join(model_name, "*.json"))
json_names = [os.path.basename(f) for f in json_files]
else:
json_names = []
# Config file not found - raise a friendly error message
raise ValueError(
"\n"
+ "=" * 80
+ "\n"
+ "ERROR: ModelSlim Quantization Config Not Found\n"
+ "=" * 80
+ "\n"
+ "\n"
+ f"You have enabled '--quantization {ASCEND_QUANTIZATION_METHOD}' "
+ "(ModelSlim quantization),\n"
+ f"but the model at '{model_name}' does not contain the required\n"
+ f"quantization config file ('{MODELSLIM_CONFIG_FILENAME}').\n"
+ "\n"
+ "This usually means the model weights are NOT quantized by "
+ "ModelSlim.\n"
+ "\n"
+ "Please choose one of the following solutions:\n"
+ "\n"
+ " Solution 1: Remove the quantization option "
+ "(for float/unquantized models)\n"
+ " "
+ "-" * 58
+ "\n"
+ f" Remove '--quantization {ASCEND_QUANTIZATION_METHOD}' from "
+ "your command if you want to\n"
+ " run the model with the original (float) weights.\n"
+ "\n"
+ " Example:\n"
+ f" vllm serve {model_name}\n"
+ "\n"
+ " Solution 2: Quantize your model weights with ModelSlim first\n"
+ " "
+ "-" * 58
+ "\n"
+ " Use the ModelSlim tool to quantize your model weights "
+ "before deployment.\n"
+ " After quantization, the model directory should contain "
+ f"'{MODELSLIM_CONFIG_FILENAME}'.\n"
+ " For more information, please refer to:\n"
+ " https://gitee.com/ascend/msit/tree/master/msmodelslim\n"
+ "\n"
+ (f" (Found JSON files in model directory: {json_names})\n" if json_names else "")
+ "=" * 80
)
def _apply_extra_quant_adaptations(self) -> None:
"""Apply extra adaptations to the quant_description dict.
This handles known key transformations such as shared_head and
weight_packed mappings.
"""
extra_quant_dict = {}
for k in self.quant_description:
if "shared_head" in k:
new_k = k.replace(".shared_head.", ".")
extra_quant_dict[new_k] = self.quant_description[k]
if "weight_packed" in k:
new_k = k.replace("weight_packed", "weight")
extra_quant_dict[new_k] = self.quant_description[k]
self.quant_description.update(extra_quant_dict)
def get_scaled_act_names(self) -> list[str]: def get_scaled_act_names(self) -> list[str]:
return [] return []

View File

@@ -0,0 +1,147 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import json
import os
from vllm.logger import init_logger
from vllm_ascend.quantization.modelslim_config import MODELSLIM_CONFIG_FILENAME
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
logger = init_logger(__name__)
def detect_quantization_method(model_path: str) -> str | None:
"""Auto-detect the quantization method from model directory files.
This function performs a lightweight check (JSON files and file existence
only — no .safetensors or .bin inspection) to determine which quantization
method was used to produce the weights in *model_path*.
Detection priority:
1. **ModelSlim (Ascend)** ``quant_model_description.json`` exists
in the model directory.
2. **LLM-Compressor (compressed-tensors)** ``config.json`` contains
a ``quantization_config`` section with
``"quant_method": "compressed-tensors"``.
3. **None** neither condition is met; the caller should fall back to
the default (float) behaviour.
Args:
model_path: Path to the local model directory.
Returns:
``"ascend"`` for ModelSlim models,
``"compressed-tensors"`` for LLM-Compressor models,
or ``None`` if no quantization signature is found.
"""
if not os.path.isdir(model_path):
return None
# Case 1: ModelSlim — look for quant_model_description.json
modelslim_config_path = os.path.join(model_path, MODELSLIM_CONFIG_FILENAME)
if os.path.isfile(modelslim_config_path):
return ASCEND_QUANTIZATION_METHOD
# Case 2: LLM-Compressor — look for compressed-tensors in config.json
config_json_path = os.path.join(model_path, "config.json")
if os.path.isfile(config_json_path):
try:
with open(config_json_path) as f:
config = json.load(f)
quant_cfg = config.get("quantization_config")
if isinstance(quant_cfg, dict):
quant_method = quant_cfg.get("quant_method", "")
if quant_method == COMPRESSED_TENSORS_METHOD:
return COMPRESSED_TENSORS_METHOD
except (json.JSONDecodeError, OSError):
# Malformed or unreadable config.json — skip silently.
pass
# Case 3: No quantization signature found.
return None
def maybe_auto_detect_quantization(vllm_config) -> None:
"""Auto-detect and apply the quantization method on *vllm_config*.
This should be called during engine initialisation (from
``NPUPlatform.check_and_update_config``) **after** ``VllmConfig`` has been
created but **before** heavy weights are loaded.
Because ``check_and_update_config`` runs *after*
``VllmConfig.__post_init__`` has already evaluated
``_get_quantization_config`` (which returned ``None`` when
``model_config.quantization`` was not set), we must:
1. Set ``model_config.quantization`` to the detected value.
2. Recreate ``vllm_config.quant_config`` so that the quantization
pipeline (``get_quant_config`` → ``QuantizationConfig`` →
``get_quant_method`` for every layer) is properly initialised.
Rules:
* If the user explicitly set ``--quantization``, that value is
respected. A warning is emitted when the detected method differs.
* If no ``--quantization`` was given, the detected method (if any) is
applied automatically.
Args:
vllm_config: A ``vllm.config.VllmConfig`` instance (mutable).
"""
model_config = vllm_config.model_config
model_path = model_config.model
user_quant = model_config.quantization
detected = detect_quantization_method(model_path)
if detected is None:
# No quantization signature found — nothing to do.
return
if user_quant is not None:
# User explicitly specified a quantization method.
if user_quant != detected:
logger.warning(
"Auto-detected quantization method '%s' from model "
"files at '%s', but user explicitly specified "
"'--quantization %s'. Respecting the user-specified "
"value. If you encounter errors during model loading, "
"consider using '--quantization %s' instead.",
detected,
model_path,
user_quant,
detected,
)
return
# No user-specified quantization — apply auto-detected value.
model_config.quantization = detected
logger.info(
"Auto-detected quantization method '%s' from model files "
"at '%s'. To override, pass '--quantization <method>' explicitly.",
detected,
model_path,
)
# Recreate quant_config on VllmConfig. The original __post_init__
# already ran _get_quantization_config(), but at that point
# model_config.quantization was None so it returned None. Now that
# we've set it, we need to build the actual QuantizationConfig so the
# downstream model-loading code can use it.
from vllm.config import VllmConfig as _VllmConfig
vllm_config.quant_config = _VllmConfig._get_quantization_config(model_config, vllm_config.load_config)