Files
xc-llm-ascend/tests/ut/quantization/test_quant_utils.py
Li Wang 83a4065b4b [CI] Add pre-commit check for patch logger (#7446)
### What this PR does / why we need it?
See https://github.com/vllm-project/vllm-ascend/pull/7402, pre-commit
hook will forbid init_logger(__name__) in vllm_ascend patch modules

- vLLM version: v0.17.0
- vLLM main:
8a680463fa

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
2026-03-19 16:53:20 +08:00

191 lines
8.2 KiB
Python

import json
import os
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
from tests.ut.base import TestBase
from vllm_ascend.quantization.modelslim_config import MODELSLIM_CONFIG_FILENAME
from vllm_ascend.quantization.utils import (
detect_quantization_method,
maybe_auto_detect_quantization,
)
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD
class TestDetectQuantizationMethod(TestBase):
def test_returns_none_for_non_existent_path(self):
result = detect_quantization_method("/non/existent/path")
self.assertIsNone(result)
def test_detects_modelslim(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME)
with open(config_path, "w") as f:
json.dump({"layer.weight": "INT8"}, f)
result = detect_quantization_method(tmpdir)
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
def test_detects_compressed_tensors(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "w") as f:
json.dump({
"quantization_config": {
"quant_method": "compressed-tensors"
}
}, f)
result = detect_quantization_method(tmpdir)
self.assertEqual(result, COMPRESSED_TENSORS_METHOD)
def test_returns_none_for_no_quant(self):
with tempfile.TemporaryDirectory() as tmpdir:
result = detect_quantization_method(tmpdir)
self.assertIsNone(result)
def test_returns_none_for_non_compressed_tensors_quant_method(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "w") as f:
json.dump({
"quantization_config": {
"quant_method": "gptq"
}
}, f)
result = detect_quantization_method(tmpdir)
self.assertIsNone(result)
def test_returns_none_for_config_without_quant_config(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "w") as f:
json.dump({"model_type": "llama"}, f)
result = detect_quantization_method(tmpdir)
self.assertIsNone(result)
def test_returns_none_for_malformed_config_json(self):
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "w") as f:
f.write("not valid json{{{")
result = detect_quantization_method(tmpdir)
self.assertIsNone(result)
def test_modelslim_takes_priority_over_compressed_tensors(self):
"""When both ModelSlim config and compressed-tensors config exist,
ModelSlim should take priority."""
with tempfile.TemporaryDirectory() as tmpdir:
modelslim_path = os.path.join(tmpdir, MODELSLIM_CONFIG_FILENAME)
with open(modelslim_path, "w") as f:
json.dump({"layer.weight": "INT8"}, f)
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "w") as f:
json.dump({
"quantization_config": {
"quant_method": "compressed-tensors"
}
}, f)
result = detect_quantization_method(tmpdir)
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
class TestMaybeAutoDetectQuantization(TestBase):
def _make_vllm_config(self, model_path="/fake/model",
quantization=None, revision=None):
vllm_config = MagicMock()
vllm_config.model_config.model = model_path
vllm_config.model_config.quantization = quantization
vllm_config.model_config.revision = revision
return vllm_config
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
return_value=None)
def test_no_detection_does_nothing(self, mock_detect):
vllm_config = self._make_vllm_config()
maybe_auto_detect_quantization(vllm_config)
self.assertIsNone(vllm_config.model_config.quantization)
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
return_value=ASCEND_QUANTIZATION_METHOD)
def test_user_specified_same_method_no_change(self, mock_detect):
vllm_config = self._make_vllm_config(
quantization=ASCEND_QUANTIZATION_METHOD)
maybe_auto_detect_quantization(vllm_config)
self.assertEqual(vllm_config.model_config.quantization,
ASCEND_QUANTIZATION_METHOD)
@patch("vllm.config.VllmConfig._get_quantization_config",
return_value=MagicMock())
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
return_value=ASCEND_QUANTIZATION_METHOD)
def test_auto_detect_sets_quantization_and_logs_info(
self, mock_detect, mock_get_quant_config):
"""When no --quantization is specified but ModelSlim config is found,
the method should auto-set quantization and emit an INFO log."""
vllm_config = self._make_vllm_config(
model_path="/fake/quant_model", quantization=None)
with patch("vllm_ascend.quantization.utils.logger") as mock_logger:
maybe_auto_detect_quantization(vllm_config)
self.assertEqual(vllm_config.model_config.quantization,
ASCEND_QUANTIZATION_METHOD)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args[0]
self.assertIn("Auto-detected quantization method", call_args[0])
self.assertIn(ASCEND_QUANTIZATION_METHOD, call_args)
self.assertIn("/fake/quant_model", call_args)
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
return_value=ASCEND_QUANTIZATION_METHOD)
def test_user_mismatch_logs_warning(self, mock_detect):
"""When user specifies a different method than auto-detected,
a WARNING should be emitted and user's choice should be respected."""
vllm_config = self._make_vllm_config(
model_path="/fake/quant_model",
quantization=COMPRESSED_TENSORS_METHOD)
with patch("vllm_ascend.quantization.utils.logger") as mock_logger:
maybe_auto_detect_quantization(vllm_config)
self.assertEqual(vllm_config.model_config.quantization,
COMPRESSED_TENSORS_METHOD)
mock_logger.warning.assert_called_once()
call_args = mock_logger.warning.call_args[0]
self.assertIn("Auto-detected quantization method", call_args[0])
self.assertIn(ASCEND_QUANTIZATION_METHOD, call_args)
self.assertIn(COMPRESSED_TENSORS_METHOD, call_args)
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
return_value=None)
def test_no_detection_emits_no_log(self, mock_detect):
"""When no quantization is detected, no log should be emitted."""
vllm_config = self._make_vllm_config(quantization=None)
with patch("vllm_ascend.quantization.utils.logger") as mock_logger:
maybe_auto_detect_quantization(vllm_config)
mock_logger.info.assert_not_called()
mock_logger.warning.assert_not_called()
self.assertIsNone(vllm_config.model_config.quantization)
@patch("vllm.config.VllmConfig._get_quantization_config",
return_value=MagicMock())
@patch("vllm_ascend.quantization.utils.detect_quantization_method",
return_value=ASCEND_QUANTIZATION_METHOD)
def test_passes_revision_to_detect(self, mock_detect, mock_get_quant):
"""Verify that model revision is forwarded to detect_quantization_method."""
vllm_config = self._make_vllm_config(
model_path="org/model-name", revision="v1.0", quantization=None)
maybe_auto_detect_quantization(vllm_config)
mock_detect.assert_called_once_with("org/model-name", revision="v1.0")