diff --git a/docs/source/user_guide/feature_guide/quantization.md b/docs/source/user_guide/feature_guide/quantization.md index 9e5f56c..5300ad5 100644 --- a/docs/source/user_guide/feature_guide/quantization.md +++ b/docs/source/user_guide/feature_guide/quantization.md @@ -108,18 +108,19 @@ Please convert DeepSeek series models using `br_release_MindStudio_8.1.RC2_TR5_2 ### 3. When converting deepseek series models with modelslim, what should you pay attention? -When using the weight generated by modelslim with the `--dynamic` parameter, if torchair graph mode is enabled, please modify the configuration file in the CANN package to prevent incorrect inference results. +When the mla portion of the weights used `W8A8_DYNAMIC` quantization, if torchair graph mode is enabled, please modify the configuration file in the CANN package to prevent incorrect inference results. The operation steps are as follows: 1. Search in the CANN package directory used, for example: find /usr/local/Ascend/ -name fusion_config.json -2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows: +2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` and `"MultiAddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows: ```bash { "Switch":{ "GraphFusion":{ "AddRmsNormDynamicQuantFusionPass":"off", + "MultiAddRmsNormDynamicQuantFusionPass":"off", ``` diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 72897e3..f3348d8 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -35,6 +35,11 @@ QWEN_DENSE_MODELS = [ "vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8" ] +DEEPSEEK_W4A8_MODELS = [ + "vllm-ascend/DeepSeek-V3-W4A8-Pruing", + "vllm-ascend/DeepSeek-V3.1-W4A8-puring" +] + def test_models_distributed_QwQ(): example_prompts = [ @@ -109,14 +114,15 @@ def test_models_distributed_Qwen3_W4A8DYNAMIC(): vllm_model.generate_greedy(example_prompts, max_tokens) +@pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS) @patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"}) -def test_models_distributed_DeepSeek_W4A8DYNAMIC(): +def test_models_distributed_DeepSeek_W4A8DYNAMIC(model): prompts = [ "Hello, my name is", ] max_tokens = 5 with VllmRunner( - snapshot_download("vllm-ascend/DeepSeek-V3-W4A8-Pruing"), + snapshot_download(model), dtype="auto", tensor_parallel_size=2, quantization="ascend", diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index d12bbe1..a14702b 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -1,4 +1,3 @@ -import copy from unittest.mock import Mock, patch import torch @@ -95,19 +94,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): # old quant version weight param_dict = self.quant_method.get_dynamic_quant_param( self.experts, self.input_size, self.output_size, torch.bfloat16) - self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32) self.assertEqual(param_dict["w13_weight_scale"].shape, (self.experts, 2 * self.input_size, 1)) self.assertEqual(param_dict["w13_weight_scale_second"].dtype, - torch.bfloat16) + torch.float32) self.assertEqual(param_dict["w13_weight_scale_second"].shape, (self.experts, 2 * self.input_size, self.output_size // self.group_size)) - self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32) self.assertEqual(param_dict["w2_weight_scale"].shape, (self.experts, self.output_size, 1)) self.assertEqual(param_dict["w2_weight_scale_second"].dtype, - torch.bfloat16) + torch.float32) self.assertEqual(param_dict["w2_weight_scale_second"].shape, (self.experts, self.output_size, self.input_size // self.group_size)) @@ -119,40 +118,87 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): self.assertEqual( param_dict["w2_scale_bias"].shape, (self.experts, self.output_size, 16 // self.quant_method.tp_size)) + # per-channel weight + self.quant_method.is_per_channel_weight = True + param_dict = self.quant_method.get_dynamic_quant_param( + self.experts, self.input_size, self.output_size, torch.bfloat16) + pergroup_param = [ + "w13_weight_scale_second", "w13_weight_offset_second", + "w2_weight_scale_second", "w2_weight_offset_second" + ] + is_contains = any(key in param_dict for key in pergroup_param) + self.assertFalse(is_contains) + def build_layer(self, + is_new_quant_version=True, + is_per_channel_weight=False): + layer = torch.nn.Module() + if is_new_quant_version: + layer.w13_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.input_size, self.output_size), + dtype=torch.int8), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.output_size // 2, self.input_size), + dtype=torch.int8), + requires_grad=False) + w13_scale_bias = torch.zeros( + (self.experts, 2 * self.input_size, 1), dtype=torch.float32) + layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, + requires_grad=False) + w2_scale_bias = torch.zeros((self.experts, self.output_size, + 16 // self.quant_method.tp_size), + dtype=torch.float32) + layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, + requires_grad=False) + else: + layer.w13_weight = torch.nn.Parameter(torch.zeros( + (self.experts, 2 * self.input_size, self.output_size), + dtype=torch.int8), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.output_size, self.input_size), + dtype=torch.int8), + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + (self.experts, 2 * self.input_size, 1), dtype=torch.float32), + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(torch.ones( + (self.experts, self.output_size, 1), dtype=torch.float32), + requires_grad=False) + if not is_per_channel_weight: + layer.w13_weight_scale_second = torch.nn.Parameter( + torch.ones((self.experts, 2 * self.input_size, + self.output_size // self.group_size), + dtype=torch.float32), + requires_grad=False) + layer.w13_weight_offset_second = torch.nn.Parameter( + torch.empty_like(layer.w13_weight_scale_second.data), + requires_grad=False) + layer.w2_weight_scale_second = torch.nn.Parameter( + torch.ones((self.experts, self.output_size, + self.input_size // self.group_size), + dtype=torch.float32), + requires_grad=False) + layer.w2_weight_offset_second = torch.nn.Parameter( + torch.empty_like(layer.w2_weight_scale_second.data), + requires_grad=False) + return layer + + @patch('torch_npu.npu_format_cast') @patch('torch_npu.npu_quantize') @patch('torch.Tensor.npu') - def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize): - # old quant version weight - layer = torch.nn.Module() - layer.w13_weight = torch.nn.Parameter(torch.zeros( - (self.experts, 2 * self.input_size, self.output_size), - dtype=torch.int8), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(torch.zeros( - (self.experts, self.output_size, self.input_size), - dtype=torch.int8), - requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(torch.ones( - (self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16), - requires_grad=False) - layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones( - (self.experts, 2 * self.input_size, - self.output_size // self.group_size), - dtype=torch.bfloat16), - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(torch.ones( - (self.experts, self.output_size, 1), dtype=torch.bfloat16), - requires_grad=False) - layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones( - (self.experts, self.output_size, - self.input_size // self.group_size), - dtype=torch.bfloat16), - requires_grad=False) - new_layer = copy.deepcopy(layer) - + def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize, + mock_npu_format_cast): mock_npu.return_value = torch.Tensor() mock_npu_quantize.return_value = torch.Tensor() + + def func_by_args(weight, num_format): + return weight + + mock_npu_format_cast.side_effect = func_by_args + # old quant version weight + layer = self.build_layer(is_new_quant_version=False) self.quant_method.process_weights_after_loading(layer) self.assertTrue(hasattr(layer, "w13_scale_bias")) self.assertEqual(layer.w13_scale_bias.data.shape, @@ -164,23 +210,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32) # new quant version weight self.quant_method.new_quant_version = True - new_layer.w13_weight.data = torch.zeros( - (self.experts, self.input_size, self.output_size), - dtype=torch.int8) - new_layer.w2_weight.data = torch.zeros( - (self.experts, self.output_size // 2, self.input_size), - dtype=torch.int8) - w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1), - dtype=torch.float32) - new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, - requires_grad=False) - w2_scale_bias = torch.zeros( - (self.experts, self.output_size, 16 // self.quant_method.tp_size), - dtype=torch.float32) - new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, - requires_grad=False) + new_layer = self.build_layer(is_new_quant_version=True) self.quant_method.process_weights_after_loading(new_layer) self.assertEqual(new_layer.w13_scale_bias.data.shape, (self.experts, 2 * self.input_size)) self.assertEqual(new_layer.w2_scale_bias.data.shape, (self.experts, self.output_size)) + self.assertFalse(hasattr(new_layer, "w13_weight_scale_second")) + # per-channel weight + self.quant_method.is_per_channel_weight = True + per_channel_layer = self.build_layer(is_new_quant_version=True, + is_per_channel_weight=True) + self.quant_method.process_weights_after_loading(per_channel_layer) + self.assertEqual(new_layer.w13_scale_bias.data.shape, + (self.experts, 2 * self.input_size)) diff --git a/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py b/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py index cd94101..9fd3f29 100644 --- a/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py +++ b/tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py @@ -1,4 +1,3 @@ -import copy from unittest.mock import Mock, patch import torch @@ -85,19 +84,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): # old quant version weight param_dict = self.quant_method.get_dynamic_quant_param( self.experts, self.input_size, self.output_size, torch.bfloat16) - self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32) self.assertEqual(param_dict["w13_weight_scale"].shape, (self.experts, 2 * self.input_size, 1)) self.assertEqual(param_dict["w13_weight_scale_second"].dtype, - torch.bfloat16) + torch.float32) self.assertEqual(param_dict["w13_weight_scale_second"].shape, (self.experts, 2 * self.input_size, self.output_size // self.group_size)) - self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32) self.assertEqual(param_dict["w2_weight_scale"].shape, (self.experts, self.output_size, 1)) self.assertEqual(param_dict["w2_weight_scale_second"].dtype, - torch.bfloat16) + torch.float32) self.assertEqual(param_dict["w2_weight_scale_second"].shape, (self.experts, self.output_size, self.input_size // self.group_size)) @@ -109,40 +108,80 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): self.assertEqual( param_dict["w2_scale_bias"].shape, (self.experts, self.output_size, 16 // self.quant_method.tp_size)) + # per-channel weight + self.quant_method.is_per_channel_weight = True + param_dict = self.quant_method.get_dynamic_quant_param( + self.experts, self.input_size, self.output_size, torch.bfloat16) + pergroup_param = [ + "w13_weight_scale_second", "w13_weight_offset_second", + "w2_weight_scale_second", "w2_weight_offset_second" + ] + is_contains = any(key in param_dict for key in pergroup_param) + self.assertFalse(is_contains) + + def build_layer(self, + is_new_quant_version=True, + is_per_channel_weight=False): + layer = torch.nn.Module() + if is_new_quant_version: + layer.w13_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.input_size, self.output_size), + dtype=torch.int8), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.output_size // 2, self.input_size), + dtype=torch.int8), + requires_grad=False) + w13_scale_bias = torch.zeros( + (self.experts, 2 * self.input_size, 1), dtype=torch.float32) + layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, + requires_grad=False) + w2_scale_bias = torch.zeros((self.experts, self.output_size, + 16 // self.quant_method.tp_size), + dtype=torch.float32) + layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, + requires_grad=False) + else: + layer.w13_weight = torch.nn.Parameter(torch.zeros( + (self.experts, 2 * self.input_size, self.output_size), + dtype=torch.int8), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.zeros( + (self.experts, self.output_size, self.input_size), + dtype=torch.int8), + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + (self.experts, 2 * self.input_size, 1), dtype=torch.float32), + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(torch.ones( + (self.experts, self.output_size, 1), dtype=torch.float32), + requires_grad=False) + if not is_per_channel_weight: + layer.w13_weight_scale_second = torch.nn.Parameter( + torch.ones((self.experts, 2 * self.input_size, + self.output_size // self.group_size), + dtype=torch.float32), + requires_grad=False) + layer.w13_weight_offset_second = torch.nn.Parameter( + torch.empty_like(layer.w13_weight_scale_second.data), + requires_grad=False) + layer.w2_weight_scale_second = torch.nn.Parameter( + torch.ones((self.experts, self.output_size, + self.input_size // self.group_size), + dtype=torch.float32), + requires_grad=False) + layer.w2_weight_offset_second = torch.nn.Parameter( + torch.empty_like(layer.w2_weight_scale_second.data), + requires_grad=False) + return layer @patch('torch_npu.npu_quantize') @patch('torch.Tensor.npu') def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize): - # old quant version weight - layer = torch.nn.Module() - layer.w13_weight = torch.nn.Parameter(torch.zeros( - (self.experts, 2 * self.input_size, self.output_size), - dtype=torch.int8), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(torch.zeros( - (self.experts, self.output_size, self.input_size), - dtype=torch.int8), - requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(torch.ones( - (self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16), - requires_grad=False) - layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones( - (self.experts, 2 * self.input_size, - self.output_size // self.group_size), - dtype=torch.bfloat16), - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(torch.ones( - (self.experts, self.output_size, 1), dtype=torch.bfloat16), - requires_grad=False) - layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones( - (self.experts, self.output_size, - self.input_size // self.group_size), - dtype=torch.bfloat16), - requires_grad=False) - new_layer = copy.deepcopy(layer) - mock_npu.return_value = torch.Tensor() mock_npu_quantize.return_value = torch.Tensor() + # old quant version weight + layer = self.build_layer(is_new_quant_version=False) self.quant_method.process_weights_after_loading(layer) self.assertTrue(hasattr(layer, "w13_scale_bias")) self.assertEqual(layer.w13_scale_bias.data.shape, @@ -154,23 +193,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32) # new quant version weight self.quant_method.new_quant_version = True - new_layer.w13_weight.data = torch.zeros( - (self.experts, self.input_size, self.output_size), - dtype=torch.int8) - new_layer.w2_weight.data = torch.zeros( - (self.experts, self.output_size // 2, self.input_size), - dtype=torch.int8) - w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1), - dtype=torch.float32) - new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias, - requires_grad=False) - w2_scale_bias = torch.zeros( - (self.experts, self.output_size, 16 // self.quant_method.tp_size), - dtype=torch.float32) - new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias, - requires_grad=False) + new_layer = self.build_layer(is_new_quant_version=True) self.quant_method.process_weights_after_loading(new_layer) self.assertEqual(new_layer.w13_scale_bias.data.shape, (self.experts, 2 * self.input_size)) self.assertEqual(new_layer.w2_scale_bias.data.shape, (self.experts, self.output_size)) + self.assertFalse(hasattr(new_layer, "w13_weight_scale_second")) + # per-channel weight + self.quant_method.is_per_channel_weight = True + per_channel_layer = self.build_layer(is_new_quant_version=True, + is_per_channel_weight=True) + self.quant_method.process_weights_after_loading(per_channel_layer) + self.assertEqual(new_layer.w13_scale_bias.data.shape, + (self.experts, 2 * self.input_size)) diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 514bea7..b8bcc78 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -27,6 +27,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ class AscendW4A8DynamicLinearMethod: @@ -132,6 +133,8 @@ class AscendW4A8DynamicFusedMoEMethod: vllm_config = get_current_vllm_config() self.group_size = vllm_config.quant_config.quant_description.get( "group_size", 256) + # NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process + self.is_per_channel_weight = self.group_size == 0 quant_version = vllm_config.quant_config.quant_description.get( "version", "0") # NOTE: new quantize weights: 2 int4 pack into int8 @@ -182,44 +185,44 @@ class AscendW4A8DynamicFusedMoEMethod: num_experts, 2 * intermediate_size_per_partition, 1, - dtype=params_dtype) + dtype=torch.float32) param_dict["w13_weight_offset"] = torch.empty( num_experts, 2 * intermediate_size_per_partition, 1, - dtype=params_dtype) - - param_dict["w13_weight_scale_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=params_dtype) - - param_dict["w13_weight_offset_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=params_dtype) + dtype=torch.float32) param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, - dtype=params_dtype) + dtype=torch.float32) param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, - dtype=params_dtype) - param_dict["w2_weight_scale_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=params_dtype) - param_dict["w2_weight_offset_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=params_dtype) + dtype=torch.float32) + if not self.is_per_channel_weight: + param_dict["w13_weight_scale_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + param_dict["w13_weight_offset_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + + param_dict["w2_weight_scale_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) + param_dict["w2_weight_offset_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) if self.new_quant_version: param_dict["w13_scale_bias"] = torch.empty( @@ -288,8 +291,8 @@ class AscendW4A8DynamicFusedMoEMethod: hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_second, - w2_scale=layer.w2_weight_scale_second, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, w1_scale_bias=layer.w13_scale_bias, w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, @@ -305,6 +308,14 @@ class AscendW4A8DynamicFusedMoEMethod: dynamic_eplb=self.dynamic_eplb) def process_scale(self, weight: torch.Tensor, scale, per_group_scale): + scale = scale.transpose(1, 2).contiguous() + if self.is_per_channel_weight: + scale_np = scale.cpu().numpy() + scale_np.dtype = np.uint32 + scale_uint64_tensor = torch.from_numpy(scale_np.astype( + np.int64)).npu() + return scale_uint64_tensor, None + per_group_scale = per_group_scale.transpose(1, 2).contiguous() group_num, k, n = weight.shape # the weight of the new version is reduced by half by pack n, so it needs to be restored if self.new_quant_version: @@ -347,13 +358,10 @@ class AscendW4A8DynamicFusedMoEMethod: def pack_to_int32(self, weight: torch.Tensor): if self.new_quant_version: - group_num, k, n = weight.shape - assert n % 4 == 0, "the last dim of weight needs to be divided by 4" - packed_n = n // 4 # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 - packed_weight = torch.from_numpy( - np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32)) - return packed_weight.reshape(group_num, k, packed_n).npu() + assert weight.shape[ + -1] % 4 == 0, "the last dim of weight needs to be divided by 4" + return weight.view(torch.int32).contiguous() else: return torch_npu.npu_quantize(weight.to(torch.float32), torch.tensor([1.]).npu(), None, @@ -365,23 +373,29 @@ class AscendW4A8DynamicFusedMoEMethod: 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() - layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( - 1, 2).contiguous() - layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( - 1, 2).contiguous() - layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose( - 1, 2).contiguous() - layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose( - 1, 2).contiguous() - layer.w13_weight_scale_second.data, w13_bias = self.process_scale( + w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr( + layer, "w13_weight_scale_second") else None + w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr( + layer, "w2_weight_scale_second") else None + layer.w13_weight_scale.data, w13_bias = self.process_scale( layer.w13_weight, layer.w13_weight_scale.data, - layer.w13_weight_scale_second.data) - layer.w2_weight_scale_second.data, w2_bias = self.process_scale( + w13_weight_scale_second) + layer.w2_weight_scale.data, w2_bias = self.process_scale( layer.w2_weight, layer.w2_weight_scale.data, - layer.w2_weight_scale_second.data) + w2_weight_scale_second) + if hasattr(layer, "w13_weight_scale_second"): + # scale_second is no longer used, release this part of the memory + del layer.w13_weight_scale_second + del layer.w2_weight_scale_second + del layer.w13_weight_offset_second + del layer.w2_weight_offset_second self.update_bias(layer, w13_bias, w2_bias) + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) diff --git a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py index f38e2d8..02deee8 100644 --- a/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w4a8_dynamic.py @@ -139,6 +139,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: vllm_config = get_current_vllm_config() self.group_size = vllm_config.quant_config.quant_description.get( "group_size", 256) + # NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process + self.is_per_channel_weight = self.group_size == 0 quant_version = vllm_config.quant_config.quant_description.get( "version", "0") # NOTE: new quantize weights: 2 int4 pack into int8 @@ -188,44 +190,45 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: num_experts, 2 * intermediate_size_per_partition, 1, - dtype=params_dtype) + dtype=torch.float32) param_dict["w13_weight_offset"] = torch.empty( num_experts, 2 * intermediate_size_per_partition, 1, - dtype=params_dtype) - - param_dict["w13_weight_scale_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=params_dtype) - - param_dict["w13_weight_offset_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=params_dtype) + dtype=torch.float32) param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, - dtype=params_dtype) + dtype=torch.float32) param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, - dtype=params_dtype) - param_dict["w2_weight_scale_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=params_dtype) - param_dict["w2_weight_offset_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=params_dtype) + dtype=torch.float32) + + if not self.is_per_channel_weight: + param_dict["w13_weight_scale_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + param_dict["w13_weight_offset_second"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.float32) + + param_dict["w2_weight_scale_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) + param_dict["w2_weight_offset_second"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32) if self.new_quant_version: param_dict["w13_scale_bias"] = torch.empty( @@ -318,8 +321,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_second, - w2_scale=layer.w2_weight_scale_second, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, w1_scale_bias=layer.w13_scale_bias, w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, @@ -343,8 +346,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_second, - w2_scale=layer.w2_weight_scale_second, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, w1_scale_bias=layer.w13_scale_bias, w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, @@ -357,6 +360,14 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: ) def process_scale(self, weight: torch.Tensor, scale, per_group_scale): + scale = scale.transpose(1, 2).contiguous() + if self.is_per_channel_weight: + scale_np = scale.cpu().numpy() + scale_np.dtype = np.uint32 + scale_uint64_tensor = torch.from_numpy(scale_np.astype( + np.int64)).npu() + return scale_uint64_tensor, None + per_group_scale = per_group_scale.transpose(1, 2).contiguous() group_num, k, n = weight.shape # the weight of the new version is reduced by half by pack n, so it needs to be restored if self.new_quant_version: @@ -399,13 +410,10 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: def pack_to_int32(self, weight: torch.Tensor): if self.new_quant_version: - group_num, k, n = weight.shape - assert n % 4 == 0, "the last dim of weight needs to be divided by 4" - packed_n = n // 4 # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 - packed_weight = torch.from_numpy( - np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32)) - return packed_weight.reshape(group_num, k, packed_n).npu() + assert weight.shape[ + -1] % 4 == 0, "the last dim of weight needs to be divided by 4" + return weight.view(torch.int32).contiguous() else: return torch_npu.npu_quantize(weight.to(torch.float32), torch.tensor([1.]).npu(), None, @@ -417,21 +425,22 @@ class TorchairAscendW4A8DynamicFusedMoEMethod: 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() - layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( - 1, 2).contiguous() - layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( - 1, 2).contiguous() - layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose( - 1, 2).contiguous() - layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose( - 1, 2).contiguous() - - layer.w13_weight_scale_second.data, w13_bias = self.process_scale( + w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr( + layer, "w13_weight_scale_second") else None + w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr( + layer, "w2_weight_scale_second") else None + layer.w13_weight_scale.data, w13_bias = self.process_scale( layer.w13_weight, layer.w13_weight_scale.data, - layer.w13_weight_scale_second.data) - layer.w2_weight_scale_second.data, w2_bias = self.process_scale( + w13_weight_scale_second) + layer.w2_weight_scale.data, w2_bias = self.process_scale( layer.w2_weight, layer.w2_weight_scale.data, - layer.w2_weight_scale_second.data) + w2_weight_scale_second) + if hasattr(layer, "w13_weight_scale_second"): + # scale_second is no longer used, release this part of the memory + del layer.w13_weight_scale_second + del layer.w2_weight_scale_second + del layer.w13_weight_offset_second + del layer.w2_weight_offset_second self.update_bias(layer, w13_bias, w2_bias)