[Feature][Quant] Auto-detect quantization format from model files (#6645)
## Summary
- Add automatic quantization format detection, eliminating the need to
manually specify `--quantization` when serving quantized models.
- The detection inspects only lightweight JSON files
(`quant_model_description.json` and `config.json`) at engine
initialization time, with no `.safetensors` reads.
- User-explicit `--quantization` flags are always respected;
auto-detection only applies when the flag is omitted.
## Details
**Detection priority:**
1. `quant_model_description.json` exists → `quantization="ascend"`
(ModelSlim)
2. `config.json` contains `"quant_method": "compressed-tensors"` →
`quantization="compressed-tensors"` (LLM-Compressor)
3. Neither → default float behavior
**Technical approach:**
Hooked into `NPUPlatform.check_and_update_config()` to run detection
after `VllmConfig.__post_init__`. Since `quant_config` is already `None`
at that point, we explicitly recreate it via
`VllmConfig._get_quantization_config()` to trigger the full quantization
initialization pipeline.
## Files Changed
| File | Description |
|------|-------------|
| `vllm_ascend/quantization/utils.py` | Added
`detect_quantization_method()` and `maybe_auto_detect_quantization()` |
| `vllm_ascend/platform.py` | Integrated auto-detection in
`check_and_update_config()` |
| `vllm_ascend/quantization/modelslim_config.py` | Improved error
handling for weight loading |
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
@@ -49,6 +49,43 @@ def test_qwen3_w8a8_quant():
|
|||||||
name_1="vllm_quant_w8a8_outputs",
|
name_1="vllm_quant_w8a8_outputs",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
def test_qwen3_w8a8_quant_auto_detect():
|
||||||
|
"""Test that ModelSlim quantization is auto-detected without --quantization.
|
||||||
|
|
||||||
|
Uses the same W8A8 model as test_qwen3_w8a8_quant but omits the
|
||||||
|
quantization parameter, verifying that the auto-detection in
|
||||||
|
maybe_auto_detect_quantization() picks up quant_model_description.json
|
||||||
|
and produces identical results.
|
||||||
|
"""
|
||||||
|
max_tokens = 5
|
||||||
|
example_prompts = [
|
||||||
|
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
|
||||||
|
]
|
||||||
|
vllm_target_outputs = [([
|
||||||
|
85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323,
|
||||||
|
13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 311, 387
|
||||||
|
], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be'
|
||||||
|
)]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
"vllm-ascend/Qwen3-0.6B-W8A8",
|
||||||
|
max_model_len=8192,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_quant_auto_detect_outputs = vllm_model.generate_greedy(
|
||||||
|
example_prompts, max_tokens)
|
||||||
|
|
||||||
|
check_outputs_equal(
|
||||||
|
outputs_0_lst=vllm_target_outputs,
|
||||||
|
outputs_1_lst=vllm_quant_auto_detect_outputs,
|
||||||
|
name_0="vllm_target_outputs",
|
||||||
|
name_1="vllm_quant_auto_detect_outputs",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
def test_qwen3_dense_w8a16():
|
def test_qwen3_dense_w8a16():
|
||||||
max_tokens = 5
|
max_tokens = 5
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@@ -6,7 +9,10 @@ from vllm.model_executor.layers.linear import LinearBase
|
|||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||||
from vllm_ascend.quantization.modelslim_config import AscendModelSlimConfig
|
from vllm_ascend.quantization.modelslim_config import (
|
||||||
|
MODELSLIM_CONFIG_FILENAME,
|
||||||
|
AscendModelSlimConfig,
|
||||||
|
)
|
||||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, vllm_version_is
|
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, vllm_version_is
|
||||||
|
|
||||||
if vllm_version_is("v0.15.0"):
|
if vllm_version_is("v0.15.0"):
|
||||||
@@ -54,7 +60,7 @@ class TestAscendModelSlimConfig(TestBase):
|
|||||||
|
|
||||||
def test_get_config_filenames(self):
|
def test_get_config_filenames(self):
|
||||||
filenames = AscendModelSlimConfig.get_config_filenames()
|
filenames = AscendModelSlimConfig.get_config_filenames()
|
||||||
self.assertEqual(filenames, ["quant_model_description.json"])
|
self.assertEqual(filenames, [])
|
||||||
|
|
||||||
def test_from_config(self):
|
def test_from_config(self):
|
||||||
config = AscendModelSlimConfig.from_config(self.sample_config)
|
config = AscendModelSlimConfig.from_config(self.sample_config)
|
||||||
@@ -162,5 +168,90 @@ class TestAscendModelSlimConfig(TestBase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
config.is_layer_skipped_ascend("fused_layer", fused_mapping)
|
config.is_layer_skipped_ascend("fused_layer", fused_mapping)
|
||||||
|
|
||||||
|
def test_init_with_none_config(self):
|
||||||
|
config = AscendModelSlimConfig(None)
|
||||||
|
self.assertEqual(config.quant_description, {})
|
||||||
|
|
||||||
|
def test_init_with_default_config(self):
|
||||||
|
config = AscendModelSlimConfig()
|
||||||
|
self.assertEqual(config.quant_description, {})
|
||||||
|
|
||||||
|
def test_maybe_update_config_already_populated(self):
|
||||||
|
# When quant_description is already populated, should be a no-op
|
||||||
|
self.assertTrue(len(self.ascend_config.quant_description) > 0)
|
||||||
|
self.ascend_config.maybe_update_config("/some/model/path")
|
||||||
|
# quant_description should remain unchanged
|
||||||
|
self.assertEqual(self.ascend_config.quant_description,
|
||||||
|
self.sample_config)
|
||||||
|
|
||||||
|
def test_maybe_update_config_loads_from_file(self):
|
||||||
|
config = AscendModelSlimConfig()
|
||||||
|
self.assertEqual(config.quant_description, {})
|
||||||
|
|
||||||
|
quant_data = {"layer1.weight": "INT8", "layer2.weight": "FLOAT"}
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME)
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(quant_data, f)
|
||||||
|
|
||||||
|
config.maybe_update_config(tmpdir)
|
||||||
|
|
||||||
|
self.assertEqual(config.quant_description, quant_data)
|
||||||
|
|
||||||
|
def test_maybe_update_config_raises_when_file_missing(self):
|
||||||
|
config = AscendModelSlimConfig()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
with self.assertRaises(ValueError) as ctx:
|
||||||
|
config.maybe_update_config(tmpdir)
|
||||||
|
|
||||||
|
error_msg = str(ctx.exception)
|
||||||
|
self.assertIn("ModelSlim Quantization Config Not Found", error_msg)
|
||||||
|
self.assertIn(MODELSLIM_CONFIG_FILENAME, error_msg)
|
||||||
|
|
||||||
|
def test_maybe_update_config_raises_with_json_files_listed(self):
|
||||||
|
config = AscendModelSlimConfig()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Create a dummy json file that is NOT the config file
|
||||||
|
dummy_path = os.path.join(tmpdir, "config.json")
|
||||||
|
with open(dummy_path, "w") as f:
|
||||||
|
json.dump({"dummy": True}, f)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as ctx:
|
||||||
|
config.maybe_update_config(tmpdir)
|
||||||
|
|
||||||
|
error_msg = str(ctx.exception)
|
||||||
|
self.assertIn("config.json", error_msg)
|
||||||
|
|
||||||
|
def test_maybe_update_config_non_directory_raises(self):
|
||||||
|
config = AscendModelSlimConfig()
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as ctx:
|
||||||
|
config.maybe_update_config("not_a_real_directory_path")
|
||||||
|
|
||||||
|
error_msg = str(ctx.exception)
|
||||||
|
self.assertIn("ModelSlim Quantization Config Not Found", error_msg)
|
||||||
|
|
||||||
|
def test_apply_extra_quant_adaptations_shared_head(self):
|
||||||
|
config = AscendModelSlimConfig()
|
||||||
|
config.quant_description = {
|
||||||
|
"model.layers.0.shared_head.weight": "INT8",
|
||||||
|
}
|
||||||
|
config._apply_extra_quant_adaptations()
|
||||||
|
self.assertIn("model.layers.0.weight", config.quant_description)
|
||||||
|
self.assertEqual(config.quant_description["model.layers.0.weight"],
|
||||||
|
"INT8")
|
||||||
|
|
||||||
|
def test_apply_extra_quant_adaptations_weight_packed(self):
|
||||||
|
config = AscendModelSlimConfig()
|
||||||
|
config.quant_description = {
|
||||||
|
"model.layers.0.weight_packed": "INT8",
|
||||||
|
}
|
||||||
|
config._apply_extra_quant_adaptations()
|
||||||
|
self.assertIn("model.layers.0.weight", config.quant_description)
|
||||||
|
self.assertEqual(config.quant_description["model.layers.0.weight"],
|
||||||
|
"INT8")
|
||||||
|
|
||||||
def test_get_scaled_act_names(self):
|
def test_get_scaled_act_names(self):
|
||||||
self.assertEqual(self.ascend_config.get_scaled_act_names(), [])
|
self.assertEqual(self.ascend_config.get_scaled_act_names(), [])
|
||||||
|
|||||||
182
tests/ut/quantization/test_quant_utils.py
Normal file
182
tests/ut/quantization/test_quant_utils.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend.quantization.modelslim_config import MODELSLIM_CONFIG_FILENAME
|
||||||
|
from vllm_ascend.quantization.utils import (
|
||||||
|
detect_quantization_method,
|
||||||
|
maybe_auto_detect_quantization,
|
||||||
|
)
|
||||||
|
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetectQuantizationMethod(TestBase):
|
||||||
|
|
||||||
|
def test_returns_none_for_non_directory(self):
|
||||||
|
result = detect_quantization_method("/non/existent/path")
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_detects_modelslim(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME)
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump({"layer.weight": "INT8"}, f)
|
||||||
|
|
||||||
|
result = detect_quantization_method(tmpdir)
|
||||||
|
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
|
||||||
|
|
||||||
|
def test_detects_compressed_tensors(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config_path = os.path.join(tmpdir, "config.json")
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump({
|
||||||
|
"quantization_config": {
|
||||||
|
"quant_method": "compressed-tensors"
|
||||||
|
}
|
||||||
|
}, f)
|
||||||
|
|
||||||
|
result = detect_quantization_method(tmpdir)
|
||||||
|
self.assertEqual(result, COMPRESSED_TENSORS_METHOD)
|
||||||
|
|
||||||
|
def test_returns_none_for_no_quant(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
result = detect_quantization_method(tmpdir)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_returns_none_for_non_compressed_tensors_quant_method(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config_path = os.path.join(tmpdir, "config.json")
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump({
|
||||||
|
"quantization_config": {
|
||||||
|
"quant_method": "gptq"
|
||||||
|
}
|
||||||
|
}, f)
|
||||||
|
|
||||||
|
result = detect_quantization_method(tmpdir)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_returns_none_for_config_without_quant_config(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config_path = os.path.join(tmpdir, "config.json")
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump({"model_type": "llama"}, f)
|
||||||
|
|
||||||
|
result = detect_quantization_method(tmpdir)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_returns_none_for_malformed_config_json(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config_path = os.path.join(tmpdir, "config.json")
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
f.write("not valid json{{{")
|
||||||
|
|
||||||
|
result = detect_quantization_method(tmpdir)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_modelslim_takes_priority_over_compressed_tensors(self):
|
||||||
|
"""When both ModelSlim config and compressed-tensors config exist,
|
||||||
|
ModelSlim should take priority."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Create ModelSlim config
|
||||||
|
modelslim_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME)
|
||||||
|
with open(modelslim_path, "w") as f:
|
||||||
|
json.dump({"layer.weight": "INT8"}, f)
|
||||||
|
|
||||||
|
# Create compressed-tensors config
|
||||||
|
config_path = os.path.join(tmpdir, "config.json")
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump({
|
||||||
|
"quantization_config": {
|
||||||
|
"quant_method": "compressed-tensors"
|
||||||
|
}
|
||||||
|
}, f)
|
||||||
|
|
||||||
|
result = detect_quantization_method(tmpdir)
|
||||||
|
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMaybeAutoDetectQuantization(TestBase):
|
||||||
|
|
||||||
|
def _make_vllm_config(self, model_path="/fake/model", quantization=None):
|
||||||
|
vllm_config = MagicMock()
|
||||||
|
vllm_config.model_config.model = model_path
|
||||||
|
vllm_config.model_config.quantization = quantization
|
||||||
|
return vllm_config
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
|
||||||
|
return_value=None)
|
||||||
|
def test_no_detection_does_nothing(self, mock_detect):
|
||||||
|
vllm_config = self._make_vllm_config()
|
||||||
|
maybe_auto_detect_quantization(vllm_config)
|
||||||
|
# quantization should remain unchanged
|
||||||
|
self.assertIsNone(vllm_config.model_config.quantization)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
|
||||||
|
return_value=ASCEND_QUANTIZATION_METHOD)
|
||||||
|
def test_user_specified_same_method_no_change(self, mock_detect):
|
||||||
|
vllm_config = self._make_vllm_config(
|
||||||
|
quantization=ASCEND_QUANTIZATION_METHOD)
|
||||||
|
maybe_auto_detect_quantization(vllm_config)
|
||||||
|
self.assertEqual(vllm_config.model_config.quantization,
|
||||||
|
ASCEND_QUANTIZATION_METHOD)
|
||||||
|
|
||||||
|
@patch("vllm.config.VllmConfig._get_quantization_config",
|
||||||
|
return_value=MagicMock())
|
||||||
|
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
|
||||||
|
return_value=ASCEND_QUANTIZATION_METHOD)
|
||||||
|
def test_auto_detect_sets_quantization_and_logs_info(
|
||||||
|
self, mock_detect, mock_get_quant_config):
|
||||||
|
"""When no --quantization is specified but ModelSlim config is found,
|
||||||
|
the method should auto-set quantization and emit an INFO log."""
|
||||||
|
vllm_config = self._make_vllm_config(
|
||||||
|
model_path="/fake/quant_model", quantization=None)
|
||||||
|
|
||||||
|
with self.assertLogs("vllm_ascend.quantization.utils",
|
||||||
|
level=logging.INFO) as cm:
|
||||||
|
maybe_auto_detect_quantization(vllm_config)
|
||||||
|
|
||||||
|
self.assertEqual(vllm_config.model_config.quantization,
|
||||||
|
ASCEND_QUANTIZATION_METHOD)
|
||||||
|
log_output = "\n".join(cm.output)
|
||||||
|
self.assertIn("Auto-detected quantization method", log_output)
|
||||||
|
self.assertIn(ASCEND_QUANTIZATION_METHOD, log_output)
|
||||||
|
self.assertIn("/fake/quant_model", log_output)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
|
||||||
|
return_value=ASCEND_QUANTIZATION_METHOD)
|
||||||
|
def test_user_mismatch_logs_warning(self, mock_detect):
|
||||||
|
"""When user specifies a different method than auto-detected,
|
||||||
|
a WARNING should be emitted and user's choice should be respected."""
|
||||||
|
vllm_config = self._make_vllm_config(
|
||||||
|
model_path="/fake/quant_model",
|
||||||
|
quantization=COMPRESSED_TENSORS_METHOD)
|
||||||
|
|
||||||
|
with self.assertLogs("vllm_ascend.quantization.utils",
|
||||||
|
level=logging.WARNING) as cm:
|
||||||
|
maybe_auto_detect_quantization(vllm_config)
|
||||||
|
|
||||||
|
# User's choice is respected
|
||||||
|
self.assertEqual(vllm_config.model_config.quantization,
|
||||||
|
COMPRESSED_TENSORS_METHOD)
|
||||||
|
log_output = "\n".join(cm.output)
|
||||||
|
self.assertIn("Auto-detected quantization method", log_output)
|
||||||
|
self.assertIn(ASCEND_QUANTIZATION_METHOD, log_output)
|
||||||
|
self.assertIn(COMPRESSED_TENSORS_METHOD, log_output)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
|
||||||
|
return_value=None)
|
||||||
|
def test_no_detection_emits_no_log(self, mock_detect):
|
||||||
|
"""When no quantization is detected, no log should be emitted."""
|
||||||
|
vllm_config = self._make_vllm_config(quantization=None)
|
||||||
|
logger_name = "vllm_ascend.quantization.utils"
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
# assertLogs raises AssertionError when no logs are emitted
|
||||||
|
with self.assertLogs(logger_name, level=logging.DEBUG):
|
||||||
|
maybe_auto_detect_quantization(vllm_config)
|
||||||
|
|
||||||
|
self.assertIsNone(vllm_config.model_config.quantization)
|
||||||
@@ -116,13 +116,14 @@ class TestNPUPlatform(TestBase):
|
|||||||
self.assertIsNone(self.platform.inference_mode())
|
self.assertIsNone(self.platform.inference_mode())
|
||||||
mock_inference_mode.assert_called_once()
|
mock_inference_mode.assert_called_once()
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
@patch("vllm_ascend.utils.update_aclgraph_sizes")
|
@patch("vllm_ascend.utils.update_aclgraph_sizes")
|
||||||
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
||||||
@patch("os.environ", {})
|
@patch("os.environ", {})
|
||||||
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
||||||
def test_check_and_update_config_basic_config_update(
|
def test_check_and_update_config_basic_config_update(
|
||||||
self, mock_init_recompute, mock_soc_version, mock_update_acl, mock_init_ascend
|
self, mock_init_recompute, mock_soc_version, mock_update_acl, mock_init_ascend, mock_auto_detect
|
||||||
):
|
):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
@@ -146,11 +147,12 @@ class TestNPUPlatform(TestBase):
|
|||||||
|
|
||||||
mock_init_ascend.assert_called_once_with(vllm_config)
|
mock_init_ascend.assert_called_once_with(vllm_config)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
||||||
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
||||||
def test_check_and_update_config_no_model_config_warning(
|
def test_check_and_update_config_no_model_config_warning(
|
||||||
self, mock_init_recompute, mock_init_ascend, mock_soc_version
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
||||||
):
|
):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
@@ -172,10 +174,11 @@ class TestNPUPlatform(TestBase):
|
|||||||
|
|
||||||
self.assertTrue("Model config is missing" in cm.output[0])
|
self.assertTrue("Model config is missing" in cm.output[0])
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
||||||
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
||||||
def test_check_and_update_config_enforce_eager_mode(self, mock_init_recompute, mock_init_ascend, mock_soc_version):
|
def test_check_and_update_config_enforce_eager_mode(self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
vllm_config.model_config.enforce_eager = True
|
vllm_config.model_config.enforce_eager = True
|
||||||
@@ -206,11 +209,12 @@ class TestNPUPlatform(TestBase):
|
|||||||
CUDAGraphMode.NONE,
|
CUDAGraphMode.NONE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
||||||
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
||||||
def test_check_and_update_config_unsupported_compilation_level(
|
def test_check_and_update_config_unsupported_compilation_level(
|
||||||
self, mock_init_recompute, mock_init_ascend, mock_soc_version
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
||||||
):
|
):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
@@ -244,9 +248,10 @@ class TestNPUPlatform(TestBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip("Revert me when vllm support setting cudagraph_mode on oot platform")
|
@pytest.mark.skip("Revert me when vllm support setting cudagraph_mode on oot platform")
|
||||||
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
||||||
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
def test_check_and_update_config_unsupported_cudagraph_mode(self, mock_init_ascend, mock_soc_version):
|
def test_check_and_update_config_unsupported_cudagraph_mode(self, mock_init_ascend, mock_soc_version, mock_auto_detect):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
vllm_config.model_config.enforce_eager = False
|
vllm_config.model_config.enforce_eager = False
|
||||||
@@ -268,11 +273,12 @@ class TestNPUPlatform(TestBase):
|
|||||||
CUDAGraphMode.NONE,
|
CUDAGraphMode.NONE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
||||||
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
||||||
def test_check_and_update_config_cache_config_block_size(
|
def test_check_and_update_config_cache_config_block_size(
|
||||||
self, mock_init_recompute, mock_init_ascend, mock_soc_version
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
||||||
):
|
):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
@@ -292,11 +298,12 @@ class TestNPUPlatform(TestBase):
|
|||||||
|
|
||||||
self.assertEqual(vllm_config.cache_config.block_size, 128)
|
self.assertEqual(vllm_config.cache_config.block_size, 128)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
||||||
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
||||||
def test_check_and_update_config_v1_worker_class_selection(
|
def test_check_and_update_config_v1_worker_class_selection(
|
||||||
self, mock_init_recompute, mock_init_ascend, mock_soc_version
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
||||||
):
|
):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
@@ -327,10 +334,11 @@ class TestNPUPlatform(TestBase):
|
|||||||
"vllm_ascend.xlite.xlite_worker.XliteWorker",
|
"vllm_ascend.xlite.xlite_worker.XliteWorker",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
||||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||||
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType._310P)
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType._310P)
|
||||||
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
||||||
def test_check_and_update_config_310p_no_custom_ops(self, mock_init_recompute, mock_soc_version, mock_init_ascend):
|
def test_check_and_update_config_310p_no_custom_ops(self, mock_init_recompute, mock_soc_version, mock_init_ascend, mock_auto_detect):
|
||||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
||||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||||
vllm_config.compilation_config.custom_ops = []
|
vllm_config.compilation_config.custom_ops = []
|
||||||
|
|||||||
@@ -171,6 +171,11 @@ class NPUPlatform(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
|
from vllm_ascend.quantization.utils import maybe_auto_detect_quantization
|
||||||
|
|
||||||
|
if vllm_config.model_config is not None:
|
||||||
|
maybe_auto_detect_quantization(vllm_config)
|
||||||
|
|
||||||
# initialize ascend config from vllm additional_config
|
# initialize ascend config from vllm additional_config
|
||||||
cls._fix_incompatible_config(vllm_config)
|
cls._fix_incompatible_config(vllm_config)
|
||||||
ascend_config = init_ascend_config(vllm_config)
|
ascend_config = init_ascend_config(vllm_config)
|
||||||
|
|||||||
@@ -21,6 +21,9 @@ This module provides the AscendModelSlimConfig class for parsing quantization
|
|||||||
configs generated by the ModelSlim tool, along with model-specific mappings.
|
configs generated by the ModelSlim tool, along with model-specific mappings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@@ -39,6 +42,9 @@ from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
|||||||
|
|
||||||
from .methods import get_scheme_class
|
from .methods import get_scheme_class
|
||||||
|
|
||||||
|
# The config filename that ModelSlim generates after quantizing a model.
|
||||||
|
MODELSLIM_CONFIG_FILENAME = "quant_model_description.json"
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# key: model_type
|
# key: model_type
|
||||||
@@ -310,9 +316,9 @@ class AscendModelSlimConfig(QuantizationConfig):
|
|||||||
quantized using the ModelSlim tool.
|
quantized using the ModelSlim tool.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: dict[str, Any]):
|
def __init__(self, quant_config: dict[str, Any] | None = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.quant_description = quant_config
|
self.quant_description = quant_config if quant_config is not None else {}
|
||||||
# TODO(whx): remove this adaptation after adding "shared_head"
|
# TODO(whx): remove this adaptation after adding "shared_head"
|
||||||
# to prefix of DeepSeekShareHead in vLLM.
|
# to prefix of DeepSeekShareHead in vLLM.
|
||||||
extra_quant_dict = {}
|
extra_quant_dict = {}
|
||||||
@@ -342,7 +348,12 @@ class AscendModelSlimConfig(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> list[str]:
|
def get_config_filenames(cls) -> list[str]:
|
||||||
return ["quant_model_description.json"]
|
# Return empty list so that vllm's get_quant_config() skips the
|
||||||
|
# file-based lookup (which raises an unfriendly "Cannot find the
|
||||||
|
# config file for ascend" error when the model is not quantized).
|
||||||
|
# Instead, the config file is loaded in maybe_update_config(),
|
||||||
|
# which can provide a user-friendly error message.
|
||||||
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict[str, Any]) -> "AscendModelSlimConfig":
|
def from_config(cls, config: dict[str, Any]) -> "AscendModelSlimConfig":
|
||||||
@@ -456,5 +467,98 @@ class AscendModelSlimConfig(QuantizationConfig):
|
|||||||
assert is_skipped is not None
|
assert is_skipped is not None
|
||||||
return is_skipped
|
return is_skipped
|
||||||
|
|
||||||
|
def maybe_update_config(self, model_name: str) -> None:
|
||||||
|
"""Load the ModelSlim quantization config from model directory.
|
||||||
|
|
||||||
|
This method is called by vllm after get_quant_config() returns
|
||||||
|
successfully. Since we return an empty list from get_config_filenames()
|
||||||
|
to bypass vllm's built-in file lookup, we do the actual config loading
|
||||||
|
here and provide user-friendly error messages when the config is missing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Path to the model directory or model name.
|
||||||
|
"""
|
||||||
|
# If quant_description is already populated (e.g. from from_config()),
|
||||||
|
# there is nothing to do.
|
||||||
|
if self.quant_description:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Try to find and load the ModelSlim config file
|
||||||
|
if os.path.isdir(model_name):
|
||||||
|
config_path = os.path.join(model_name, MODELSLIM_CONFIG_FILENAME)
|
||||||
|
if os.path.isfile(config_path):
|
||||||
|
with open(config_path) as f:
|
||||||
|
self.quant_description = json.load(f)
|
||||||
|
self._apply_extra_quant_adaptations()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if there are any json files at all to help diagnose
|
||||||
|
json_files = glob.glob(os.path.join(model_name, "*.json"))
|
||||||
|
json_names = [os.path.basename(f) for f in json_files]
|
||||||
|
else:
|
||||||
|
json_names = []
|
||||||
|
|
||||||
|
# Config file not found - raise a friendly error message
|
||||||
|
raise ValueError(
|
||||||
|
"\n"
|
||||||
|
+ "=" * 80
|
||||||
|
+ "\n"
|
||||||
|
+ "ERROR: ModelSlim Quantization Config Not Found\n"
|
||||||
|
+ "=" * 80
|
||||||
|
+ "\n"
|
||||||
|
+ "\n"
|
||||||
|
+ f"You have enabled '--quantization {ASCEND_QUANTIZATION_METHOD}' "
|
||||||
|
+ "(ModelSlim quantization),\n"
|
||||||
|
+ f"but the model at '{model_name}' does not contain the required\n"
|
||||||
|
+ f"quantization config file ('{MODELSLIM_CONFIG_FILENAME}').\n"
|
||||||
|
+ "\n"
|
||||||
|
+ "This usually means the model weights are NOT quantized by "
|
||||||
|
+ "ModelSlim.\n"
|
||||||
|
+ "\n"
|
||||||
|
+ "Please choose one of the following solutions:\n"
|
||||||
|
+ "\n"
|
||||||
|
+ " Solution 1: Remove the quantization option "
|
||||||
|
+ "(for float/unquantized models)\n"
|
||||||
|
+ " "
|
||||||
|
+ "-" * 58
|
||||||
|
+ "\n"
|
||||||
|
+ f" Remove '--quantization {ASCEND_QUANTIZATION_METHOD}' from "
|
||||||
|
+ "your command if you want to\n"
|
||||||
|
+ " run the model with the original (float) weights.\n"
|
||||||
|
+ "\n"
|
||||||
|
+ " Example:\n"
|
||||||
|
+ f" vllm serve {model_name}\n"
|
||||||
|
+ "\n"
|
||||||
|
+ " Solution 2: Quantize your model weights with ModelSlim first\n"
|
||||||
|
+ " "
|
||||||
|
+ "-" * 58
|
||||||
|
+ "\n"
|
||||||
|
+ " Use the ModelSlim tool to quantize your model weights "
|
||||||
|
+ "before deployment.\n"
|
||||||
|
+ " After quantization, the model directory should contain "
|
||||||
|
+ f"'{MODELSLIM_CONFIG_FILENAME}'.\n"
|
||||||
|
+ " For more information, please refer to:\n"
|
||||||
|
+ " https://gitee.com/ascend/msit/tree/master/msmodelslim\n"
|
||||||
|
+ "\n"
|
||||||
|
+ (f" (Found JSON files in model directory: {json_names})\n" if json_names else "")
|
||||||
|
+ "=" * 80
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_extra_quant_adaptations(self) -> None:
|
||||||
|
"""Apply extra adaptations to the quant_description dict.
|
||||||
|
|
||||||
|
This handles known key transformations such as shared_head and
|
||||||
|
weight_packed mappings.
|
||||||
|
"""
|
||||||
|
extra_quant_dict = {}
|
||||||
|
for k in self.quant_description:
|
||||||
|
if "shared_head" in k:
|
||||||
|
new_k = k.replace(".shared_head.", ".")
|
||||||
|
extra_quant_dict[new_k] = self.quant_description[k]
|
||||||
|
if "weight_packed" in k:
|
||||||
|
new_k = k.replace("weight_packed", "weight")
|
||||||
|
extra_quant_dict[new_k] = self.quant_description[k]
|
||||||
|
self.quant_description.update(extra_quant_dict)
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> list[str]:
|
def get_scaled_act_names(self) -> list[str]:
|
||||||
return []
|
return []
|
||||||
|
|||||||
147
vllm_ascend/quantization/utils.py
Normal file
147
vllm_ascend/quantization/utils.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from vllm_ascend.quantization.modelslim_config import MODELSLIM_CONFIG_FILENAME
|
||||||
|
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_quantization_method(model_path: str) -> str | None:
|
||||||
|
"""Auto-detect the quantization method from model directory files.
|
||||||
|
|
||||||
|
This function performs a lightweight check (JSON files and file existence
|
||||||
|
only — no .safetensors or .bin inspection) to determine which quantization
|
||||||
|
method was used to produce the weights in *model_path*.
|
||||||
|
|
||||||
|
Detection priority:
|
||||||
|
1. **ModelSlim (Ascend)** – ``quant_model_description.json`` exists
|
||||||
|
in the model directory.
|
||||||
|
2. **LLM-Compressor (compressed-tensors)** – ``config.json`` contains
|
||||||
|
a ``quantization_config`` section with
|
||||||
|
``"quant_method": "compressed-tensors"``.
|
||||||
|
3. **None** – neither condition is met; the caller should fall back to
|
||||||
|
the default (float) behaviour.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the local model directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``"ascend"`` for ModelSlim models,
|
||||||
|
``"compressed-tensors"`` for LLM-Compressor models,
|
||||||
|
or ``None`` if no quantization signature is found.
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(model_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Case 1: ModelSlim — look for quant_model_description.json
|
||||||
|
modelslim_config_path = os.path.join(model_path, MODELSLIM_CONFIG_FILENAME)
|
||||||
|
if os.path.isfile(modelslim_config_path):
|
||||||
|
return ASCEND_QUANTIZATION_METHOD
|
||||||
|
|
||||||
|
# Case 2: LLM-Compressor — look for compressed-tensors in config.json
|
||||||
|
config_json_path = os.path.join(model_path, "config.json")
|
||||||
|
if os.path.isfile(config_json_path):
|
||||||
|
try:
|
||||||
|
with open(config_json_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
quant_cfg = config.get("quantization_config")
|
||||||
|
if isinstance(quant_cfg, dict):
|
||||||
|
quant_method = quant_cfg.get("quant_method", "")
|
||||||
|
if quant_method == COMPRESSED_TENSORS_METHOD:
|
||||||
|
return COMPRESSED_TENSORS_METHOD
|
||||||
|
except (json.JSONDecodeError, OSError):
|
||||||
|
# Malformed or unreadable config.json — skip silently.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Case 3: No quantization signature found.
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_auto_detect_quantization(vllm_config) -> None:
|
||||||
|
"""Auto-detect and apply the quantization method on *vllm_config*.
|
||||||
|
|
||||||
|
This should be called during engine initialisation (from
|
||||||
|
``NPUPlatform.check_and_update_config``) **after** ``VllmConfig`` has been
|
||||||
|
created but **before** heavy weights are loaded.
|
||||||
|
|
||||||
|
Because ``check_and_update_config`` runs *after*
|
||||||
|
``VllmConfig.__post_init__`` has already evaluated
|
||||||
|
``_get_quantization_config`` (which returned ``None`` when
|
||||||
|
``model_config.quantization`` was not set), we must:
|
||||||
|
|
||||||
|
1. Set ``model_config.quantization`` to the detected value.
|
||||||
|
2. Recreate ``vllm_config.quant_config`` so that the quantization
|
||||||
|
pipeline (``get_quant_config`` → ``QuantizationConfig`` →
|
||||||
|
``get_quant_method`` for every layer) is properly initialised.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
* If the user explicitly set ``--quantization``, that value is
|
||||||
|
respected. A warning is emitted when the detected method differs.
|
||||||
|
* If no ``--quantization`` was given, the detected method (if any) is
|
||||||
|
applied automatically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: A ``vllm.config.VllmConfig`` instance (mutable).
|
||||||
|
"""
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
model_path = model_config.model
|
||||||
|
user_quant = model_config.quantization
|
||||||
|
detected = detect_quantization_method(model_path)
|
||||||
|
|
||||||
|
if detected is None:
|
||||||
|
# No quantization signature found — nothing to do.
|
||||||
|
return
|
||||||
|
|
||||||
|
if user_quant is not None:
|
||||||
|
# User explicitly specified a quantization method.
|
||||||
|
if user_quant != detected:
|
||||||
|
logger.warning(
|
||||||
|
"Auto-detected quantization method '%s' from model "
|
||||||
|
"files at '%s', but user explicitly specified "
|
||||||
|
"'--quantization %s'. Respecting the user-specified "
|
||||||
|
"value. If you encounter errors during model loading, "
|
||||||
|
"consider using '--quantization %s' instead.",
|
||||||
|
detected,
|
||||||
|
model_path,
|
||||||
|
user_quant,
|
||||||
|
detected,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# No user-specified quantization — apply auto-detected value.
|
||||||
|
model_config.quantization = detected
|
||||||
|
logger.info(
|
||||||
|
"Auto-detected quantization method '%s' from model files "
|
||||||
|
"at '%s'. To override, pass '--quantization <method>' explicitly.",
|
||||||
|
detected,
|
||||||
|
model_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recreate quant_config on VllmConfig. The original __post_init__
|
||||||
|
# already ran _get_quantization_config(), but at that point
|
||||||
|
# model_config.quantization was None so it returned None. Now that
|
||||||
|
# we've set it, we need to build the actual QuantizationConfig so the
|
||||||
|
# downstream model-loading code can use it.
|
||||||
|
from vllm.config import VllmConfig as _VllmConfig
|
||||||
|
|
||||||
|
vllm_config.quant_config = _VllmConfig._get_quantization_config(model_config, vllm_config.load_config)
|
||||||
Reference in New Issue
Block a user