[main][quantization] Support deepseek w4a8 per-channel quantization (#3011)

### What this PR does / why we need it?
1.Support deepseek w4a8 per-channel quantization
2.The eager mode supports converting weights to the NZ format
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
#### How to get weights using Modelslim

##### Installation steps

git clone https://gitcode.com/Ascend/msit.git
cd msit/msmodelslim
bash install.sh

##### Generate w4a8 per-channel weights

cd /example/DeepSeek
Command reference: msmodelslim/example/DeepSeek/README.md

- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2025-09-27 21:01:16 +08:00
committed by GitHub
parent e9359bd8fa
commit 859e861d92
6 changed files with 299 additions and 196 deletions

View File

@@ -1,4 +1,3 @@
import copy
from unittest.mock import Mock, patch
import torch
@@ -95,19 +94,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
# old quant version weight
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w13_weight_scale"].shape,
(self.experts, 2 * self.input_size, 1))
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
torch.bfloat16)
torch.float32)
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
(self.experts, 2 * self.input_size,
self.output_size // self.group_size))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w2_weight_scale"].shape,
(self.experts, self.output_size, 1))
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
torch.bfloat16)
torch.float32)
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
(self.experts, self.output_size,
self.input_size // self.group_size))
@@ -119,40 +118,87 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
self.assertEqual(
param_dict["w2_scale_bias"].shape,
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
# per-channel weight
self.quant_method.is_per_channel_weight = True
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
pergroup_param = [
"w13_weight_scale_second", "w13_weight_offset_second",
"w2_weight_scale_second", "w2_weight_offset_second"
]
is_contains = any(key in param_dict for key in pergroup_param)
self.assertFalse(is_contains)
def build_layer(self,
is_new_quant_version=True,
is_per_channel_weight=False):
layer = torch.nn.Module()
if is_new_quant_version:
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8),
requires_grad=False)
w13_scale_bias = torch.zeros(
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w2_scale_bias = torch.zeros((self.experts, self.output_size,
16 // self.quant_method.tp_size),
dtype=torch.float32)
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
else:
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, 2 * self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size, self.input_size),
dtype=torch.int8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size, 1), dtype=torch.float32),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, self.output_size, 1), dtype=torch.float32),
requires_grad=False)
if not is_per_channel_weight:
layer.w13_weight_scale_second = torch.nn.Parameter(
torch.ones((self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.float32),
requires_grad=False)
layer.w13_weight_offset_second = torch.nn.Parameter(
torch.empty_like(layer.w13_weight_scale_second.data),
requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(
torch.ones((self.experts, self.output_size,
self.input_size // self.group_size),
dtype=torch.float32),
requires_grad=False)
layer.w2_weight_offset_second = torch.nn.Parameter(
torch.empty_like(layer.w2_weight_scale_second.data),
requires_grad=False)
return layer
@patch('torch_npu.npu_format_cast')
@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
# old quant version weight
layer = torch.nn.Module()
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, 2 * self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size, self.input_size),
dtype=torch.int8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
(self.experts, self.output_size,
self.input_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False)
new_layer = copy.deepcopy(layer)
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize,
mock_npu_format_cast):
mock_npu.return_value = torch.Tensor()
mock_npu_quantize.return_value = torch.Tensor()
def func_by_args(weight, num_format):
return weight
mock_npu_format_cast.side_effect = func_by_args
# old quant version weight
layer = self.build_layer(is_new_quant_version=False)
self.quant_method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "w13_scale_bias"))
self.assertEqual(layer.w13_scale_bias.data.shape,
@@ -164,23 +210,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
# new quant version weight
self.quant_method.new_quant_version = True
new_layer.w13_weight.data = torch.zeros(
(self.experts, self.input_size, self.output_size),
dtype=torch.int8)
new_layer.w2_weight.data = torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8)
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w2_scale_bias = torch.zeros(
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
dtype=torch.float32)
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
new_layer = self.build_layer(is_new_quant_version=True)
self.quant_method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(new_layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
# per-channel weight
self.quant_method.is_per_channel_weight = True
per_channel_layer = self.build_layer(is_new_quant_version=True,
is_per_channel_weight=True)
self.quant_method.process_weights_after_loading(per_channel_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))