[Feat] [310p] Support w8a8sc quantization method (#7075)
### What this PR does / why we need it?
New Quantization Method: Introduced support for the W8A8SC static linear
quantization scheme specifically for 310P hardware, enabling more
efficient model compression.
Refactored the save_sharded_state_310.py to avoid multi-process issue.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
W8A8SC quant E2E test.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
@@ -24,12 +24,15 @@ Sparse-Compress-Quantization state dict could also be saved via this script.
|
|||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
python save_sharded_state.py \
|
python save_sharded_state_310.py \
|
||||||
--model /path/to/load \
|
--model /path/to/load \
|
||||||
--tensor-parallel-size 8 \
|
--tensor-parallel-size 8 \
|
||||||
--output /path/to/save \
|
--output /path/to/save \
|
||||||
--enable-compress \
|
--enable-compress \
|
||||||
--compress-process-num 8
|
--compress-process-num 8 \
|
||||||
|
--enforce-eager \
|
||||||
|
--dtype float16 \
|
||||||
|
--quantization ascend
|
||||||
|
|
||||||
Then, the model can be loaded with
|
Then, the model can be loaded with
|
||||||
|
|
||||||
@@ -140,29 +143,30 @@ def get_quant_description(json_file: str) -> dict:
|
|||||||
return quant_desc
|
return quant_desc
|
||||||
|
|
||||||
|
|
||||||
def update_quant_description(json_file: str) -> None:
|
def update_quant_description(ori_json_file: str, target_json_file: str) -> None:
|
||||||
"""
|
"""
|
||||||
Update quantization types in JSON configuration file based on update mapping.
|
Update quantization types in JSON configuration file based on update mapping.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
json_file: Path to the JSON configuration file
|
ori_json_file: Path to the JSON configuration file
|
||||||
|
target_json_file: Path to the JSON configuration file to be saved
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If the JSON file does not exist
|
FileNotFoundError: If the JSON file does not exist
|
||||||
RuntimeError: If JSON parsing fails or required keys are missing
|
RuntimeError: If JSON parsing fails or required keys are missing
|
||||||
"""
|
"""
|
||||||
config_path = Path(json_file)
|
config_path = Path(ori_json_file)
|
||||||
try:
|
try:
|
||||||
with config_path.open("r", encoding="utf-8") as file:
|
with config_path.open("r", encoding="utf-8") as file:
|
||||||
json_data = json.load(file)
|
json_data = json.load(file)
|
||||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||||
raise RuntimeError(f"Failed to read configuration file {json_file}: {e}")
|
raise RuntimeError(f"Failed to read configuration file {ori_json_file}: {e}")
|
||||||
|
|
||||||
original_quant_type = json_data.get("model_quant_type")
|
original_quant_type = json_data.get("model_quant_type")
|
||||||
if not original_quant_type or original_quant_type not in QUANTIZATION_UPDATE_MAP:
|
if not original_quant_type or original_quant_type not in QUANTIZATION_UPDATE_MAP:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot update quantization type. "
|
f"Cannot update quantization type. "
|
||||||
f"Original type '{original_quant_type}' not found or not supported for update in {json_file}."
|
f"Original type '{original_quant_type}' not found or not supported for update in {ori_json_file}."
|
||||||
)
|
)
|
||||||
updated_quant_type = QUANTIZATION_UPDATE_MAP[original_quant_type]
|
updated_quant_type = QUANTIZATION_UPDATE_MAP[original_quant_type]
|
||||||
|
|
||||||
@@ -175,12 +179,12 @@ def update_quant_description(json_file: str) -> None:
|
|||||||
updated_config[key] = value
|
updated_config[key] = value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_file_path = config_path.parent / "quant_model_description.json"
|
new_file_path = Path(target_json_file)
|
||||||
with new_file_path.open("w", encoding="utf-8") as file:
|
with new_file_path.open("w", encoding="utf-8") as file:
|
||||||
json.dump(updated_config, file, indent=2, ensure_ascii=False)
|
json.dump(updated_config, file, indent=2, ensure_ascii=False)
|
||||||
os.remove(json_file)
|
os.remove(ori_json_file)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise RuntimeError(f"Failed to write updated configuration to {json_file}: {e}")
|
raise RuntimeError(f"Failed to write updated configuration to {target_json_file}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -> bool:
|
def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -> bool:
|
||||||
@@ -214,9 +218,6 @@ def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -
|
|||||||
compressor.run()
|
compressor.run()
|
||||||
if p.exists():
|
if p.exists():
|
||||||
os.remove(p)
|
os.remove(p)
|
||||||
ori_quant_desc_file = p.parent / "quant_model_description.json"
|
|
||||||
if ori_quant_desc_file.exists():
|
|
||||||
os.rename(str(ori_quant_desc_file), str(ori_quant_desc_file.parent / "ori_quant_model_description.json"))
|
|
||||||
compressor.export_safetensors(str(p.parent), safetensors_name=p.name)
|
compressor.export_safetensors(str(p.parent), safetensors_name=p.name)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -248,6 +249,10 @@ def main(args):
|
|||||||
# 4. Compression Logic
|
# 4. Compression Logic
|
||||||
parameters_map_fpath = output_dir / "parameters_type_map.json"
|
parameters_map_fpath = output_dir / "parameters_type_map.json"
|
||||||
if args.enable_compress:
|
if args.enable_compress:
|
||||||
|
quant_desc_file = output_dir / "quant_model_description.json"
|
||||||
|
backup_quant_desc_file = output_dir / "ori_quant_model_description.json"
|
||||||
|
if quant_desc_file.exists():
|
||||||
|
os.rename(str(quant_desc_file), str(backup_quant_desc_file))
|
||||||
quant_desc = get_quant_description(str(parameters_map_fpath))
|
quant_desc = get_quant_description(str(parameters_map_fpath))
|
||||||
quant_type = quant_desc["model_quant_type"]
|
quant_type = quant_desc["model_quant_type"]
|
||||||
if quant_type in SUPPORTED_COMPRESS_QUANT_TYPE:
|
if quant_type in SUPPORTED_COMPRESS_QUANT_TYPE:
|
||||||
@@ -269,7 +274,7 @@ def main(args):
|
|||||||
for p in tasks:
|
for p in tasks:
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
update_quant_description(os.path.join(args.output, "ori_quant_model_description.json"))
|
update_quant_description(str(backup_quant_desc_file), str(quant_desc_file))
|
||||||
print("Compression completed successfully.")
|
print("Compression completed successfully.")
|
||||||
else:
|
else:
|
||||||
print(f"Skipping compression: Unsupported type {quant_type}")
|
print(f"Skipping compression: Unsupported type {quant_type}")
|
||||||
|
|||||||
122
tests/ut/_310p/quantization/test_w8a8sc_310.py
Normal file
122
tests/ut/_310p/quantization/test_w8a8sc_310.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
|
from vllm_ascend._310p.quantization.methods.w8a8sc import AscendW8A8SCLinearMethod310
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendW8A8SCLinearMethod310(TestBase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.method = AscendW8A8SCLinearMethod310()
|
||||||
|
|
||||||
|
def test_get_weight_310(self):
|
||||||
|
weight = self.method.get_weight(10, 20)
|
||||||
|
self.assertEqual(weight["weight"].dtype, torch.int8)
|
||||||
|
self.assertEqual(weight["weight"].shape, (10 * 20, ))
|
||||||
|
self.assertEqual(weight["index"].dtype, torch.int8)
|
||||||
|
index_len = math.ceil(10 / 256) * math.ceil(20 / 128) * 8
|
||||||
|
self.assertEqual(weight["index"].shape, (index_len, ))
|
||||||
|
self.assertEqual(weight["info"].dtype, torch.int64)
|
||||||
|
self.assertEqual(weight["info"].shape, (5, ))
|
||||||
|
|
||||||
|
def test_get_pertensor_param_310(self):
|
||||||
|
params = self.method.get_pertensor_param(torch.float16)
|
||||||
|
self.assertEqual(params["input_scale"].dtype, torch.float16)
|
||||||
|
self.assertEqual(params["input_offset"].dtype, torch.int8)
|
||||||
|
self.assertEqual(params["input_scale"].shape, (1, ))
|
||||||
|
self.assertEqual(params["input_offset"].shape, (1, ))
|
||||||
|
|
||||||
|
def test_get_perchannel_param_310(self):
|
||||||
|
params = self.method.get_perchannel_param(10, torch.float16)
|
||||||
|
|
||||||
|
self.assertEqual(params["quant_bias"].dtype, torch.int32)
|
||||||
|
self.assertEqual(params["deq_scale"].dtype, torch.int64)
|
||||||
|
self.assertEqual(params["quant_bias"].shape, (10, ))
|
||||||
|
self.assertEqual(params["deq_scale"].shape, (10, ))
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
"Skip as npu_matmul_compress_dequant will be supported in PTA 26.0.0.")
|
||||||
|
@patch("torch.ops.vllm.quantize")
|
||||||
|
@patch("torch_npu.npu_matmul_compress_dequant")
|
||||||
|
def test_apply_with_x_not_int8_310(self, mock_matmul_compress_dequant,
|
||||||
|
mock_quantize):
|
||||||
|
layer = MagicMock()
|
||||||
|
layer.aclnn_input_scale = torch.randn(256)
|
||||||
|
layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale
|
||||||
|
layer.aclnn_input_offset = torch.randint(-128,
|
||||||
|
127, (256, ),
|
||||||
|
dtype=torch.int8)
|
||||||
|
layer.weight = torch.randint(-128,
|
||||||
|
127, (256 * 128, ),
|
||||||
|
dtype=torch.int8)
|
||||||
|
layer.index = torch.randint(-128, 127, (8, ), dtype=torch.int8)
|
||||||
|
layer.deq_scale = torch.randn(128)
|
||||||
|
layer.quant_bias = torch.randint(-128, 127, (256, ))
|
||||||
|
layer.params_dtype = torch.float16
|
||||||
|
|
||||||
|
x = torch.randn(32, 128)
|
||||||
|
expect_x_output = torch.randint(-128, 127, x.shape, dtype=torch.int8)
|
||||||
|
mock_quantize.return_value = expect_x_output
|
||||||
|
|
||||||
|
expected_y_output = torch.randn(32, 256)
|
||||||
|
mock_matmul_compress_dequant.return_value = expected_y_output
|
||||||
|
|
||||||
|
output = self.method.apply(layer, x, tp_rank=0)
|
||||||
|
|
||||||
|
mock_quantize.assert_called_with(x, layer.aclnn_input_scale,
|
||||||
|
layer.aclnn_input_scale_reciprocal,
|
||||||
|
layer.aclnn_input_offset)
|
||||||
|
mock_matmul_compress_dequant.assert_called_with(
|
||||||
|
expect_x_output, layer.weight, layer.index, layer.quant_bias,
|
||||||
|
layer.deq_scale)
|
||||||
|
self.assertTrue(torch.equal(output, expected_y_output))
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
"Skip as npu_matmul_compress_dequant will be supported in PTA 26.0.0.")
|
||||||
|
@patch("torch.ops.vllm.quantize")
|
||||||
|
@patch("torch_npu.npu_matmul_compress_dequant")
|
||||||
|
def test_apply_with_x_is_int8_310(self, mock_matmul_compress_dequant,
|
||||||
|
mock_quantize):
|
||||||
|
layer = MagicMock()
|
||||||
|
layer.aclnn_input_scale = torch.randn(256)
|
||||||
|
layer.aclnn_input_offset = torch.randint(-128,
|
||||||
|
127, (256, ),
|
||||||
|
dtype=torch.int8)
|
||||||
|
layer.weight = torch.randint(-128,
|
||||||
|
127, (256 * 128, ),
|
||||||
|
dtype=torch.int8)
|
||||||
|
layer.index = torch.randint(-128, 127, (8, ), dtype=torch.int8)
|
||||||
|
layer.deq_scale = torch.randn(128)
|
||||||
|
layer.quant_bias = torch.randint(-128, 127, (256, ))
|
||||||
|
layer.params_dtype = torch.float16
|
||||||
|
|
||||||
|
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
||||||
|
|
||||||
|
expected_y_output = torch.randn(32, 256)
|
||||||
|
mock_matmul_compress_dequant.return_value = expected_y_output
|
||||||
|
|
||||||
|
output = self.method.apply(layer, x, tp_rank=0)
|
||||||
|
|
||||||
|
mock_quantize.assert_not_called()
|
||||||
|
mock_matmul_compress_dequant.assert_called_with(
|
||||||
|
x, layer.weight, layer.index, layer.quant_bias, layer.deq_scale)
|
||||||
|
self.assertTrue(torch.equal(output, expected_y_output))
|
||||||
@@ -19,4 +19,5 @@ from . import (
|
|||||||
w8a8_dynamic, # noqa: F401
|
w8a8_dynamic, # noqa: F401
|
||||||
w8a8_static, # noqa: F401
|
w8a8_static, # noqa: F401
|
||||||
w8a8s, # noqa: F401
|
w8a8s, # noqa: F401
|
||||||
|
w8a8sc, # noqa: F401
|
||||||
)
|
)
|
||||||
|
|||||||
116
vllm_ascend/_310p/quantization/methods/w8a8sc.py
Normal file
116
vllm_ascend/_310p/quantization/methods/w8a8sc.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
|
|
||||||
|
from vllm_ascend.ops.linear import AscendRowParallelLinear
|
||||||
|
from vllm_ascend.quantization.methods.base import AscendLinearScheme
|
||||||
|
|
||||||
|
from .registry import register_scheme
|
||||||
|
|
||||||
|
|
||||||
|
@register_scheme("W8A8SC", "linear")
|
||||||
|
class AscendW8A8SCLinearMethod310(AscendLinearScheme):
|
||||||
|
"""310P-only W8A8SC static linear scheme.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- This scheme is discovered via 310P local registry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_weight(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get the weight tensors for the W8A8SC quantization scheme.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_size: Size of the input dimension (k)
|
||||||
|
output_size: Size of the output dimension (n)
|
||||||
|
params_dtype: Data type for parameters, default is torch.float16
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing:
|
||||||
|
- "weight": The compressed weight tensor with shape [c], where c is greater than 0
|
||||||
|
and not larger than k * n
|
||||||
|
- "index": Compression index generated simultaneously with compressed weights,
|
||||||
|
with shape [x], where x = k_index * n_index * 8, k_index = ceil(k1 / tilingK),
|
||||||
|
n_index = ceil(n1 / tilingN), k1 = k / 32, n1 = n / 16
|
||||||
|
- "info": Compression information with length 5, containing compression block
|
||||||
|
information tilingN, tilingK, original shape of the pre-compression x2 matrix,
|
||||||
|
and identifier for the compression block traversal direction
|
||||||
|
"""
|
||||||
|
self.input_size = input_size
|
||||||
|
index_len = math.ceil(input_size / 256) * math.ceil(output_size / 128) * 8
|
||||||
|
return {
|
||||||
|
"weight": torch.empty(input_size * output_size, dtype=torch.int8),
|
||||||
|
"index": torch.empty(index_len, dtype=torch.int8),
|
||||||
|
"info": torch.empty(5, dtype=torch.int64),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"input_scale": torch.empty(1, dtype=params_dtype),
|
||||||
|
"input_offset": torch.empty(1, dtype=torch.int8),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"quant_bias": torch.empty(output_size, dtype=torch.int32),
|
||||||
|
"deq_scale": torch.empty(output_size, dtype=torch.int64),
|
||||||
|
}
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
tp_rank: int | None = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if x.dtype != torch.int8:
|
||||||
|
x = torch.ops.vllm.quantize(
|
||||||
|
x,
|
||||||
|
layer.aclnn_input_scale,
|
||||||
|
layer.aclnn_input_scale_reciprocal,
|
||||||
|
layer.aclnn_input_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch_npu.npu_matmul_compress_dequant(
|
||||||
|
x,
|
||||||
|
layer.weight,
|
||||||
|
layer.index,
|
||||||
|
layer.quant_bias,
|
||||||
|
layer.deq_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
layer.aclnn_input_scale = layer.input_scale.data.repeat(self.input_size)
|
||||||
|
layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale.data
|
||||||
|
layer.aclnn_input_offset = layer.input_offset.data.repeat(self.input_size).to(layer.aclnn_input_scale.dtype)
|
||||||
|
layer.deq_scale.data = layer.deq_scale.data.unsqueeze(0).to(torch.uint64)
|
||||||
|
layer.quant_bias.data = layer.quant_bias.data.unsqueeze(0)
|
||||||
|
# Only apply bias on row_parallel_linear when tp_rank is 0.
|
||||||
|
# torch_npu.npu_matmul_compress_dequant's quant_bias cannot be None.
|
||||||
|
if isinstance(layer, AscendRowParallelLinear) and get_tensor_model_parallel_rank() != 0:
|
||||||
|
layer.quant_bias.data = torch.zeros_like(layer.quant_bias)
|
||||||
Reference in New Issue
Block a user