[1/N][Refactor][Quantization] remove redundant quantizer class (#2680)
### What this PR does / why we need it?
AscendQuantizer/LLMQuantizer class is used to select quant method based
on quant config and some other arguments,
but it is more simple and clean replacing these classes with map. So i
remove them.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
ut and e2e test
- vLLM version: v0.10.1.1
- vLLM main:
6997a25ac6
Signed-off-by: 22dimensions <waitingwind@foxmail.com>
This commit is contained in:
@@ -156,33 +156,22 @@ class TestAscendKVCacheMethod(TestBase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Setup common test fixtures
|
# Setup common test fixtures
|
||||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
self.mock_quant_config.quant_description = {"kv_quant_type": "C8"}
|
||||||
self.prefix = "attention_layer"
|
self.prefix = "layer.attn"
|
||||||
|
|
||||||
# Mock the quantizer and quant_method
|
# Mock quant_method
|
||||||
self.mock_quantizer = MagicMock()
|
|
||||||
self.mock_quant_method = MagicMock()
|
self.mock_quant_method = MagicMock()
|
||||||
|
self.patcher = patch(
|
||||||
# Patch the AscendQuantizer
|
'vllm_ascend.quantization.quant_config.get_quant_method')
|
||||||
self.quantizer_patcher = patch(
|
self.mock_get_quant_method = self.patcher.start()
|
||||||
'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer',
|
self.mock_get_quant_method.return_value = self.mock_quant_method
|
||||||
return_value=self.mock_quantizer)
|
|
||||||
self.mock_get_quantizer = self.quantizer_patcher.start()
|
|
||||||
|
|
||||||
self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method
|
|
||||||
|
|
||||||
# Create instance
|
# Create instance
|
||||||
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
|
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
|
||||||
self.prefix)
|
self.prefix)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.quantizer_patcher.stop()
|
self.patcher.stop()
|
||||||
|
|
||||||
def test_init(self):
|
|
||||||
"""Test initialization with proper quantizer setup."""
|
|
||||||
self.mock_get_quantizer.assert_called_once_with(
|
|
||||||
self.mock_quant_config.quant_description, self.prefix)
|
|
||||||
self.mock_quantizer.build_attention_method.assert_called_once()
|
|
||||||
|
|
||||||
def test_create_weights(self):
|
def test_create_weights(self):
|
||||||
"""Test create_weights delegates to quant_method."""
|
"""Test create_weights delegates to quant_method."""
|
||||||
|
|||||||
@@ -1,145 +0,0 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
|
||||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
|
||||||
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
|
|
||||||
W4A8DYNAMICQuantizer,
|
|
||||||
W8A8Quantizer)
|
|
||||||
|
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetQuantizer(TestBase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
# Setup common test fixtures
|
|
||||||
self.supported_types = {
|
|
||||||
'INT8': MagicMock(_instance=None),
|
|
||||||
'FP16': MagicMock(_instance=None),
|
|
||||||
'C8': MagicMock(_instance=None)
|
|
||||||
}
|
|
||||||
self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy()
|
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types)
|
|
||||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
|
||||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
# Restore original supported types
|
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE.clear()
|
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types)
|
|
||||||
|
|
||||||
def test_get_quantizer_fa(self):
|
|
||||||
"""Test successful quantizer retrieval for different cases."""
|
|
||||||
# Setup
|
|
||||||
quant_description = {'fa_quant_type': 'C8'}
|
|
||||||
prefix = '.attn'
|
|
||||||
expected_type = 'C8'
|
|
||||||
with patch.dict(
|
|
||||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
|
||||||
|
|
||||||
result = VLLMAscendQuantizer.get_quantizer(
|
|
||||||
quant_description,
|
|
||||||
prefix,
|
|
||||||
packed_modules_mapping={"some": "mapping"})
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
self.assertIsNotNone(result)
|
|
||||||
self.assertEqual(result,
|
|
||||||
self.supported_types[expected_type]._instance)
|
|
||||||
self.supported_types[expected_type].assert_called_once_with(
|
|
||||||
quant_description)
|
|
||||||
|
|
||||||
def test_get_quantizer_kv(self):
|
|
||||||
"""Test successful quantizer retrieval for different cases."""
|
|
||||||
# Setup
|
|
||||||
quant_description = {'kv_quant_type': 'C8'}
|
|
||||||
prefix = '.attn'
|
|
||||||
expected_type = 'C8'
|
|
||||||
with patch.dict(
|
|
||||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
|
||||||
|
|
||||||
result = VLLMAscendQuantizer.get_quantizer(
|
|
||||||
quant_description,
|
|
||||||
prefix,
|
|
||||||
packed_modules_mapping={"some": "mapping"})
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
self.assertIsNotNone(result)
|
|
||||||
self.assertEqual(result,
|
|
||||||
self.supported_types[expected_type]._instance)
|
|
||||||
self.supported_types[expected_type].assert_called_once_with(
|
|
||||||
quant_description)
|
|
||||||
|
|
||||||
def test_get_quantizer_linear(self):
|
|
||||||
"""Test successful quantizer retrieval for different cases."""
|
|
||||||
# Setup
|
|
||||||
quant_description = {'linear_type': 'INT8'}
|
|
||||||
prefix = 'nothing'
|
|
||||||
expected_type = 'INT8'
|
|
||||||
with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type',
|
|
||||||
return_value=expected_type), \
|
|
||||||
patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE):
|
|
||||||
|
|
||||||
result = VLLMAscendQuantizer.get_quantizer(
|
|
||||||
quant_description,
|
|
||||||
prefix,
|
|
||||||
packed_modules_mapping={"some": "mapping"})
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
self.assertIsNotNone(result)
|
|
||||||
self.assertEqual(result,
|
|
||||||
self.supported_types[expected_type]._instance)
|
|
||||||
self.supported_types[expected_type].assert_called_once_with(
|
|
||||||
quant_description)
|
|
||||||
|
|
||||||
|
|
||||||
class TestW8A8Quantizer(TestBase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.quantizer = W8A8Quantizer(quant_description={})
|
|
||||||
|
|
||||||
def test_build_linear_method(self):
|
|
||||||
with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod',
|
|
||||||
return_value=MagicMock()) as mock_linear:
|
|
||||||
result = self.quantizer.build_linear_method()
|
|
||||||
mock_linear.assert_called_once_with()
|
|
||||||
self.assertIsInstance(result, MagicMock)
|
|
||||||
|
|
||||||
def test_build_moe_method(self):
|
|
||||||
with patch(
|
|
||||||
'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod',
|
|
||||||
return_value=MagicMock()) as mock_linear:
|
|
||||||
result = self.quantizer.build_moe_method()
|
|
||||||
mock_linear.assert_called_once_with()
|
|
||||||
self.assertIsInstance(result, MagicMock)
|
|
||||||
|
|
||||||
def test_build_attention_method(self):
|
|
||||||
with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod',
|
|
||||||
return_value=MagicMock()) as mock_linear:
|
|
||||||
result = self.quantizer.build_attention_method()
|
|
||||||
mock_linear.assert_called_once_with()
|
|
||||||
self.assertIsInstance(result, MagicMock)
|
|
||||||
|
|
||||||
|
|
||||||
class TestW4A8DYNAMICQuantizer(TestBase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.quantizer = W4A8DYNAMICQuantizer(quant_description={})
|
|
||||||
|
|
||||||
def test_build_linear_method(self):
|
|
||||||
with patch(
|
|
||||||
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicLinearMethod',
|
|
||||||
return_value=MagicMock()) as mock_linear:
|
|
||||||
result = self.quantizer.build_linear_method()
|
|
||||||
mock_linear.assert_called_once_with()
|
|
||||||
self.assertIsInstance(result, MagicMock)
|
|
||||||
|
|
||||||
def test_build_moe_method(self):
|
|
||||||
with patch(
|
|
||||||
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicFusedMoEMethod',
|
|
||||||
return_value=MagicMock()) as mock_fused_moe:
|
|
||||||
result = self.quantizer.build_moe_method()
|
|
||||||
mock_fused_moe.assert_called_once_with()
|
|
||||||
self.assertIsInstance(result, MagicMock)
|
|
||||||
62
tests/ut/quantization/test_utils.py
Normal file
62
tests/ut/quantization/test_utils.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import types
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP,
|
||||||
|
get_quant_method)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetQuantMethod(TestBase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy(
|
||||||
|
)
|
||||||
|
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||||
|
for layer_type in layer_map.keys():
|
||||||
|
ASCEND_QUANTIZATION_METHOD_MAP[quant_type][
|
||||||
|
layer_type] = types.new_class(f"{quant_type}_{layer_type}")
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# Restore original map
|
||||||
|
ASCEND_QUANTIZATION_METHOD_MAP.clear()
|
||||||
|
ASCEND_QUANTIZATION_METHOD_MAP.update(
|
||||||
|
self.original_quantization_method_map)
|
||||||
|
|
||||||
|
def test_linear_quant_methods(self):
|
||||||
|
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||||
|
if "linear" in layer_map.keys():
|
||||||
|
prefix = "linear_layer"
|
||||||
|
cls = layer_map["linear"]
|
||||||
|
method = get_quant_method({"linear_layer.weight": quant_type},
|
||||||
|
prefix, "linear")
|
||||||
|
self.assertIsInstance(method, cls)
|
||||||
|
|
||||||
|
def test_moe_quant_methods(self):
|
||||||
|
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||||
|
if "moe" in layer_map.keys():
|
||||||
|
prefix = "layer"
|
||||||
|
cls = layer_map["moe"]
|
||||||
|
method = get_quant_method({"layer.weight": quant_type}, prefix,
|
||||||
|
"moe")
|
||||||
|
self.assertIsInstance(method, cls)
|
||||||
|
|
||||||
|
def test_with_fa_quant_type(self):
|
||||||
|
quant_description = {"fa_quant_type": "C8"}
|
||||||
|
method = get_quant_method(quant_description, ".attn", "attention")
|
||||||
|
self.assertIsInstance(
|
||||||
|
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
||||||
|
|
||||||
|
def test_with_kv_quant_type(self):
|
||||||
|
quant_description = {"kv_quant_type": "C8"}
|
||||||
|
method = get_quant_method(quant_description, ".attn", "attention")
|
||||||
|
self.assertIsInstance(
|
||||||
|
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
||||||
|
|
||||||
|
def test_invalid_layer_type(self):
|
||||||
|
quant_description = {"linear_layer.weight": "W8A8"}
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
get_quant_method(quant_description, "linear_layer", "unsupported")
|
||||||
|
|
||||||
|
def test_invalid_quant_type(self):
|
||||||
|
quant_description = {"linear_layer.weight": "UNKNOWN"}
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
get_quant_method(quant_description, "linear_layer", "linear")
|
||||||
@@ -24,7 +24,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
|||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
||||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||||
from vllm_ascend.quantization.quantizer import W8A8Quantizer
|
|
||||||
from vllm_ascend.torchair.ops.torchair_fused_moe import (
|
from vllm_ascend.torchair.ops.torchair_fused_moe import (
|
||||||
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
|
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
|
||||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
||||||
@@ -236,12 +235,9 @@ class TestTorchairAscendFusedMoe:
|
|||||||
mock_quant_method = MockFusedMoEMethod()
|
mock_quant_method = MockFusedMoEMethod()
|
||||||
mock_quant_config.get_quant_method.return_value = mock_quant_method
|
mock_quant_config.get_quant_method.return_value = mock_quant_method
|
||||||
mock_quant_config.is_layer_skipped_ascend.return_value = False
|
mock_quant_config.is_layer_skipped_ascend.return_value = False
|
||||||
with patch(
|
with patch("vllm_ascend.quantization.quant_config.get_quant_method"):
|
||||||
'vllm_ascend.quantization.quantizer.AscendQuantizer.get_quantizer',
|
|
||||||
return_value=W8A8Quantizer):
|
|
||||||
moe = TorchairAscendFusedMoE(**default_moe_config,
|
moe = TorchairAscendFusedMoE(**default_moe_config,
|
||||||
quant_config=mock_quant_config)
|
quant_config=mock_quant_config)
|
||||||
|
|
||||||
assert moe.quant_method is not None
|
assert moe.quant_method is not None
|
||||||
assert isinstance(moe.quant_method, AscendFusedMoEMethod)
|
assert isinstance(moe.quant_method, AscendFusedMoEMethod)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.quantization.quantizer import SUPPORT_ASCEND_QUANTIZER_TYPE
|
|
||||||
from vllm_ascend.torchair import utils
|
from vllm_ascend.torchair import utils
|
||||||
|
|
||||||
|
|
||||||
@@ -135,15 +134,3 @@ class TestTorchairUtils(TestBase):
|
|||||||
|
|
||||||
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||||
mock_npu_cast.assert_not_called()
|
mock_npu_cast.assert_not_called()
|
||||||
|
|
||||||
def test_torchair_quant_method_register(self):
|
|
||||||
|
|
||||||
TorchairW8A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
|
|
||||||
"W8A8_DYNAMIC"]
|
|
||||||
TorchairW4A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
|
|
||||||
"W4A8_DYNAMIC"]
|
|
||||||
utils.torchair_quant_method_register()
|
|
||||||
self.assertNotEqual(TorchairW8A8DYNAMICQuantizer,
|
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE["W8A8_DYNAMIC"])
|
|
||||||
self.assertNotEqual(TorchairW4A8DYNAMICQuantizer,
|
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE["W4A8_DYNAMIC"])
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
|
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
||||||
|
|
||||||
from .quantizer import AscendQuantizer
|
from .utils import get_quant_method
|
||||||
|
|
||||||
|
|
||||||
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
|
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
|
||||||
@@ -150,18 +150,15 @@ class AscendQuantConfig(QuantizationConfig):
|
|||||||
class AscendLinearMethod(LinearMethodBase):
|
class AscendLinearMethod(LinearMethodBase):
|
||||||
"""Linear method for Ascend quantization.
|
"""Linear method for Ascend quantization.
|
||||||
|
|
||||||
This class calls AscendQuantizer to search a specific quantization
|
|
||||||
implementations supported on ascend hardware for linear methods.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
quant_config: The Ascend quantization config.
|
quant_config: The Ascend quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||||
packed_modules_mapping: Dict[str, Any]) -> None:
|
packed_modules_mapping: Dict[str, Any]) -> None:
|
||||||
self.quantizer = AscendQuantizer.get_quantizer(
|
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||||
quant_config.quant_description, prefix, packed_modules_mapping)
|
prefix, "linear",
|
||||||
self.quant_method = self.quantizer.build_linear_method()
|
packed_modules_mapping)
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -231,17 +228,13 @@ class AscendLinearMethod(LinearMethodBase):
|
|||||||
class AscendKVCacheMethod(BaseKVCacheMethod):
|
class AscendKVCacheMethod(BaseKVCacheMethod):
|
||||||
"""KVCache method for Ascend quantization.
|
"""KVCache method for Ascend quantization.
|
||||||
|
|
||||||
This class calls AscendQuantizer to search a specific quantization
|
|
||||||
implementations supported on ascend hardware for kvcache methods.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
quant_config: The Ascend quantization config.
|
quant_config: The Ascend quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
||||||
self.quantizer = AscendQuantizer.get_quantizer(
|
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||||
quant_config.quant_description, prefix)
|
prefix, "attention")
|
||||||
self.quant_method = self.quantizer.build_attention_method()
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module) -> None:
|
def create_weights(self, layer: torch.nn.Module) -> None:
|
||||||
# Different from linear method, there are no weight processing/slicing
|
# Different from linear method, there are no weight processing/slicing
|
||||||
@@ -263,18 +256,15 @@ class AscendKVCacheMethod(BaseKVCacheMethod):
|
|||||||
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||||
"""FusedMoE method for Ascend quantization.
|
"""FusedMoE method for Ascend quantization.
|
||||||
|
|
||||||
This class calls AscendQuantizer to search a specific quantization
|
|
||||||
implementations supported on ascend hardware for kvcache methods.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
quant_config: The Ascend quantization config.
|
quant_config: The Ascend quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||||
packed_modules_mapping: Dict[str, Any]):
|
packed_modules_mapping: Dict[str, Any]):
|
||||||
self.quantizer = AscendQuantizer.get_quantizer(
|
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||||
quant_config.quant_description, prefix, packed_modules_mapping)
|
prefix, "moe",
|
||||||
self.quant_method = self.quantizer.build_moe_method()
|
packed_modules_mapping)
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -344,14 +334,13 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
class AscendEmbeddingMethod(AscendLinearMethod):
|
class AscendEmbeddingMethod(AscendLinearMethod):
|
||||||
"""Embedding method for Ascend quantization.
|
"""Embedding method for Ascend quantization.
|
||||||
This class calls AscendQuantizer to search a specific quantization
|
|
||||||
implementations supported on ascend hardware for Embedding methods.
|
|
||||||
Args:
|
Args:
|
||||||
quant_config: The Ascend quantization config.
|
quant_config: The Ascend quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||||
packed_modules_mapping: Dict[str, Any]) -> None:
|
packed_modules_mapping: Dict[str, Any]) -> None:
|
||||||
self.quantizer = AscendQuantizer.get_quantizer(
|
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||||
quant_config.quant_description, prefix, packed_modules_mapping)
|
prefix, "linear",
|
||||||
self.quant_method = self.quantizer.build_linear_method()
|
packed_modules_mapping)
|
||||||
|
|||||||
@@ -1,311 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from vllm.logger import logger
|
|
||||||
|
|
||||||
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
|
|
||||||
wrapper_vocab_parallel_embedding_init)
|
|
||||||
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
|
|
||||||
AscendW4A8DynamicLinearMethod)
|
|
||||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
|
||||||
AscendW8A8LinearMethod)
|
|
||||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
|
||||||
AscendW8A8DynamicLinearMethod)
|
|
||||||
|
|
||||||
CUSTOMIZED_QUANTIZER_TYPE: List[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
class AscendQuantizer:
|
|
||||||
"""An interface to different quantization implementations for ascend hardwares."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_quantizer(cls,
|
|
||||||
quant_config: Dict[str, Any],
|
|
||||||
prefix: str,
|
|
||||||
packed_modules_mapping: Optional[Dict[str,
|
|
||||||
Any]] = dict()):
|
|
||||||
# TODO: Need a param to choose quantization algorithms.
|
|
||||||
quantization_algorithm = ''
|
|
||||||
|
|
||||||
if quantization_algorithm in CUSTOMIZED_QUANTIZER_TYPE:
|
|
||||||
return
|
|
||||||
|
|
||||||
return VLLMAscendQuantizer.get_quantizer(quant_config, prefix,
|
|
||||||
packed_modules_mapping)
|
|
||||||
|
|
||||||
def build_linear_method(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def build_moe_method(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def build_attention_method(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class VLLMAscendQuantizer:
|
|
||||||
_instance: Optional[object] = None
|
|
||||||
patched = False
|
|
||||||
|
|
||||||
def __init__(self, quant_description):
|
|
||||||
if VLLMAscendQuantizer.patched:
|
|
||||||
return
|
|
||||||
for name in quant_description.keys():
|
|
||||||
if "norm.bias" in name:
|
|
||||||
VLLMAscendQuantizer.apply_patch(
|
|
||||||
"vllm.model_executor.layers.layernorm.RMSNorm", "__init__",
|
|
||||||
[wrapper_rmsnorm_init])
|
|
||||||
VLLMAscendQuantizer.apply_patch(
|
|
||||||
"vllm_ascend.ops.layernorm.AscendRMSNorm", "forward_oot",
|
|
||||||
[wrapper_rmsnorm_forward_oot])
|
|
||||||
VLLMAscendQuantizer.apply_patch(
|
|
||||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding",
|
|
||||||
"__init__", [wrapper_vocab_parallel_embedding_init])
|
|
||||||
break
|
|
||||||
VLLMAscendQuantizer.patched = True
|
|
||||||
logger.info("Using the vLLM Ascend Quantizer version now!")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def apply_patch(target_module, target_function, wrappers):
|
|
||||||
|
|
||||||
original_module, original_function = VLLMAscendQuantizer.parse_path(
|
|
||||||
target_module, target_function, False)
|
|
||||||
|
|
||||||
original_function_id = id(original_function)
|
|
||||||
|
|
||||||
candidate = original_function
|
|
||||||
for wrapper in wrappers:
|
|
||||||
candidate = wrapper(candidate)
|
|
||||||
if target_function is not None:
|
|
||||||
setattr(original_module, target_function, candidate)
|
|
||||||
|
|
||||||
for _, value in sys.modules.copy().items():
|
|
||||||
if target_function is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
attr = getattr(value, target_function, None)
|
|
||||||
if attr is not None and id(attr) == original_function_id:
|
|
||||||
setattr(value, target_function, candidate)
|
|
||||||
except ImportError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_path(module_path, function_name, create_dummy):
|
|
||||||
"""
|
|
||||||
Parse module path and resolve/create modules as needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module_path: Dot-separated module path
|
|
||||||
function_name: Target function name (None for module only)
|
|
||||||
create_dummy: Create dummy modules/functions when missing
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (resolved module, target function/none)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ModuleNotFoundError: If module path is invalid and create_dummy=False
|
|
||||||
AttributeError: If function is missing and create_dummy=False
|
|
||||||
"""
|
|
||||||
from importlib.machinery import ModuleSpec
|
|
||||||
|
|
||||||
def create_dummy_module(full_path, parent=None):
|
|
||||||
"""Create and register a placeholder module"""
|
|
||||||
dummy = types.ModuleType(full_path)
|
|
||||||
dummy.__file__ = "vllm_ascend.dummy_module.py"
|
|
||||||
dummy.__spec__ = ModuleSpec(full_path, None)
|
|
||||||
sys.modules[full_path] = dummy
|
|
||||||
if parent:
|
|
||||||
setattr(parent, full_path.split(".")[-1], dummy)
|
|
||||||
return dummy
|
|
||||||
|
|
||||||
def create_placeholder_function(func_name):
|
|
||||||
"""Create dummy function that raises when called"""
|
|
||||||
|
|
||||||
def placeholder(*args, **kwargs):
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Function {func_name} is a placeholder")
|
|
||||||
|
|
||||||
placeholder.__name__ = func_name
|
|
||||||
return placeholder
|
|
||||||
|
|
||||||
modules = module_path.split(".")
|
|
||||||
current_module = None
|
|
||||||
processed_path = []
|
|
||||||
|
|
||||||
for idx, part in enumerate(modules):
|
|
||||||
current_path = ".".join(modules[:idx + 1])
|
|
||||||
parent_path = ".".join(modules[:idx]) if idx > 0 else None
|
|
||||||
|
|
||||||
try:
|
|
||||||
current_module = importlib.import_module(current_path)
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
# Handle missing module
|
|
||||||
parent = importlib.import_module(
|
|
||||||
parent_path) if parent_path else None
|
|
||||||
if parent and hasattr(parent, part):
|
|
||||||
# Use existing attribute from parent
|
|
||||||
current_module = getattr(parent, part)
|
|
||||||
# Check for early function resolution
|
|
||||||
if function_name and hasattr(current_module,
|
|
||||||
function_name):
|
|
||||||
return current_module, getattr(current_module,
|
|
||||||
function_name)
|
|
||||||
if function_name and create_dummy:
|
|
||||||
ph_func = create_placeholder_function(function_name)
|
|
||||||
setattr(current_module, function_name, ph_func)
|
|
||||||
return current_module, ph_func
|
|
||||||
if function_name:
|
|
||||||
raise AttributeError(
|
|
||||||
f"Function {function_name} missing in {current_path}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if not create_dummy:
|
|
||||||
raise
|
|
||||||
# Create and register dummy module
|
|
||||||
current_module = create_dummy_module(
|
|
||||||
current_path,
|
|
||||||
parent=importlib.import_module(parent_path)
|
|
||||||
if parent_path else None)
|
|
||||||
|
|
||||||
processed_path.append(part)
|
|
||||||
|
|
||||||
# Final function handling
|
|
||||||
final_module = sys.modules[module_path]
|
|
||||||
if function_name is not None:
|
|
||||||
if not hasattr(final_module, function_name):
|
|
||||||
if create_dummy:
|
|
||||||
ph_func = create_placeholder_function(function_name)
|
|
||||||
setattr(final_module, function_name, ph_func)
|
|
||||||
else:
|
|
||||||
setattr(final_module, function_name, None)
|
|
||||||
return final_module, getattr(final_module, function_name)
|
|
||||||
|
|
||||||
return final_module, None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_linear_method():
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Linear method is not implemented for the current quant type.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_moe_method():
|
|
||||||
raise NotImplementedError(
|
|
||||||
"MoE method is not implemented for the current quant type.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_attention_method():
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Attention method is not implemented for the current quant type.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str,
|
|
||||||
packed_modules_mapping: Dict[str, Any]):
|
|
||||||
proj_name = prefix.split(".")[-1]
|
|
||||||
if proj_name in packed_modules_mapping:
|
|
||||||
quant_type = None
|
|
||||||
shard_prefixes = [
|
|
||||||
prefix.replace(proj_name, shard_proj_name)
|
|
||||||
for shard_proj_name in packed_modules_mapping[proj_name]
|
|
||||||
]
|
|
||||||
for shard_prefix in shard_prefixes:
|
|
||||||
shard_quant_type = quant_description[shard_prefix + '.weight']
|
|
||||||
|
|
||||||
if quant_type is None:
|
|
||||||
quant_type = shard_quant_type
|
|
||||||
elif shard_quant_type != quant_type:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not all shards of {prefix} are quantized with same quant type."
|
|
||||||
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
|
|
||||||
f"use {quant_type}. Please check quantization config.")
|
|
||||||
else:
|
|
||||||
quant_type = quant_description[prefix + '.weight']
|
|
||||||
return quant_type
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_quantizer(cls,
|
|
||||||
quant_description: Dict[str, Any],
|
|
||||||
prefix: str,
|
|
||||||
packed_modules_mapping: Optional[Dict[str, Any]] = None):
|
|
||||||
if packed_modules_mapping is None:
|
|
||||||
packed_modules_mapping = dict()
|
|
||||||
# Attention
|
|
||||||
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
|
|
||||||
quant_type = quant_description['fa_quant_type']
|
|
||||||
# Use KVCache int8
|
|
||||||
elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
|
|
||||||
quant_type = quant_description['kv_quant_type']
|
|
||||||
# Linear
|
|
||||||
else:
|
|
||||||
quant_type = cls.get_linear_quant_type(quant_description, prefix,
|
|
||||||
packed_modules_mapping)
|
|
||||||
if quant_type in SUPPORT_ASCEND_QUANTIZER_TYPE.keys():
|
|
||||||
cls = SUPPORT_ASCEND_QUANTIZER_TYPE[quant_type]
|
|
||||||
if not cls._instance:
|
|
||||||
cls._instance = cls(quant_description)
|
|
||||||
return cls._instance
|
|
||||||
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
|
|
||||||
f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}")
|
|
||||||
|
|
||||||
|
|
||||||
class W4A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_linear_method():
|
|
||||||
return AscendW4A8DynamicLinearMethod()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_moe_method():
|
|
||||||
return AscendW4A8DynamicFusedMoEMethod()
|
|
||||||
|
|
||||||
|
|
||||||
class W8A8Quantizer(VLLMAscendQuantizer):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_linear_method():
|
|
||||||
return AscendW8A8LinearMethod()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_moe_method():
|
|
||||||
return AscendW8A8FusedMoEMethod()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_attention_method():
|
|
||||||
return AscendC8KVCacheMethod()
|
|
||||||
|
|
||||||
|
|
||||||
class W8A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_linear_method():
|
|
||||||
return AscendW8A8DynamicLinearMethod()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_moe_method():
|
|
||||||
return AscendW8A8DynamicFusedMoEMethod()
|
|
||||||
|
|
||||||
|
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE = {
|
|
||||||
"W4A8_DYNAMIC": W4A8DYNAMICQuantizer,
|
|
||||||
"W8A8": W8A8Quantizer,
|
|
||||||
"W8A8_DYNAMIC": W8A8DYNAMICQuantizer,
|
|
||||||
"C8": W8A8Quantizer,
|
|
||||||
}
|
|
||||||
222
vllm_ascend/quantization/utils.py
Normal file
222
vllm_ascend/quantization/utils.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from typing import Any, Dict, Optional, Type
|
||||||
|
|
||||||
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
|
||||||
|
wrapper_vocab_parallel_embedding_init)
|
||||||
|
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
|
||||||
|
AscendW4A8DynamicLinearMethod)
|
||||||
|
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||||
|
AscendW8A8LinearMethod)
|
||||||
|
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||||
|
AscendW8A8DynamicLinearMethod)
|
||||||
|
|
||||||
|
patched = False
|
||||||
|
|
||||||
|
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||||
|
"W4A8_DYNAMIC": {
|
||||||
|
"linear": AscendW4A8DynamicLinearMethod,
|
||||||
|
"moe": AscendW4A8DynamicFusedMoEMethod,
|
||||||
|
},
|
||||||
|
"W8A8": {
|
||||||
|
"linear": AscendW8A8LinearMethod,
|
||||||
|
"moe": AscendW8A8FusedMoEMethod,
|
||||||
|
"attention": AscendC8KVCacheMethod,
|
||||||
|
},
|
||||||
|
"W8A8_DYNAMIC": {
|
||||||
|
"linear": AscendW8A8DynamicLinearMethod,
|
||||||
|
"moe": AscendW8A8DynamicFusedMoEMethod,
|
||||||
|
},
|
||||||
|
"C8": {
|
||||||
|
"attention": AscendC8KVCacheMethod,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str,
|
||||||
|
packed_modules_mapping: Dict[str, Any]):
|
||||||
|
proj_name = prefix.split(".")[-1]
|
||||||
|
if proj_name in packed_modules_mapping:
|
||||||
|
quant_type = None
|
||||||
|
shard_prefixes = [
|
||||||
|
prefix.replace(proj_name, shard_proj_name)
|
||||||
|
for shard_proj_name in packed_modules_mapping[proj_name]
|
||||||
|
]
|
||||||
|
for shard_prefix in shard_prefixes:
|
||||||
|
shard_quant_type = quant_description[shard_prefix + '.weight']
|
||||||
|
|
||||||
|
if quant_type is None:
|
||||||
|
quant_type = shard_quant_type
|
||||||
|
elif shard_quant_type != quant_type:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not all shards of {prefix} are quantized with same quant type."
|
||||||
|
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
|
||||||
|
f"use {quant_type}. Please check quantization config.")
|
||||||
|
else:
|
||||||
|
quant_type = quant_description[prefix + '.weight']
|
||||||
|
return quant_type
|
||||||
|
|
||||||
|
|
||||||
|
def get_quant_method(quant_description: Dict[str, Any],
|
||||||
|
prefix: str,
|
||||||
|
layer_type: str,
|
||||||
|
packed_modules_mapping: Optional[Dict[str, Any]] = None):
|
||||||
|
apply_quantization_patch(quant_description)
|
||||||
|
if packed_modules_mapping is None:
|
||||||
|
packed_modules_mapping = dict()
|
||||||
|
# Attention
|
||||||
|
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
|
||||||
|
quant_type = quant_description['fa_quant_type']
|
||||||
|
# Use KVCache int8
|
||||||
|
elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
|
||||||
|
quant_type = quant_description['kv_quant_type']
|
||||||
|
# Linear
|
||||||
|
else:
|
||||||
|
quant_type = get_linear_quant_type(quant_description, prefix,
|
||||||
|
packed_modules_mapping)
|
||||||
|
if quant_type in ASCEND_QUANTIZATION_METHOD_MAP.keys():
|
||||||
|
method_map = ASCEND_QUANTIZATION_METHOD_MAP[quant_type]
|
||||||
|
if layer_type in method_map.keys():
|
||||||
|
method_cls = method_map[layer_type]
|
||||||
|
return method_cls()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}."
|
||||||
|
)
|
||||||
|
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
|
||||||
|
f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_quantization_patch(quant_description):
|
||||||
|
global patched
|
||||||
|
if patched:
|
||||||
|
return
|
||||||
|
for name in quant_description.keys():
|
||||||
|
if "norm.bias" in name:
|
||||||
|
apply_patch("vllm.model_executor.layers.layernorm.RMSNorm",
|
||||||
|
"__init__", [wrapper_rmsnorm_init])
|
||||||
|
apply_patch("vllm_ascend.ops.layernorm.AscendRMSNorm",
|
||||||
|
"forward_oot", [wrapper_rmsnorm_forward_oot])
|
||||||
|
apply_patch(
|
||||||
|
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding",
|
||||||
|
"__init__", [wrapper_vocab_parallel_embedding_init])
|
||||||
|
break
|
||||||
|
patched = True
|
||||||
|
logger.info("Using the vLLM Ascend Quantization now!")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_patch(target_module, target_function, wrappers):
|
||||||
|
|
||||||
|
original_module, original_function = parse_path(target_module,
|
||||||
|
target_function, False)
|
||||||
|
|
||||||
|
original_function_id = id(original_function)
|
||||||
|
|
||||||
|
candidate = original_function
|
||||||
|
for wrapper in wrappers:
|
||||||
|
candidate = wrapper(candidate)
|
||||||
|
if target_function is not None:
|
||||||
|
setattr(original_module, target_function, candidate)
|
||||||
|
|
||||||
|
for _, value in sys.modules.copy().items():
|
||||||
|
if target_function is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
attr = getattr(value, target_function, None)
|
||||||
|
if attr is not None and id(attr) == original_function_id:
|
||||||
|
setattr(value, target_function, candidate)
|
||||||
|
except ImportError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def parse_path(module_path, function_name, create_dummy):
|
||||||
|
"""
|
||||||
|
Parse module path and resolve/create modules as needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_path: Dot-separated module path
|
||||||
|
function_name: Target function name (None for module only)
|
||||||
|
create_dummy: Create dummy modules/functions when missing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (resolved module, target function/none)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ModuleNotFoundError: If module path is invalid and create_dummy=False
|
||||||
|
AttributeError: If function is missing and create_dummy=False
|
||||||
|
"""
|
||||||
|
from importlib.machinery import ModuleSpec
|
||||||
|
|
||||||
|
def create_dummy_module(full_path, parent=None):
|
||||||
|
"""Create and register a placeholder module"""
|
||||||
|
dummy = types.ModuleType(full_path)
|
||||||
|
dummy.__file__ = "vllm_ascend.dummy_module.py"
|
||||||
|
dummy.__spec__ = ModuleSpec(full_path, None)
|
||||||
|
sys.modules[full_path] = dummy
|
||||||
|
if parent:
|
||||||
|
setattr(parent, full_path.split(".")[-1], dummy)
|
||||||
|
return dummy
|
||||||
|
|
||||||
|
def create_placeholder_function(func_name):
|
||||||
|
"""Create dummy function that raises when called"""
|
||||||
|
|
||||||
|
def placeholder(*args, **kwargs):
|
||||||
|
raise NotImplementedError(f"Function {func_name} is a placeholder")
|
||||||
|
|
||||||
|
placeholder.__name__ = func_name
|
||||||
|
return placeholder
|
||||||
|
|
||||||
|
modules = module_path.split(".")
|
||||||
|
current_module = None
|
||||||
|
processed_path = []
|
||||||
|
|
||||||
|
for idx, part in enumerate(modules):
|
||||||
|
current_path = ".".join(modules[:idx + 1])
|
||||||
|
parent_path = ".".join(modules[:idx]) if idx > 0 else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_module = importlib.import_module(current_path)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# Handle missing module
|
||||||
|
parent = importlib.import_module(
|
||||||
|
parent_path) if parent_path else None
|
||||||
|
if parent and hasattr(parent, part):
|
||||||
|
# Use existing attribute from parent
|
||||||
|
current_module = getattr(parent, part)
|
||||||
|
# Check for early function resolution
|
||||||
|
if function_name and hasattr(current_module, function_name):
|
||||||
|
return current_module, getattr(current_module,
|
||||||
|
function_name)
|
||||||
|
if function_name and create_dummy:
|
||||||
|
ph_func = create_placeholder_function(function_name)
|
||||||
|
setattr(current_module, function_name, ph_func)
|
||||||
|
return current_module, ph_func
|
||||||
|
if function_name:
|
||||||
|
raise AttributeError(
|
||||||
|
f"Function {function_name} missing in {current_path}")
|
||||||
|
else:
|
||||||
|
if not create_dummy:
|
||||||
|
raise
|
||||||
|
# Create and register dummy module
|
||||||
|
current_module = create_dummy_module(
|
||||||
|
current_path,
|
||||||
|
parent=importlib.import_module(parent_path)
|
||||||
|
if parent_path else None)
|
||||||
|
|
||||||
|
processed_path.append(part)
|
||||||
|
|
||||||
|
# Final function handling
|
||||||
|
final_module = sys.modules[module_path]
|
||||||
|
if function_name is not None:
|
||||||
|
if not hasattr(final_module, function_name):
|
||||||
|
if create_dummy:
|
||||||
|
ph_func = create_placeholder_function(function_name)
|
||||||
|
setattr(final_module, function_name, ph_func)
|
||||||
|
else:
|
||||||
|
setattr(final_module, function_name, None)
|
||||||
|
return final_module, getattr(final_module, function_name)
|
||||||
|
|
||||||
|
return final_module, None
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
from vllm_ascend.quantization.quantizer import VLLMAscendQuantizer
|
|
||||||
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
|
|
||||||
TorchairAscendW4A8DynamicFusedMoEMethod,
|
|
||||||
TorchairAscendW4A8DynamicLinearMethod)
|
|
||||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
|
||||||
TorchairAscendW8A8DynamicFusedMoEMethod,
|
|
||||||
TorchairAscendW8A8DynamicLinearMethod)
|
|
||||||
|
|
||||||
|
|
||||||
class TorchairW8A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_linear_method():
|
|
||||||
return TorchairAscendW8A8DynamicLinearMethod()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_moe_method():
|
|
||||||
return TorchairAscendW8A8DynamicFusedMoEMethod()
|
|
||||||
|
|
||||||
|
|
||||||
class TorchairW4A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_linear_method():
|
|
||||||
return TorchairAscendW4A8DynamicLinearMethod()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_moe_method():
|
|
||||||
return TorchairAscendW4A8DynamicFusedMoEMethod()
|
|
||||||
@@ -180,15 +180,22 @@ def register_torchair_model():
|
|||||||
|
|
||||||
|
|
||||||
def torchair_quant_method_register():
|
def torchair_quant_method_register():
|
||||||
from vllm_ascend.quantization.quantizer import \
|
from vllm_ascend.quantization.utils import ASCEND_QUANTIZATION_METHOD_MAP
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE
|
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
|
||||||
from vllm_ascend.torchair.quantization.torchair_quantizer import (
|
TorchairAscendW4A8DynamicFusedMoEMethod,
|
||||||
TorchairW4A8DYNAMICQuantizer, TorchairW8A8DYNAMICQuantizer)
|
TorchairAscendW4A8DynamicLinearMethod)
|
||||||
|
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
||||||
|
TorchairAscendW8A8DynamicFusedMoEMethod,
|
||||||
|
TorchairAscendW8A8DynamicLinearMethod)
|
||||||
|
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE[
|
ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][
|
||||||
"W8A8_DYNAMIC"] = TorchairW8A8DYNAMICQuantizer
|
"linear"] = TorchairAscendW8A8DynamicLinearMethod
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE[
|
ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][
|
||||||
"W4A8_DYNAMIC"] = TorchairW4A8DYNAMICQuantizer
|
"moe"] = TorchairAscendW8A8DynamicFusedMoEMethod
|
||||||
|
ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][
|
||||||
|
"linear"] = TorchairAscendW4A8DynamicLinearMethod
|
||||||
|
ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][
|
||||||
|
"moe"] = TorchairAscendW4A8DynamicFusedMoEMethod
|
||||||
|
|
||||||
|
|
||||||
def torchair_ops_patch():
|
def torchair_ops_patch():
|
||||||
|
|||||||
Reference in New Issue
Block a user