forked from EngineX-Ascend/enginex-ascend-910-vllm
init v0.11.0rc0
This commit is contained in:
@@ -1,134 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot,
|
||||
wrapper_rmsnorm_init)
|
||||
|
||||
|
||||
class MockRMSNorm:
|
||||
|
||||
def __init__(self, hidden_size: int, **extra_args):
|
||||
self.hidden_size = hidden_size
|
||||
self.weight = torch.ones(hidden_size)
|
||||
self.input_scale = 1.0
|
||||
self.input_offset = 0.0
|
||||
self.variance_epsilon = 1e-6
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
self.ignore_anti = extra_args.get('ignore_anti', True)
|
||||
|
||||
|
||||
class TestFuncWrapper(TestBase):
|
||||
|
||||
def test_wrapper_rmsnorm_init(self):
|
||||
|
||||
@wrapper_rmsnorm_init
|
||||
def init(self, hidden_size: int, **extra_args) -> None:
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
hidden_size = 128
|
||||
extra_args = {'arg1': 'value1'}
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size, **extra_args)
|
||||
init(rms_norm, hidden_size, **extra_args)
|
||||
|
||||
self.assertTrue(hasattr(rms_norm, 'ignore_anti'))
|
||||
self.assertTrue(rms_norm.ignore_anti)
|
||||
|
||||
self.assertTrue(hasattr(rms_norm, 'bias'))
|
||||
self.assertIsInstance(rms_norm.bias, torch.nn.Parameter)
|
||||
self.assertEqual(rms_norm.bias.shape, torch.Size([hidden_size]))
|
||||
self.assertFalse(rms_norm.bias.requires_grad)
|
||||
|
||||
@patch('torch_npu._npu_quant_rms_norm')
|
||||
def test_wrapper_rmsnorm_forward_oot_with_residual(
|
||||
self, mock_npu_quant_rms_norm):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
residual = torch.randn(hidden_size)
|
||||
expected_out = torch.randn(hidden_size)
|
||||
|
||||
mock_npu_quant_rms_norm.return_value = (expected_out, residual)
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x, residual
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = False
|
||||
|
||||
output, res = forward_oot(rms_norm, x, residual)
|
||||
|
||||
mock_npu_quant_rms_norm.assert_called_once()
|
||||
|
||||
args, kwargs = mock_npu_quant_rms_norm.call_args
|
||||
self.assertTrue(torch.equal(args[1], rms_norm.weight))
|
||||
self.assertTrue(torch.equal(args[2], rms_norm.bias))
|
||||
self.assertEqual(args[3], rms_norm.input_scale)
|
||||
self.assertEqual(args[4], rms_norm.input_offset)
|
||||
self.assertEqual(args[5], rms_norm.variance_epsilon)
|
||||
self.assertTrue(torch.equal(res, residual))
|
||||
|
||||
@patch('torch_npu._npu_quant_rms_norm')
|
||||
def test_wrapper_rmsnorm_forward_oot_without_residual(
|
||||
self, mock_npu_quant_rms_norm):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
expected_out = torch.randn(hidden_size)
|
||||
|
||||
mock_npu_quant_rms_norm.return_value = expected_out
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = False
|
||||
|
||||
output = forward_oot(rms_norm, x)
|
||||
|
||||
mock_npu_quant_rms_norm.assert_called_once()
|
||||
|
||||
args, kwargs = mock_npu_quant_rms_norm.call_args
|
||||
self.assertTrue(torch.equal(args[0], x))
|
||||
self.assertTrue(torch.equal(args[1], rms_norm.weight))
|
||||
self.assertTrue(torch.equal(args[2], rms_norm.bias))
|
||||
self.assertEqual(args[3], rms_norm.input_scale)
|
||||
self.assertEqual(args[4], rms_norm.input_offset)
|
||||
self.assertEqual(args[5], rms_norm.variance_epsilon)
|
||||
|
||||
self.assertTrue(torch.equal(output, expected_out))
|
||||
|
||||
def test_wrapper_rmsnorm_forward_oot_ignore_anti_with_residual(self):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
residual = torch.randn(hidden_size)
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x, residual
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = True
|
||||
|
||||
output, res = forward_oot(rms_norm, x, residual)
|
||||
|
||||
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))
|
||||
self.assertTrue(torch.equal(res, residual))
|
||||
|
||||
def test_wrapper_rmsnorm_forward_oot_ignore_anti_no_residual(self):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = True
|
||||
|
||||
output = forward_oot(rms_norm, x)
|
||||
|
||||
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))
|
||||
@@ -73,9 +73,12 @@ class TestAscendQuantConfig(TestBase):
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_quant_method_for_linear(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
linear_layer = MagicMock(spec=LinearBase)
|
||||
# Test skipped layer
|
||||
with patch.object(self.ascend_config,
|
||||
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch.object(self.ascend_config, \
|
||||
'is_layer_skipped_ascend',
|
||||
return_value=True):
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
@@ -83,6 +86,7 @@ class TestAscendQuantConfig(TestBase):
|
||||
|
||||
# Test quantized layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
|
||||
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
@@ -93,14 +97,18 @@ class TestAscendQuantConfig(TestBase):
|
||||
|
||||
def test_get_quant_method_for_attention(self):
|
||||
attention_layer = MagicMock(spec=Attention)
|
||||
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
|
||||
return_value=MagicMock()) as mock_ascend_kvcache:
|
||||
# Test with fa_quant_type
|
||||
method = self.ascend_config.get_quant_method(
|
||||
attention_layer, ".attn")
|
||||
self.assertIs(method, mock_ascend_kvcache.return_value)
|
||||
|
||||
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
|
||||
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
|
||||
return_value=MagicMock()) as mock_ascend_kvcache:
|
||||
# Test with kv_quant_type
|
||||
modified_config = {"kv_quant_type": "C8"}
|
||||
@@ -113,9 +121,12 @@ class TestAscendQuantConfig(TestBase):
|
||||
fused_moe_layer = MagicMock(spec=FusedMoE)
|
||||
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
|
||||
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
|
||||
# Test skipped layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
@@ -123,6 +134,7 @@ class TestAscendQuantConfig(TestBase):
|
||||
|
||||
# Test quantized layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
@@ -156,33 +168,22 @@ class TestAscendKVCacheMethod(TestBase):
|
||||
def setUp(self):
|
||||
# Setup common test fixtures
|
||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
||||
self.prefix = "attention_layer"
|
||||
self.mock_quant_config.quant_description = {"kv_quant_type": "C8"}
|
||||
self.prefix = "layer.attn"
|
||||
|
||||
# Mock the quantizer and quant_method
|
||||
self.mock_quantizer = MagicMock()
|
||||
# Mock quant_method
|
||||
self.mock_quant_method = MagicMock()
|
||||
|
||||
# Patch the AscendQuantizer
|
||||
self.quantizer_patcher = patch(
|
||||
'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer',
|
||||
return_value=self.mock_quantizer)
|
||||
self.mock_get_quantizer = self.quantizer_patcher.start()
|
||||
|
||||
self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method
|
||||
self.patcher = patch(
|
||||
'vllm_ascend.quantization.quant_config.get_quant_method')
|
||||
self.mock_get_quant_method = self.patcher.start()
|
||||
self.mock_get_quant_method.return_value = self.mock_quant_method
|
||||
|
||||
# Create instance
|
||||
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
|
||||
self.prefix)
|
||||
|
||||
def tearDown(self):
|
||||
self.quantizer_patcher.stop()
|
||||
|
||||
def test_init(self):
|
||||
"""Test initialization with proper quantizer setup."""
|
||||
self.mock_get_quantizer.assert_called_once_with(
|
||||
self.mock_quant_config.quant_description, self.prefix)
|
||||
self.mock_quantizer.build_attention_method.assert_called_once()
|
||||
self.patcher.stop()
|
||||
|
||||
def test_create_weights(self):
|
||||
"""Test create_weights delegates to quant_method."""
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
|
||||
W4A8DYNAMICQuantizer,
|
||||
W8A8Quantizer)
|
||||
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
|
||||
|
||||
|
||||
class TestGetQuantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Setup common test fixtures
|
||||
self.supported_types = {
|
||||
'INT8': MagicMock(_instance=None),
|
||||
'FP16': MagicMock(_instance=None),
|
||||
'C8': MagicMock(_instance=None)
|
||||
}
|
||||
self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy()
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types)
|
||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original supported types
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.clear()
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types)
|
||||
|
||||
def test_get_quantizer_fa(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'fa_quant_type': 'C8'}
|
||||
prefix = '.attn'
|
||||
expected_type = 'C8'
|
||||
with patch.dict(
|
||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
def test_get_quantizer_kv(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'kv_quant_type': 'C8'}
|
||||
prefix = '.attn'
|
||||
expected_type = 'C8'
|
||||
with patch.dict(
|
||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
def test_get_quantizer_linear(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'linear_type': 'INT8'}
|
||||
prefix = 'nothing'
|
||||
expected_type = 'INT8'
|
||||
with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type',
|
||||
return_value=expected_type), \
|
||||
patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
|
||||
class TestW8A8Quantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.quantizer = W8A8Quantizer(quant_description={})
|
||||
|
||||
def test_build_linear_method(self):
|
||||
with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_linear_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_moe_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_moe_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_attention_method(self):
|
||||
with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_attention_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
|
||||
class TestW4A8DYNAMICQuantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.quantizer = W4A8DYNAMICQuantizer(quant_description={})
|
||||
|
||||
def test_build_linear_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicLinearMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_linear_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_moe_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicFusedMoEMethod',
|
||||
return_value=MagicMock()) as mock_fused_moe:
|
||||
result = self.quantizer.build_moe_method()
|
||||
mock_fused_moe.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
62
tests/ut/quantization/test_utils.py
Normal file
62
tests/ut/quantization/test_utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import types
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP,
|
||||
get_quant_method)
|
||||
|
||||
|
||||
class TestGetQuantMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy(
|
||||
)
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
for layer_type in layer_map.keys():
|
||||
ASCEND_QUANTIZATION_METHOD_MAP[quant_type][
|
||||
layer_type] = types.new_class(f"{quant_type}_{layer_type}")
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original map
|
||||
ASCEND_QUANTIZATION_METHOD_MAP.clear()
|
||||
ASCEND_QUANTIZATION_METHOD_MAP.update(
|
||||
self.original_quantization_method_map)
|
||||
|
||||
def test_linear_quant_methods(self):
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
if "linear" in layer_map.keys():
|
||||
prefix = "linear_layer"
|
||||
cls = layer_map["linear"]
|
||||
method = get_quant_method({"linear_layer.weight": quant_type},
|
||||
prefix, "linear")
|
||||
self.assertIsInstance(method, cls)
|
||||
|
||||
def test_moe_quant_methods(self):
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
if "moe" in layer_map.keys():
|
||||
prefix = "layer"
|
||||
cls = layer_map["moe"]
|
||||
method = get_quant_method({"layer.weight": quant_type}, prefix,
|
||||
"moe")
|
||||
self.assertIsInstance(method, cls)
|
||||
|
||||
def test_with_fa_quant_type(self):
|
||||
quant_description = {"fa_quant_type": "C8"}
|
||||
method = get_quant_method(quant_description, ".attn", "attention")
|
||||
self.assertIsInstance(
|
||||
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
||||
|
||||
def test_with_kv_quant_type(self):
|
||||
quant_description = {"kv_quant_type": "C8"}
|
||||
method = get_quant_method(quant_description, ".attn", "attention")
|
||||
self.assertIsInstance(
|
||||
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
||||
|
||||
def test_invalid_layer_type(self):
|
||||
quant_description = {"linear_layer.weight": "W8A8"}
|
||||
with self.assertRaises(NotImplementedError):
|
||||
get_quant_method(quant_description, "linear_layer", "unsupported")
|
||||
|
||||
def test_invalid_quant_type(self):
|
||||
quant_description = {"linear_layer.weight": "UNKNOWN"}
|
||||
with self.assertRaises(NotImplementedError):
|
||||
get_quant_method(quant_description, "linear_layer", "linear")
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
@@ -11,8 +10,19 @@ from vllm_ascend.quantization.w4a8_dynamic import (
|
||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = AscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
with patch(
|
||||
'vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config'
|
||||
) as mock_get_current_vllm_config:
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(
|
||||
quant_description={"group_size": 256})
|
||||
mock_vllm_config.scheduler_config = Mock(
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
enable_chunked_prefill=False)
|
||||
mock_get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.method = AscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(8, 32, torch.bfloat16)
|
||||
@@ -37,18 +47,27 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
output_size = 56
|
||||
group_size = 2
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
|
||||
get_current_vllm_config):
|
||||
get_current_vllm_config, mock_get_ascend_config):
|
||||
# Mock ascend config
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.dynamic_eplb = False
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(quant_description={
|
||||
"group_size": self.group_size,
|
||||
"version": "0.0.0"
|
||||
})
|
||||
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
enable_chunked_prefill=False)
|
||||
get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
|
||||
|
||||
@@ -75,19 +94,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
# old quant version weight
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.experts, 2 * self.input_size, 1))
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size))
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
||||
(self.experts, self.output_size, 1))
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size))
|
||||
@@ -99,40 +118,87 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
self.assertEqual(
|
||||
param_dict["w2_scale_bias"].shape,
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
pergroup_param = [
|
||||
"w13_weight_scale_second", "w13_weight_offset_second",
|
||||
"w2_weight_scale_second", "w2_weight_offset_second"
|
||||
]
|
||||
is_contains = any(key in param_dict for key in pergroup_param)
|
||||
self.assertFalse(is_contains)
|
||||
|
||||
def build_layer(self,
|
||||
is_new_quant_version=True,
|
||||
is_per_channel_weight=False):
|
||||
layer = torch.nn.Module()
|
||||
if is_new_quant_version:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
w13_scale_bias = torch.zeros(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
|
||||
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros((self.experts, self.output_size,
|
||||
16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
if not is_per_channel_weight:
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
return layer
|
||||
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('torch_npu.npu_quantize')
|
||||
@patch('torch.Tensor.npu')
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
||||
# old quant version weight
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
new_layer = copy.deepcopy(layer)
|
||||
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize,
|
||||
mock_npu_format_cast):
|
||||
mock_npu.return_value = torch.Tensor()
|
||||
mock_npu_quantize.return_value = torch.Tensor()
|
||||
|
||||
def func_by_args(weight, num_format):
|
||||
return weight
|
||||
|
||||
mock_npu_format_cast.side_effect = func_by_args
|
||||
# old quant version weight
|
||||
layer = self.build_layer(is_new_quant_version=False)
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
||||
self.assertEqual(layer.w13_scale_bias.data.shape,
|
||||
@@ -144,23 +210,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
new_layer.w13_weight.data = torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8)
|
||||
new_layer.w2_weight.data = torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8)
|
||||
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
|
||||
dtype=torch.float32)
|
||||
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros(
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
new_layer = self.build_layer(is_new_quant_version=True)
|
||||
self.quant_method.process_weights_after_loading(new_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(new_layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
per_channel_layer = self.build_layer(is_new_quant_version=True,
|
||||
is_per_channel_weight=True)
|
||||
self.quant_method.process_weights_after_loading(per_channel_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
|
||||
@@ -5,8 +5,8 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.ops.moe.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
|
||||
AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod,
|
||||
@@ -784,7 +784,7 @@ class TestSelectExperts(TestBase):
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
|
||||
@patch('vllm_ascend.ops.moe.experts_selector._native_grouped_topk')
|
||||
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
|
||||
"""Test grouped topk with expert score correction bias"""
|
||||
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
|
||||
|
||||
69
tests/ut/quantization/test_w8a8_dynamic.py
Normal file
69
tests/ut/quantization/test_w8a8_dynamic.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
num_experts = 8
|
||||
hidden_size = 128
|
||||
intermediate_size = 128
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
|
||||
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
|
||||
mock_get_mc2_group, mock_get_rank):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.w8a8_dynamic.get_current_vllm_config'
|
||||
) as mock_get_current_vllm_config:
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(
|
||||
quant_description={"group_size": 256})
|
||||
mock_vllm_config.scheduler_config = Mock(
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
enable_chunked_prefill=False)
|
||||
mock_get_current_vllm_config.return_value = mock_vllm_config
|
||||
mock_ep_group = Mock()
|
||||
mock_get_ep_group.return_value = mock_ep_group
|
||||
mock_ascend_config = Mock()
|
||||
|
||||
# 创建一个具有具体属性的 Mock 对象来表示 ascend_scheduler_config
|
||||
mock_ascend_scheduler_config = Mock()
|
||||
mock_ascend_scheduler_config.enabled = False
|
||||
mock_ascend_scheduler_config.max_num_batched_tokens = 1024
|
||||
mock_ascend_scheduler_config.max_model_len = 2048
|
||||
mock_ascend_config.ascend_scheduler_config = mock_ascend_scheduler_config
|
||||
|
||||
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
||||
mock_ascend_config.enable_chunked_prefill = False
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
mock_mc2_group = Mock(device_group=0)
|
||||
mock_get_mc2_group.return_value = mock_mc2_group
|
||||
mock_rank = Mock()
|
||||
mock_get_rank.return_value = mock_rank
|
||||
|
||||
self.quant_method = AscendW8A8DynamicFusedMoEMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
param_dict = self.quant_method.get_weight(self.num_experts,
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(
|
||||
param_dict["w13_weight"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.num_experts, self.intermediate_size, self.hidden_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, 1))
|
||||
Reference in New Issue
Block a user