diff --git a/examples/save_sharded_state_310.py b/examples/save_sharded_state_310.py index fb7acabe..0787ae17 100644 --- a/examples/save_sharded_state_310.py +++ b/examples/save_sharded_state_310.py @@ -24,12 +24,15 @@ Sparse-Compress-Quantization state dict could also be saved via this script. Example usage: -python save_sharded_state.py \ +python save_sharded_state_310.py \ --model /path/to/load \ --tensor-parallel-size 8 \ --output /path/to/save \ --enable-compress \ - --compress-process-num 8 + --compress-process-num 8 \ + --enforce-eager \ + --dtype float16 \ + --quantization ascend Then, the model can be loaded with @@ -140,29 +143,30 @@ def get_quant_description(json_file: str) -> dict: return quant_desc -def update_quant_description(json_file: str) -> None: +def update_quant_description(ori_json_file: str, target_json_file: str) -> None: """ Update quantization types in JSON configuration file based on update mapping. Args: - json_file: Path to the JSON configuration file + ori_json_file: Path to the JSON configuration file + target_json_file: Path to the JSON configuration file to be saved Raises: FileNotFoundError: If the JSON file does not exist RuntimeError: If JSON parsing fails or required keys are missing """ - config_path = Path(json_file) + config_path = Path(ori_json_file) try: with config_path.open("r", encoding="utf-8") as file: json_data = json.load(file) except (FileNotFoundError, json.JSONDecodeError) as e: - raise RuntimeError(f"Failed to read configuration file {json_file}: {e}") + raise RuntimeError(f"Failed to read configuration file {ori_json_file}: {e}") original_quant_type = json_data.get("model_quant_type") if not original_quant_type or original_quant_type not in QUANTIZATION_UPDATE_MAP: raise RuntimeError( f"Cannot update quantization type. " - f"Original type '{original_quant_type}' not found or not supported for update in {json_file}." + f"Original type '{original_quant_type}' not found or not supported for update in {ori_json_file}." ) updated_quant_type = QUANTIZATION_UPDATE_MAP[original_quant_type] @@ -175,12 +179,12 @@ def update_quant_description(json_file: str) -> None: updated_config[key] = value try: - new_file_path = config_path.parent / "quant_model_description.json" + new_file_path = Path(target_json_file) with new_file_path.open("w", encoding="utf-8") as file: json.dump(updated_config, file, indent=2, ensure_ascii=False) - os.remove(json_file) + os.remove(ori_json_file) except OSError as e: - raise RuntimeError(f"Failed to write updated configuration to {json_file}: {e}") + raise RuntimeError(f"Failed to write updated configuration to {target_json_file}: {e}") def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -> bool: @@ -214,9 +218,6 @@ def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) - compressor.run() if p.exists(): os.remove(p) - ori_quant_desc_file = p.parent / "quant_model_description.json" - if ori_quant_desc_file.exists(): - os.rename(str(ori_quant_desc_file), str(ori_quant_desc_file.parent / "ori_quant_model_description.json")) compressor.export_safetensors(str(p.parent), safetensors_name=p.name) return True except Exception as e: @@ -248,6 +249,10 @@ def main(args): # 4. Compression Logic parameters_map_fpath = output_dir / "parameters_type_map.json" if args.enable_compress: + quant_desc_file = output_dir / "quant_model_description.json" + backup_quant_desc_file = output_dir / "ori_quant_model_description.json" + if quant_desc_file.exists(): + os.rename(str(quant_desc_file), str(backup_quant_desc_file)) quant_desc = get_quant_description(str(parameters_map_fpath)) quant_type = quant_desc["model_quant_type"] if quant_type in SUPPORTED_COMPRESS_QUANT_TYPE: @@ -269,7 +274,7 @@ def main(args): for p in tasks: p.join() - update_quant_description(os.path.join(args.output, "ori_quant_model_description.json")) + update_quant_description(str(backup_quant_desc_file), str(quant_desc_file)) print("Compression completed successfully.") else: print(f"Skipping compression: Unsupported type {quant_type}") diff --git a/tests/ut/_310p/quantization/test_w8a8sc_310.py b/tests/ut/_310p/quantization/test_w8a8sc_310.py new file mode 100644 index 00000000..0db4b978 --- /dev/null +++ b/tests/ut/_310p/quantization/test_w8a8sc_310.py @@ -0,0 +1,122 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from tests.ut.base import TestBase +from vllm_ascend._310p.quantization.methods.w8a8sc import AscendW8A8SCLinearMethod310 + + +class TestAscendW8A8SCLinearMethod310(TestBase): + + def setUp(self): + self.method = AscendW8A8SCLinearMethod310() + + def test_get_weight_310(self): + weight = self.method.get_weight(10, 20) + self.assertEqual(weight["weight"].dtype, torch.int8) + self.assertEqual(weight["weight"].shape, (10 * 20, )) + self.assertEqual(weight["index"].dtype, torch.int8) + index_len = math.ceil(10 / 256) * math.ceil(20 / 128) * 8 + self.assertEqual(weight["index"].shape, (index_len, )) + self.assertEqual(weight["info"].dtype, torch.int64) + self.assertEqual(weight["info"].shape, (5, )) + + def test_get_pertensor_param_310(self): + params = self.method.get_pertensor_param(torch.float16) + self.assertEqual(params["input_scale"].dtype, torch.float16) + self.assertEqual(params["input_offset"].dtype, torch.int8) + self.assertEqual(params["input_scale"].shape, (1, )) + self.assertEqual(params["input_offset"].shape, (1, )) + + def test_get_perchannel_param_310(self): + params = self.method.get_perchannel_param(10, torch.float16) + + self.assertEqual(params["quant_bias"].dtype, torch.int32) + self.assertEqual(params["deq_scale"].dtype, torch.int64) + self.assertEqual(params["quant_bias"].shape, (10, )) + self.assertEqual(params["deq_scale"].shape, (10, )) + + @pytest.mark.skip( + "Skip as npu_matmul_compress_dequant will be supported in PTA 26.0.0.") + @patch("torch.ops.vllm.quantize") + @patch("torch_npu.npu_matmul_compress_dequant") + def test_apply_with_x_not_int8_310(self, mock_matmul_compress_dequant, + mock_quantize): + layer = MagicMock() + layer.aclnn_input_scale = torch.randn(256) + layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale + layer.aclnn_input_offset = torch.randint(-128, + 127, (256, ), + dtype=torch.int8) + layer.weight = torch.randint(-128, + 127, (256 * 128, ), + dtype=torch.int8) + layer.index = torch.randint(-128, 127, (8, ), dtype=torch.int8) + layer.deq_scale = torch.randn(128) + layer.quant_bias = torch.randint(-128, 127, (256, )) + layer.params_dtype = torch.float16 + + x = torch.randn(32, 128) + expect_x_output = torch.randint(-128, 127, x.shape, dtype=torch.int8) + mock_quantize.return_value = expect_x_output + + expected_y_output = torch.randn(32, 256) + mock_matmul_compress_dequant.return_value = expected_y_output + + output = self.method.apply(layer, x, tp_rank=0) + + mock_quantize.assert_called_with(x, layer.aclnn_input_scale, + layer.aclnn_input_scale_reciprocal, + layer.aclnn_input_offset) + mock_matmul_compress_dequant.assert_called_with( + expect_x_output, layer.weight, layer.index, layer.quant_bias, + layer.deq_scale) + self.assertTrue(torch.equal(output, expected_y_output)) + + @pytest.mark.skip( + "Skip as npu_matmul_compress_dequant will be supported in PTA 26.0.0.") + @patch("torch.ops.vllm.quantize") + @patch("torch_npu.npu_matmul_compress_dequant") + def test_apply_with_x_is_int8_310(self, mock_matmul_compress_dequant, + mock_quantize): + layer = MagicMock() + layer.aclnn_input_scale = torch.randn(256) + layer.aclnn_input_offset = torch.randint(-128, + 127, (256, ), + dtype=torch.int8) + layer.weight = torch.randint(-128, + 127, (256 * 128, ), + dtype=torch.int8) + layer.index = torch.randint(-128, 127, (8, ), dtype=torch.int8) + layer.deq_scale = torch.randn(128) + layer.quant_bias = torch.randint(-128, 127, (256, )) + layer.params_dtype = torch.float16 + + x = torch.randint(-128, 127, (32, 128), dtype=torch.int8) + + expected_y_output = torch.randn(32, 256) + mock_matmul_compress_dequant.return_value = expected_y_output + + output = self.method.apply(layer, x, tp_rank=0) + + mock_quantize.assert_not_called() + mock_matmul_compress_dequant.assert_called_with( + x, layer.weight, layer.index, layer.quant_bias, layer.deq_scale) + self.assertTrue(torch.equal(output, expected_y_output)) diff --git a/vllm_ascend/_310p/quantization/methods/__init__.py b/vllm_ascend/_310p/quantization/methods/__init__.py index 4399207a..9bce13ff 100644 --- a/vllm_ascend/_310p/quantization/methods/__init__.py +++ b/vllm_ascend/_310p/quantization/methods/__init__.py @@ -19,4 +19,5 @@ from . import ( w8a8_dynamic, # noqa: F401 w8a8_static, # noqa: F401 w8a8s, # noqa: F401 + w8a8sc, # noqa: F401 ) diff --git a/vllm_ascend/_310p/quantization/methods/w8a8sc.py b/vllm_ascend/_310p/quantization/methods/w8a8sc.py new file mode 100644 index 00000000..76de861b --- /dev/null +++ b/vllm_ascend/_310p/quantization/methods/w8a8sc.py @@ -0,0 +1,116 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import math +from typing import Any + +import torch +import torch_npu +from vllm.distributed import get_tensor_model_parallel_rank + +from vllm_ascend.ops.linear import AscendRowParallelLinear +from vllm_ascend.quantization.methods.base import AscendLinearScheme + +from .registry import register_scheme + + +@register_scheme("W8A8SC", "linear") +class AscendW8A8SCLinearMethod310(AscendLinearScheme): + """310P-only W8A8SC static linear scheme. + + Notes: + - This scheme is discovered via 310P local registry. + """ + + def get_weight( + self, + input_size: int, + output_size: int, + params_dtype: torch.dtype = torch.float16, + ) -> dict[str, Any]: + """ + Get the weight tensors for the W8A8SC quantization scheme. + + Args: + input_size: Size of the input dimension (k) + output_size: Size of the output dimension (n) + params_dtype: Data type for parameters, default is torch.float16 + + Returns: + A dictionary containing: + - "weight": The compressed weight tensor with shape [c], where c is greater than 0 + and not larger than k * n + - "index": Compression index generated simultaneously with compressed weights, + with shape [x], where x = k_index * n_index * 8, k_index = ceil(k1 / tilingK), + n_index = ceil(n1 / tilingN), k1 = k / 32, n1 = n / 16 + - "info": Compression information with length 5, containing compression block + information tilingN, tilingK, original shape of the pre-compression x2 matrix, + and identifier for the compression block traversal direction + """ + self.input_size = input_size + index_len = math.ceil(input_size / 256) * math.ceil(output_size / 128) * 8 + return { + "weight": torch.empty(input_size * output_size, dtype=torch.int8), + "index": torch.empty(index_len, dtype=torch.int8), + "info": torch.empty(5, dtype=torch.int64), + } + + def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]: + return { + "input_scale": torch.empty(1, dtype=params_dtype), + "input_offset": torch.empty(1, dtype=torch.int8), + } + + def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: + return { + "quant_bias": torch.empty(output_size, dtype=torch.int32), + "deq_scale": torch.empty(output_size, dtype=torch.int64), + } + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + tp_rank: int | None = 0, + ) -> torch.Tensor: + if x.dtype != torch.int8: + x = torch.ops.vllm.quantize( + x, + layer.aclnn_input_scale, + layer.aclnn_input_scale_reciprocal, + layer.aclnn_input_offset, + ) + + return torch_npu.npu_matmul_compress_dequant( + x, + layer.weight, + layer.index, + layer.quant_bias, + layer.deq_scale, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.aclnn_input_scale = layer.input_scale.data.repeat(self.input_size) + layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale.data + layer.aclnn_input_offset = layer.input_offset.data.repeat(self.input_size).to(layer.aclnn_input_scale.dtype) + layer.deq_scale.data = layer.deq_scale.data.unsqueeze(0).to(torch.uint64) + layer.quant_bias.data = layer.quant_bias.data.unsqueeze(0) + # Only apply bias on row_parallel_linear when tp_rank is 0. + # torch_npu.npu_matmul_compress_dequant's quant_bias cannot be None. + if isinstance(layer, AscendRowParallelLinear) and get_tensor_model_parallel_rank() != 0: + layer.quant_bias.data = torch.zeros_like(layer.quant_bias)