[3/N][refactor] refactoer quantization (#2504)
### What this PR does / why we need it? Move torchair related qunatization section into torchair dir to make the code clear. Next step we'll remove all torchair related code outside of torchair quantization. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? vLLM version: main vLLM main:ab9f2cfd19- vLLM version: v0.10.1.1 - vLLM main:959783fb99Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
176
tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py
Normal file
176
tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import copy
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
|
||||
TorchairAscendW4A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW4A8DynamicLinearMethod)
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = TorchairAscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(8, 32, torch.bfloat16)
|
||||
self.assertEqual(weight["weight"].dtype, torch.int8)
|
||||
self.assertEqual(weight["weight"].shape, (32, 8))
|
||||
|
||||
def test_get_pergroup_param(self):
|
||||
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
experts = 8
|
||||
input_size = 16
|
||||
output_size = 56
|
||||
group_size = 2
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group')
|
||||
@patch("vllm_ascend.ascend_config.get_ascend_config")
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group'
|
||||
)
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
|
||||
mock_get_ep_group, get_current_vllm_config):
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(quant_description={
|
||||
"group_size": self.group_size,
|
||||
"version": "0.0.0"
|
||||
})
|
||||
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
|
||||
get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.quant_method = TorchairAscendW4A8DynamicFusedMoEMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
# old quant version w4a8 weight
|
||||
param_dict = self.quant_method.get_weight(self.experts,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(param_dict["w13_weight"].shape,
|
||||
(self.experts, 2 * self.input_size, self.output_size))
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
param_dict = self.quant_method.get_weight(self.experts,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(param_dict["w13_weight"].shape,
|
||||
(self.experts, self.input_size, self.output_size))
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
# old quant version weight
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.experts, 2 * self.input_size, 1))
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size))
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
||||
(self.experts, self.output_size, 1))
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size))
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
|
||||
self.assertEqual(
|
||||
param_dict["w2_scale_bias"].shape,
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
|
||||
|
||||
@patch('torch_npu.npu_quantize')
|
||||
@patch('torch.Tensor.npu')
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
||||
# old quant version weight
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
new_layer = copy.deepcopy(layer)
|
||||
|
||||
mock_npu.return_value = torch.Tensor()
|
||||
mock_npu_quantize.return_value = torch.Tensor()
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
||||
self.assertEqual(layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
|
||||
self.assertTrue(hasattr(layer, "w2_scale_bias"))
|
||||
self.assertEqual(layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
new_layer.w13_weight.data = torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8)
|
||||
new_layer.w2_weight.data = torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8)
|
||||
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
|
||||
dtype=torch.float32)
|
||||
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros(
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
self.quant_method.process_weights_after_loading(new_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(new_layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
75
tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py
Normal file
75
tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
torchair_fused_experts_with_all2all
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.hidden_size = 128
|
||||
self.num_tokens = 128
|
||||
self.placeholder = torch.randn(self.num_tokens,
|
||||
self.hidden_size,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
@patch("torch_npu.npu_moe_re_routing")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_dynamic_quant")
|
||||
@patch("torch_npu.npu_moe_finalize_routing")
|
||||
@patch("torch_npu.npu_moe_init_routing")
|
||||
def test_torchair_fused_experts_with_all2all(
|
||||
self, mock_moe_init_routing, mock_moe_finalize_routing,
|
||||
mock_dynamic_quant, mock_swiglu, mock_grouped_matmul,
|
||||
mock_moe_re_routing, mock_all_to_all_single):
|
||||
|
||||
expert_map = MagicMock()
|
||||
ep_group = MagicMock()
|
||||
placeholder_int8 = torch.randint(0,
|
||||
100,
|
||||
(self.num_tokens, self.hidden_size),
|
||||
dtype=torch.int8)
|
||||
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
mock_moe_init_routing.return_value = (
|
||||
placeholder_int8,
|
||||
placeholder_ones,
|
||||
placeholder_ones,
|
||||
)
|
||||
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
|
||||
torch.randint(0,
|
||||
100,
|
||||
(self.num_tokens, ),
|
||||
dtype=torch.int32),
|
||||
self.placeholder)
|
||||
mock_grouped_matmul.return_value = self.placeholder
|
||||
mock_swiglu.return_value = self.placeholder
|
||||
mock_dynamic_quant.return_value = (
|
||||
placeholder_int8,
|
||||
torch.randn(self.num_tokens),
|
||||
)
|
||||
mock_moe_finalize_routing.return_value = self.placeholder
|
||||
|
||||
result = torchair_fused_experts_with_all2all(
|
||||
hidden_states=self.placeholder,
|
||||
w1=self.placeholder,
|
||||
w1_scale=self.placeholder,
|
||||
w2=self.placeholder,
|
||||
w2_scale=self.placeholder,
|
||||
topk_weights=self.placeholder,
|
||||
topk_ids=self.placeholder,
|
||||
top_k=8,
|
||||
expert_map=expert_map,
|
||||
ep_group=ep_group,
|
||||
log2phy=None,
|
||||
global_redundant_expert_num=256,
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
self.assertEqual(result.shape, (128, 128))
|
||||
Reference in New Issue
Block a user