[Feat][quantization] Support new version w4a8 dynamic quantization for Linear layers (#3311)

### What this PR does / why we need it?
**Problem Description:**

The existing implementation for the w4a8-dynamic linear method only
supports the old quantization format from msmodelslim. When attempting
to load models quantized with the new version, vLLM encounters errors
due to mismatched tensor shapes and unprocessed quantization parameters.

Relavant issues: 
- https://github.com/vllm-project/vllm-ascend/issues/3192
- https://github.com/vllm-project/vllm-ascend/issues/3152

**Proposed Changes:**
1. Add support for w4a8 dynamic(new format) in
AscendW4A8DynamicLinearMethod and TorchairAscendW4A8DynamicLinearMethod
2. Add unit tests and e2e tests for w4a8 dynamic new and old format
models
<details>
<summary><b>details</b></summary>

1.  **Support for new w4a8-dynamic format:**
* Detects quantization format by reading the "version" field in
quant_description to ensure backward compatibility.
* Handles the new pre-packed weight format (`2x int4` in an `int8`),
which has a halved dimension. It tells the vLLM loader how to unpack it
using `_packed_dim` and `_packed_factor`.
* Supports the new `scale_bias` parameter, setting its shape based on
the layer type, as required by msmodelslim. For api consistency and
future use, the `layer_type` parameter was also added to other
quantization methods.
* Updates the weight processing logic: new format weights are handled
with `.view(torch.int32)` since they're pre-packed, while old ones are
processed with `npu_convert_weight_to_int4pack`.

2.  **New unit and E2E tests:**
* Added unit tests that verify the logic for both the old and new
formats.
* Split the distributed E2E test to confirm that both old and new format
models work correctly.

</details>
Theoretically, these changes will provide support for all common new
version w4a8(dynamic) models from msmodelslim.

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
I implement relevant unit tests and e2e tests and test the changes with
following commands:
```bash
# unit tests
python -m pytest tests/ut/quantization/test_w4a8_dynamic.py tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py -v

# e2e tests
pytest tests/e2e/singlecard/test_quantization.py -v -s

pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_new_version -v -s
pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC_old_version -v -s
pytest tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC -v -s

```

I also tested Hunyuan-1.8B-Instruct quantized with the new w4a8-dynamic
format:
```
vllm serve ./models/Hunyuan-1.8B-Instruct-quantized --gpu-memory-utilization 0.96 --quantization ascend --max-model-len 9600 --seed 0 --max-num-batched-tokens 16384 
```

All tests mentioned passed locally.

**NOTE: I use quantization model from my own repo in
test_offline_inference_distributed.py**. Here is the description:
[Anionex/Qwen3-1.7B-W4A8-V1](https://modelscope.cn/models/Anionex/Qwen3-1.7B-W4A8-V1/summary)
(including quantization steps).This should be replaced by a model in
vllm-ascend ci modelscope repo.

Thanks for reading!


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: Anionex <1005128408@qq.com>
This commit is contained in:
Anion
2025-10-21 20:18:39 +08:00
committed by GitHub
parent 11f9bccf6b
commit 5f8b1699ae
10 changed files with 433 additions and 75 deletions

View File

@@ -36,18 +36,42 @@ class AscendW4A8DynamicLinearMethod:
def __init__(self):
self.transpose_weight = True
try:
self.group_size = get_current_vllm_config(
).quant_config.quant_description.get("group_size", 256)
except AttributeError:
self.group_size = 256
@staticmethod
def get_weight(input_size: int, output_size: int,
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
quant_version = vllm_config.quant_config.quant_description.get(
"version", "0")
self.new_quant_version = quant_version == "1.0.0"
from vllm.distributed import get_tensor_model_parallel_world_size
self.tp_size = get_tensor_model_parallel_world_size()
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
"""Create weight parameters.
For new quantization version (double int4 pack into int8), the output dimension
is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned
dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader.
"""
params_dict = {}
if self.new_quant_version:
# double int4 pack into int8: output dimension is compressed
pack_factor = 2
actual_output_size = output_size // pack_factor
params_dict["weight"] = torch.empty(actual_output_size,
input_size,
dtype=torch.int8)
# Add packing information for vLLM's weight_loader
params_dict["_packed_dim"] = 0
params_dict["_packed_factor"] = pack_factor
else:
params_dict["weight"] = torch.empty(output_size,
input_size,
dtype=torch.int8)
return params_dict
@staticmethod
@@ -59,8 +83,14 @@ class AscendW4A8DynamicLinearMethod:
params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
def get_pergroup_param(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
"""
Create per-group quantization parameters.
"""
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
@@ -76,17 +106,52 @@ class AscendW4A8DynamicLinearMethod:
input_size //
self.group_size,
dtype=params_dtype)
# NOTE: In w4a8 quantization implementation,
# for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16],
# others are [output_size, 1]
if self.new_quant_version:
scale_bias_dim = 16 if layer_type == "row" else 1
params_dict["scale_bias"] = torch.empty(output_size,
scale_bias_dim,
dtype=torch.float32)
return params_dict
@staticmethod
def process_scale_second(weight: torch.Tensor, scale: torch.Tensor,
per_group_scale: torch.Tensor):
def process_scale_second(weight: torch.Tensor,
scale: torch.Tensor,
per_group_scale: torch.Tensor,
is_new_quant: bool = False):
"""
Process the scale for second-level quantization.
Args:
weight: weight tensor [k, n] (in new version, n is already compressed to n/2)
scale: first-level quantization scale [output_size]
per_group_scale: second-level per-group quantization scale [group_num, n_scale]
is_new_quant: whether it's the new quantization version (weight already compressed)
Returns:
(antiquant_scale, bias): dequantization scale and bias (bias=None for new version)
"""
k, n = weight.shape
group_num, n = per_group_scale.shape
weight_high = weight.to(torch.float32).reshape(
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
weight_high = weight_high.reshape(k, n)
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
group_num, n_scale = per_group_scale.shape
if is_new_quant:
# Restore logical dimension for compressed weight
n = n * 2
bias = None
if not is_new_quant:
weight_high = weight.to(torch.float32).reshape(
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
weight_high = weight_high.reshape(k, n)
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
# NOTE: scale_bias is not used currently
# because in msmodelslim w4a8 uses symmetric quantization
# TODO: support potential future asymmetric quantization
antiquant_scale = (scale * per_group_scale).reshape(group_num, n)
return antiquant_scale.npu(), bias
@@ -114,11 +179,34 @@ class AscendW4A8DynamicLinearMethod:
layer.weight.data,
layer.weight_scale.data,
layer.weight_scale_second.data.transpose(0, 1).contiguous(),
is_new_quant=self.new_quant_version,
)
param = torch.nn.Parameter(scale_bias, requires_grad=False)
layer.register_parameter("weight_scale_bias", param)
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
layer.weight.data.to(torch.int32))
if self.new_quant_version:
# Process the loaded data based on layer type
if hasattr(layer, "scale_bias"):
if layer.scale_bias.data.shape[1] == 1:
layer.scale_bias.data = layer.scale_bias.data.flatten()
else:
layer.scale_bias.data = layer.scale_bias.data.contiguous()
else:
if scale_bias is not None:
param = torch.nn.Parameter(scale_bias, requires_grad=False)
layer.register_parameter("weight_scale_bias", param)
# Convert to NPU-specific int4pack format
if self.new_quant_version:
# weights on disk are already in packed int4 format
# pack 4 int8(int4*2) to int32
assert layer.weight.data.shape[-1] % 4 == 0, \
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
layer.weight.data = layer.weight.data.view(
torch.int32).contiguous()
else:
# weights are not compressed
# need to be packed via npu_convert_weight_to_int4pack
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
layer.weight.data.to(torch.int32))
class AscendW4A8DynamicFusedMoEMethod: