### What this PR does / why we need it?
Reapply the auto-detect quantization format feature (originally in
#6645, reverted in #6873) and extend it to support remote model
identifiers (e.g., `org/model-name`).
Changes:
- Reapply auto-detection of quantization method from model files
(`quant_model_description.json` for ModelSlim, `config.json` for
compressed-tensors)
- Add `get_model_file()` utility to handle file retrieval from both
local paths and remote repos (HuggingFace Hub / ModelScope)
- Update `detect_quantization_method()` to accept remote repo IDs with
optional `revision` parameter
- Update `maybe_update_config()` to work with remote model identifiers
- Add platform-level `auto_detect_quantization` support
- Add unit tests and e2e tests for both local and remote model ID
scenarios
Closes #6836
### Does this PR introduce _any_ user-facing change?
Yes. When `--quantization` is not explicitly specified, vllm-ascend will
now automatically detect the quantization format from the model files
for both local directories and remote model IDs.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
193 lines
8.2 KiB
Python
193 lines
8.2 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.quantization.modelslim_config import MODELSLIM_CONFIG_FILENAME
|
|
from vllm_ascend.quantization.utils import (
|
|
detect_quantization_method,
|
|
maybe_auto_detect_quantization,
|
|
)
|
|
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
|
|
|
|
|
|
class TestDetectQuantizationMethod(TestBase):
|
|
|
|
def test_returns_none_for_non_existent_path(self):
|
|
result = detect_quantization_method("/non/existent/path")
|
|
self.assertIsNone(result)
|
|
|
|
def test_detects_modelslim(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
config_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME)
|
|
with open(config_path, "w") as f:
|
|
json.dump({"layer.weight": "INT8"}, f)
|
|
|
|
result = detect_quantization_method(tmpdir)
|
|
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
|
|
|
|
def test_detects_compressed_tensors(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
config_path = os.path.join(tmpdir, "config.json")
|
|
with open(config_path, "w") as f:
|
|
json.dump({
|
|
"quantization_config": {
|
|
"quant_method": "compressed-tensors"
|
|
}
|
|
}, f)
|
|
|
|
result = detect_quantization_method(tmpdir)
|
|
self.assertEqual(result, COMPRESSED_TENSORS_METHOD)
|
|
|
|
def test_returns_none_for_no_quant(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
result = detect_quantization_method(tmpdir)
|
|
self.assertIsNone(result)
|
|
|
|
def test_returns_none_for_non_compressed_tensors_quant_method(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
config_path = os.path.join(tmpdir, "config.json")
|
|
with open(config_path, "w") as f:
|
|
json.dump({
|
|
"quantization_config": {
|
|
"quant_method": "gptq"
|
|
}
|
|
}, f)
|
|
|
|
result = detect_quantization_method(tmpdir)
|
|
self.assertIsNone(result)
|
|
|
|
def test_returns_none_for_config_without_quant_config(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
config_path = os.path.join(tmpdir, "config.json")
|
|
with open(config_path, "w") as f:
|
|
json.dump({"model_type": "llama"}, f)
|
|
|
|
result = detect_quantization_method(tmpdir)
|
|
self.assertIsNone(result)
|
|
|
|
def test_returns_none_for_malformed_config_json(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
config_path = os.path.join(tmpdir, "config.json")
|
|
with open(config_path, "w") as f:
|
|
f.write("not valid json{{{")
|
|
|
|
result = detect_quantization_method(tmpdir)
|
|
self.assertIsNone(result)
|
|
|
|
def test_modelslim_takes_priority_over_compressed_tensors(self):
|
|
"""When both ModelSlim config and compressed-tensors config exist,
|
|
ModelSlim should take priority."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
modelslim_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME)
|
|
with open(modelslim_path, "w") as f:
|
|
json.dump({"layer.weight": "INT8"}, f)
|
|
|
|
config_path = os.path.join(tmpdir, "config.json")
|
|
with open(config_path, "w") as f:
|
|
json.dump({
|
|
"quantization_config": {
|
|
"quant_method": "compressed-tensors"
|
|
}
|
|
}, f)
|
|
|
|
result = detect_quantization_method(tmpdir)
|
|
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
|
|
|
|
|
|
class TestMaybeAutoDetectQuantization(TestBase):
|
|
|
|
def _make_vllm_config(self, model_path="/fake/model",
|
|
quantization=None, revision=None):
|
|
vllm_config = MagicMock()
|
|
vllm_config.model_config.model = model_path
|
|
vllm_config.model_config.quantization = quantization
|
|
vllm_config.model_config.revision = revision
|
|
return vllm_config
|
|
|
|
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
|
|
return_value=None)
|
|
def test_no_detection_does_nothing(self, mock_detect):
|
|
vllm_config = self._make_vllm_config()
|
|
maybe_auto_detect_quantization(vllm_config)
|
|
self.assertIsNone(vllm_config.model_config.quantization)
|
|
|
|
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
|
|
return_value=ASCEND_QUANTIZATION_METHOD)
|
|
def test_user_specified_same_method_no_change(self, mock_detect):
|
|
vllm_config = self._make_vllm_config(
|
|
quantization=ASCEND_QUANTIZATION_METHOD)
|
|
maybe_auto_detect_quantization(vllm_config)
|
|
self.assertEqual(vllm_config.model_config.quantization,
|
|
ASCEND_QUANTIZATION_METHOD)
|
|
|
|
@patch("vllm.config.VllmConfig._get_quantization_config",
|
|
return_value=MagicMock())
|
|
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
|
|
return_value=ASCEND_QUANTIZATION_METHOD)
|
|
def test_auto_detect_sets_quantization_and_logs_info(
|
|
self, mock_detect, mock_get_quant_config):
|
|
"""When no --quantization is specified but ModelSlim config is found,
|
|
the method should auto-set quantization and emit an INFO log."""
|
|
vllm_config = self._make_vllm_config(
|
|
model_path="/fake/quant_model", quantization=None)
|
|
|
|
with self.assertLogs("vllm_ascend.quantization.utils",
|
|
level=logging.INFO) as cm:
|
|
maybe_auto_detect_quantization(vllm_config)
|
|
|
|
self.assertEqual(vllm_config.model_config.quantization,
|
|
ASCEND_QUANTIZATION_METHOD)
|
|
log_output = "\n".join(cm.output)
|
|
self.assertIn("Auto-detected quantization method", log_output)
|
|
self.assertIn(ASCEND_QUANTIZATION_METHOD, log_output)
|
|
self.assertIn("/fake/quant_model", log_output)
|
|
|
|
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
|
|
return_value=ASCEND_QUANTIZATION_METHOD)
|
|
def test_user_mismatch_logs_warning(self, mock_detect):
|
|
"""When user specifies a different method than auto-detected,
|
|
a WARNING should be emitted and user's choice should be respected."""
|
|
vllm_config = self._make_vllm_config(
|
|
model_path="/fake/quant_model",
|
|
quantization=COMPRESSED_TENSORS_METHOD)
|
|
|
|
with self.assertLogs("vllm_ascend.quantization.utils",
|
|
level=logging.WARNING) as cm:
|
|
maybe_auto_detect_quantization(vllm_config)
|
|
|
|
self.assertEqual(vllm_config.model_config.quantization,
|
|
COMPRESSED_TENSORS_METHOD)
|
|
log_output = "\n".join(cm.output)
|
|
self.assertIn("Auto-detected quantization method", log_output)
|
|
self.assertIn(ASCEND_QUANTIZATION_METHOD, log_output)
|
|
self.assertIn(COMPRESSED_TENSORS_METHOD, log_output)
|
|
|
|
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
|
|
return_value=None)
|
|
def test_no_detection_emits_no_log(self, mock_detect):
|
|
"""When no quantization is detected, no log should be emitted."""
|
|
vllm_config = self._make_vllm_config(quantization=None)
|
|
logger_name = "vllm_ascend.quantization.utils"
|
|
|
|
with self.assertRaises(AssertionError):
|
|
# assertLogs raises AssertionError when no logs are emitted
|
|
with self.assertLogs(logger_name, level=logging.DEBUG):
|
|
maybe_auto_detect_quantization(vllm_config)
|
|
|
|
self.assertIsNone(vllm_config.model_config.quantization)
|
|
|
|
@patch("vllm.config.VllmConfig._get_quantization_config",
|
|
return_value=MagicMock())
|
|
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
|
|
return_value=ASCEND_QUANTIZATION_METHOD)
|
|
def test_passes_revision_to_detect(self, mock_detect, mock_get_quant):
|
|
"""Verify that model revision is forwarded to detect_quantization_method."""
|
|
vllm_config = self._make_vllm_config(
|
|
model_path="org/model-name", revision="v1.0", quantization=None)
|
|
maybe_auto_detect_quantization(vllm_config)
|
|
mock_detect.assert_called_once_with("org/model-name", revision="v1.0")
|