[main][quantization] Support deepseek w4a8 per-channel quantization (#3011)

### What this PR does / why we need it?
1.Support deepseek w4a8 per-channel quantization
2.The eager mode supports converting weights to the NZ format
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
#### How to get weights using Modelslim

##### Installation steps

git clone https://gitcode.com/Ascend/msit.git
cd msit/msmodelslim
bash install.sh

##### Generate w4a8 per-channel weights

cd /example/DeepSeek
Command reference: msmodelslim/example/DeepSeek/README.md

- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2025-09-27 21:01:16 +08:00
committed by GitHub
parent e9359bd8fa
commit 859e861d92
6 changed files with 299 additions and 196 deletions

View File

@@ -108,18 +108,19 @@ Please convert DeepSeek series models using `br_release_MindStudio_8.1.RC2_TR5_2
### 3. When converting deepseek series models with modelslim, what should you pay attention? ### 3. When converting deepseek series models with modelslim, what should you pay attention?
When using the weight generated by modelslim with the `--dynamic` parameter, if torchair graph mode is enabled, please modify the configuration file in the CANN package to prevent incorrect inference results. When the mla portion of the weights used `W8A8_DYNAMIC` quantization, if torchair graph mode is enabled, please modify the configuration file in the CANN package to prevent incorrect inference results.
The operation steps are as follows: The operation steps are as follows:
1. Search in the CANN package directory used, for example: 1. Search in the CANN package directory used, for example:
find /usr/local/Ascend/ -name fusion_config.json find /usr/local/Ascend/ -name fusion_config.json
2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows: 2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` and `"MultiAddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows:
```bash ```bash
{ {
"Switch":{ "Switch":{
"GraphFusion":{ "GraphFusion":{
"AddRmsNormDynamicQuantFusionPass":"off", "AddRmsNormDynamicQuantFusionPass":"off",
"MultiAddRmsNormDynamicQuantFusionPass":"off",
``` ```

View File

@@ -35,6 +35,11 @@ QWEN_DENSE_MODELS = [
"vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8" "vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8"
] ]
DEEPSEEK_W4A8_MODELS = [
"vllm-ascend/DeepSeek-V3-W4A8-Pruing",
"vllm-ascend/DeepSeek-V3.1-W4A8-puring"
]
def test_models_distributed_QwQ(): def test_models_distributed_QwQ():
example_prompts = [ example_prompts = [
@@ -109,14 +114,15 @@ def test_models_distributed_Qwen3_W4A8DYNAMIC():
vllm_model.generate_greedy(example_prompts, max_tokens) vllm_model.generate_greedy(example_prompts, max_tokens)
@pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"}) @patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"})
def test_models_distributed_DeepSeek_W4A8DYNAMIC(): def test_models_distributed_DeepSeek_W4A8DYNAMIC(model):
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
] ]
max_tokens = 5 max_tokens = 5
with VllmRunner( with VllmRunner(
snapshot_download("vllm-ascend/DeepSeek-V3-W4A8-Pruing"), snapshot_download(model),
dtype="auto", dtype="auto",
tensor_parallel_size=2, tensor_parallel_size=2,
quantization="ascend", quantization="ascend",

View File

@@ -1,4 +1,3 @@
import copy
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import torch import torch
@@ -95,19 +94,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
# old quant version weight # old quant version weight
param_dict = self.quant_method.get_dynamic_quant_param( param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16) self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w13_weight_scale"].shape, self.assertEqual(param_dict["w13_weight_scale"].shape,
(self.experts, 2 * self.input_size, 1)) (self.experts, 2 * self.input_size, 1))
self.assertEqual(param_dict["w13_weight_scale_second"].dtype, self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
torch.bfloat16) torch.float32)
self.assertEqual(param_dict["w13_weight_scale_second"].shape, self.assertEqual(param_dict["w13_weight_scale_second"].shape,
(self.experts, 2 * self.input_size, (self.experts, 2 * self.input_size,
self.output_size // self.group_size)) self.output_size // self.group_size))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16) self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w2_weight_scale"].shape, self.assertEqual(param_dict["w2_weight_scale"].shape,
(self.experts, self.output_size, 1)) (self.experts, self.output_size, 1))
self.assertEqual(param_dict["w2_weight_scale_second"].dtype, self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
torch.bfloat16) torch.float32)
self.assertEqual(param_dict["w2_weight_scale_second"].shape, self.assertEqual(param_dict["w2_weight_scale_second"].shape,
(self.experts, self.output_size, (self.experts, self.output_size,
self.input_size // self.group_size)) self.input_size // self.group_size))
@@ -119,12 +118,40 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
self.assertEqual( self.assertEqual(
param_dict["w2_scale_bias"].shape, param_dict["w2_scale_bias"].shape,
(self.experts, self.output_size, 16 // self.quant_method.tp_size)) (self.experts, self.output_size, 16 // self.quant_method.tp_size))
# per-channel weight
self.quant_method.is_per_channel_weight = True
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
pergroup_param = [
"w13_weight_scale_second", "w13_weight_offset_second",
"w2_weight_scale_second", "w2_weight_offset_second"
]
is_contains = any(key in param_dict for key in pergroup_param)
self.assertFalse(is_contains)
@patch('torch_npu.npu_quantize') def build_layer(self,
@patch('torch.Tensor.npu') is_new_quant_version=True,
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize): is_per_channel_weight=False):
# old quant version weight
layer = torch.nn.Module() layer = torch.nn.Module()
if is_new_quant_version:
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8),
requires_grad=False)
w13_scale_bias = torch.zeros(
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w2_scale_bias = torch.zeros((self.experts, self.output_size,
16 // self.quant_method.tp_size),
dtype=torch.float32)
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
else:
layer.w13_weight = torch.nn.Parameter(torch.zeros( layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, 2 * self.input_size, self.output_size), (self.experts, 2 * self.input_size, self.output_size),
dtype=torch.int8), dtype=torch.int8),
@@ -134,25 +161,44 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
dtype=torch.int8), dtype=torch.int8),
requires_grad=False) requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(torch.ones( layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16), (self.experts, 2 * self.input_size, 1), dtype=torch.float32),
requires_grad=False)
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False) requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(torch.ones( layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, self.output_size, 1), dtype=torch.bfloat16), (self.experts, self.output_size, 1), dtype=torch.float32),
requires_grad=False) requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones( if not is_per_channel_weight:
(self.experts, self.output_size, layer.w13_weight_scale_second = torch.nn.Parameter(
torch.ones((self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.float32),
requires_grad=False)
layer.w13_weight_offset_second = torch.nn.Parameter(
torch.empty_like(layer.w13_weight_scale_second.data),
requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(
torch.ones((self.experts, self.output_size,
self.input_size // self.group_size), self.input_size // self.group_size),
dtype=torch.bfloat16), dtype=torch.float32),
requires_grad=False) requires_grad=False)
new_layer = copy.deepcopy(layer) layer.w2_weight_offset_second = torch.nn.Parameter(
torch.empty_like(layer.w2_weight_scale_second.data),
requires_grad=False)
return layer
@patch('torch_npu.npu_format_cast')
@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize,
mock_npu_format_cast):
mock_npu.return_value = torch.Tensor() mock_npu.return_value = torch.Tensor()
mock_npu_quantize.return_value = torch.Tensor() mock_npu_quantize.return_value = torch.Tensor()
def func_by_args(weight, num_format):
return weight
mock_npu_format_cast.side_effect = func_by_args
# old quant version weight
layer = self.build_layer(is_new_quant_version=False)
self.quant_method.process_weights_after_loading(layer) self.quant_method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "w13_scale_bias")) self.assertTrue(hasattr(layer, "w13_scale_bias"))
self.assertEqual(layer.w13_scale_bias.data.shape, self.assertEqual(layer.w13_scale_bias.data.shape,
@@ -164,23 +210,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32) self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
# new quant version weight # new quant version weight
self.quant_method.new_quant_version = True self.quant_method.new_quant_version = True
new_layer.w13_weight.data = torch.zeros( new_layer = self.build_layer(is_new_quant_version=True)
(self.experts, self.input_size, self.output_size),
dtype=torch.int8)
new_layer.w2_weight.data = torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8)
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w2_scale_bias = torch.zeros(
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
dtype=torch.float32)
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
self.quant_method.process_weights_after_loading(new_layer) self.quant_method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape, self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size)) (self.experts, 2 * self.input_size))
self.assertEqual(new_layer.w2_scale_bias.data.shape, self.assertEqual(new_layer.w2_scale_bias.data.shape,
(self.experts, self.output_size)) (self.experts, self.output_size))
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
# per-channel weight
self.quant_method.is_per_channel_weight = True
per_channel_layer = self.build_layer(is_new_quant_version=True,
is_per_channel_weight=True)
self.quant_method.process_weights_after_loading(per_channel_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))

View File

@@ -1,4 +1,3 @@
import copy
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import torch import torch
@@ -85,19 +84,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
# old quant version weight # old quant version weight
param_dict = self.quant_method.get_dynamic_quant_param( param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16) self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w13_weight_scale"].shape, self.assertEqual(param_dict["w13_weight_scale"].shape,
(self.experts, 2 * self.input_size, 1)) (self.experts, 2 * self.input_size, 1))
self.assertEqual(param_dict["w13_weight_scale_second"].dtype, self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
torch.bfloat16) torch.float32)
self.assertEqual(param_dict["w13_weight_scale_second"].shape, self.assertEqual(param_dict["w13_weight_scale_second"].shape,
(self.experts, 2 * self.input_size, (self.experts, 2 * self.input_size,
self.output_size // self.group_size)) self.output_size // self.group_size))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16) self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w2_weight_scale"].shape, self.assertEqual(param_dict["w2_weight_scale"].shape,
(self.experts, self.output_size, 1)) (self.experts, self.output_size, 1))
self.assertEqual(param_dict["w2_weight_scale_second"].dtype, self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
torch.bfloat16) torch.float32)
self.assertEqual(param_dict["w2_weight_scale_second"].shape, self.assertEqual(param_dict["w2_weight_scale_second"].shape,
(self.experts, self.output_size, (self.experts, self.output_size,
self.input_size // self.group_size)) self.input_size // self.group_size))
@@ -109,12 +108,40 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
self.assertEqual( self.assertEqual(
param_dict["w2_scale_bias"].shape, param_dict["w2_scale_bias"].shape,
(self.experts, self.output_size, 16 // self.quant_method.tp_size)) (self.experts, self.output_size, 16 // self.quant_method.tp_size))
# per-channel weight
self.quant_method.is_per_channel_weight = True
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
pergroup_param = [
"w13_weight_scale_second", "w13_weight_offset_second",
"w2_weight_scale_second", "w2_weight_offset_second"
]
is_contains = any(key in param_dict for key in pergroup_param)
self.assertFalse(is_contains)
@patch('torch_npu.npu_quantize') def build_layer(self,
@patch('torch.Tensor.npu') is_new_quant_version=True,
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize): is_per_channel_weight=False):
# old quant version weight
layer = torch.nn.Module() layer = torch.nn.Module()
if is_new_quant_version:
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8),
requires_grad=False)
w13_scale_bias = torch.zeros(
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w2_scale_bias = torch.zeros((self.experts, self.output_size,
16 // self.quant_method.tp_size),
dtype=torch.float32)
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
else:
layer.w13_weight = torch.nn.Parameter(torch.zeros( layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, 2 * self.input_size, self.output_size), (self.experts, 2 * self.input_size, self.output_size),
dtype=torch.int8), dtype=torch.int8),
@@ -124,25 +151,37 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
dtype=torch.int8), dtype=torch.int8),
requires_grad=False) requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(torch.ones( layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16), (self.experts, 2 * self.input_size, 1), dtype=torch.float32),
requires_grad=False)
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False) requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(torch.ones( layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, self.output_size, 1), dtype=torch.bfloat16), (self.experts, self.output_size, 1), dtype=torch.float32),
requires_grad=False) requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones( if not is_per_channel_weight:
(self.experts, self.output_size, layer.w13_weight_scale_second = torch.nn.Parameter(
torch.ones((self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.float32),
requires_grad=False)
layer.w13_weight_offset_second = torch.nn.Parameter(
torch.empty_like(layer.w13_weight_scale_second.data),
requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(
torch.ones((self.experts, self.output_size,
self.input_size // self.group_size), self.input_size // self.group_size),
dtype=torch.bfloat16), dtype=torch.float32),
requires_grad=False) requires_grad=False)
new_layer = copy.deepcopy(layer) layer.w2_weight_offset_second = torch.nn.Parameter(
torch.empty_like(layer.w2_weight_scale_second.data),
requires_grad=False)
return layer
@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
mock_npu.return_value = torch.Tensor() mock_npu.return_value = torch.Tensor()
mock_npu_quantize.return_value = torch.Tensor() mock_npu_quantize.return_value = torch.Tensor()
# old quant version weight
layer = self.build_layer(is_new_quant_version=False)
self.quant_method.process_weights_after_loading(layer) self.quant_method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "w13_scale_bias")) self.assertTrue(hasattr(layer, "w13_scale_bias"))
self.assertEqual(layer.w13_scale_bias.data.shape, self.assertEqual(layer.w13_scale_bias.data.shape,
@@ -154,23 +193,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32) self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
# new quant version weight # new quant version weight
self.quant_method.new_quant_version = True self.quant_method.new_quant_version = True
new_layer.w13_weight.data = torch.zeros( new_layer = self.build_layer(is_new_quant_version=True)
(self.experts, self.input_size, self.output_size),
dtype=torch.int8)
new_layer.w2_weight.data = torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8)
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w2_scale_bias = torch.zeros(
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
dtype=torch.float32)
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
self.quant_method.process_weights_after_loading(new_layer) self.quant_method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape, self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size)) (self.experts, 2 * self.input_size))
self.assertEqual(new_layer.w2_scale_bias.data.shape, self.assertEqual(new_layer.w2_scale_bias.data.shape,
(self.experts, self.output_size)) (self.experts, self.output_size))
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
# per-channel weight
self.quant_method.is_per_channel_weight = True
per_channel_layer = self.build_layer(is_new_quant_version=True,
is_per_channel_weight=True)
self.quant_method.process_weights_after_loading(per_channel_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))

View File

@@ -27,6 +27,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
class AscendW4A8DynamicLinearMethod: class AscendW4A8DynamicLinearMethod:
@@ -132,6 +133,8 @@ class AscendW4A8DynamicFusedMoEMethod:
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get( self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256) "group_size", 256)
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
self.is_per_channel_weight = self.group_size == 0
quant_version = vllm_config.quant_config.quant_description.get( quant_version = vllm_config.quant_config.quant_description.get(
"version", "0") "version", "0")
# NOTE: new quantize weights: 2 int4 pack into int8 # NOTE: new quantize weights: 2 int4 pack into int8
@@ -182,44 +185,44 @@ class AscendW4A8DynamicFusedMoEMethod:
num_experts, num_experts,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition,
1, 1,
dtype=params_dtype) dtype=torch.float32)
param_dict["w13_weight_offset"] = torch.empty( param_dict["w13_weight_offset"] = torch.empty(
num_experts, num_experts,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition,
1, 1,
dtype=params_dtype) dtype=torch.float32)
param_dict["w13_weight_scale_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=params_dtype)
param_dict["w13_weight_offset_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=params_dtype)
param_dict["w2_weight_scale"] = torch.empty(num_experts, param_dict["w2_weight_scale"] = torch.empty(num_experts,
hidden_sizes, hidden_sizes,
1, 1,
dtype=params_dtype) dtype=torch.float32)
param_dict["w2_weight_offset"] = torch.empty(num_experts, param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes, hidden_sizes,
1, 1,
dtype=params_dtype) dtype=torch.float32)
if not self.is_per_channel_weight:
param_dict["w13_weight_scale_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.float32)
param_dict["w13_weight_offset_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.float32)
param_dict["w2_weight_scale_second"] = torch.empty( param_dict["w2_weight_scale_second"] = torch.empty(
num_experts, num_experts,
hidden_sizes, hidden_sizes,
intermediate_size_per_partition // self.group_size, intermediate_size_per_partition // self.group_size,
dtype=params_dtype) dtype=torch.float32)
param_dict["w2_weight_offset_second"] = torch.empty( param_dict["w2_weight_offset_second"] = torch.empty(
num_experts, num_experts,
hidden_sizes, hidden_sizes,
intermediate_size_per_partition // self.group_size, intermediate_size_per_partition // self.group_size,
dtype=params_dtype) dtype=torch.float32)
if self.new_quant_version: if self.new_quant_version:
param_dict["w13_scale_bias"] = torch.empty( param_dict["w13_scale_bias"] = torch.empty(
@@ -288,8 +291,8 @@ class AscendW4A8DynamicFusedMoEMethod:
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_second, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale_second, w2_scale=layer.w2_weight_scale,
w1_scale_bias=layer.w13_scale_bias, w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias, w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights, topk_weights=topk_weights,
@@ -305,6 +308,14 @@ class AscendW4A8DynamicFusedMoEMethod:
dynamic_eplb=self.dynamic_eplb) dynamic_eplb=self.dynamic_eplb)
def process_scale(self, weight: torch.Tensor, scale, per_group_scale): def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
scale = scale.transpose(1, 2).contiguous()
if self.is_per_channel_weight:
scale_np = scale.cpu().numpy()
scale_np.dtype = np.uint32
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
np.int64)).npu()
return scale_uint64_tensor, None
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
group_num, k, n = weight.shape group_num, k, n = weight.shape
# the weight of the new version is reduced by half by pack n, so it needs to be restored # the weight of the new version is reduced by half by pack n, so it needs to be restored
if self.new_quant_version: if self.new_quant_version:
@@ -347,13 +358,10 @@ class AscendW4A8DynamicFusedMoEMethod:
def pack_to_int32(self, weight: torch.Tensor): def pack_to_int32(self, weight: torch.Tensor):
if self.new_quant_version: if self.new_quant_version:
group_num, k, n = weight.shape
assert n % 4 == 0, "the last dim of weight needs to be divided by 4"
packed_n = n // 4
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
packed_weight = torch.from_numpy( assert weight.shape[
np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32)) -1] % 4 == 0, "the last dim of weight needs to be divided by 4"
return packed_weight.reshape(group_num, k, packed_n).npu() return weight.view(torch.int32).contiguous()
else: else:
return torch_npu.npu_quantize(weight.to(torch.float32), return torch_npu.npu_quantize(weight.to(torch.float32),
torch.tensor([1.]).npu(), None, torch.tensor([1.]).npu(), None,
@@ -365,23 +373,29 @@ class AscendW4A8DynamicFusedMoEMethod:
1, 2).contiguous() 1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose( layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous() 1, 2).contiguous()
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
1, 2).contiguous()
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose(
1, 2).contiguous()
layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale_second.data, w13_bias = self.process_scale( w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
layer, "w13_weight_scale_second") else None
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
layer, "w2_weight_scale_second") else None
layer.w13_weight_scale.data, w13_bias = self.process_scale(
layer.w13_weight, layer.w13_weight_scale.data, layer.w13_weight, layer.w13_weight_scale.data,
layer.w13_weight_scale_second.data) w13_weight_scale_second)
layer.w2_weight_scale_second.data, w2_bias = self.process_scale( layer.w2_weight_scale.data, w2_bias = self.process_scale(
layer.w2_weight, layer.w2_weight_scale.data, layer.w2_weight, layer.w2_weight_scale.data,
layer.w2_weight_scale_second.data) w2_weight_scale_second)
if hasattr(layer, "w13_weight_scale_second"):
# scale_second is no longer used, release this part of the memory
del layer.w13_weight_scale_second
del layer.w2_weight_scale_second
del layer.w13_weight_offset_second
del layer.w2_weight_offset_second
self.update_bias(layer, w13_bias, w2_bias) self.update_bias(layer, w13_bias, w2_bias)
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)

View File

@@ -139,6 +139,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get( self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256) "group_size", 256)
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
self.is_per_channel_weight = self.group_size == 0
quant_version = vllm_config.quant_config.quant_description.get( quant_version = vllm_config.quant_config.quant_description.get(
"version", "0") "version", "0")
# NOTE: new quantize weights: 2 int4 pack into int8 # NOTE: new quantize weights: 2 int4 pack into int8
@@ -188,44 +190,45 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
num_experts, num_experts,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition,
1, 1,
dtype=params_dtype) dtype=torch.float32)
param_dict["w13_weight_offset"] = torch.empty( param_dict["w13_weight_offset"] = torch.empty(
num_experts, num_experts,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition,
1, 1,
dtype=params_dtype) dtype=torch.float32)
param_dict["w13_weight_scale_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=params_dtype)
param_dict["w13_weight_offset_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=params_dtype)
param_dict["w2_weight_scale"] = torch.empty(num_experts, param_dict["w2_weight_scale"] = torch.empty(num_experts,
hidden_sizes, hidden_sizes,
1, 1,
dtype=params_dtype) dtype=torch.float32)
param_dict["w2_weight_offset"] = torch.empty(num_experts, param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes, hidden_sizes,
1, 1,
dtype=params_dtype) dtype=torch.float32)
if not self.is_per_channel_weight:
param_dict["w13_weight_scale_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.float32)
param_dict["w13_weight_offset_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.float32)
param_dict["w2_weight_scale_second"] = torch.empty( param_dict["w2_weight_scale_second"] = torch.empty(
num_experts, num_experts,
hidden_sizes, hidden_sizes,
intermediate_size_per_partition // self.group_size, intermediate_size_per_partition // self.group_size,
dtype=params_dtype) dtype=torch.float32)
param_dict["w2_weight_offset_second"] = torch.empty( param_dict["w2_weight_offset_second"] = torch.empty(
num_experts, num_experts,
hidden_sizes, hidden_sizes,
intermediate_size_per_partition // self.group_size, intermediate_size_per_partition // self.group_size,
dtype=params_dtype) dtype=torch.float32)
if self.new_quant_version: if self.new_quant_version:
param_dict["w13_scale_bias"] = torch.empty( param_dict["w13_scale_bias"] = torch.empty(
@@ -318,8 +321,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_second, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale_second, w2_scale=layer.w2_weight_scale,
w1_scale_bias=layer.w13_scale_bias, w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias, w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights, topk_weights=topk_weights,
@@ -343,8 +346,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale_second, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale_second, w2_scale=layer.w2_weight_scale,
w1_scale_bias=layer.w13_scale_bias, w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias, w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights, topk_weights=topk_weights,
@@ -357,6 +360,14 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
) )
def process_scale(self, weight: torch.Tensor, scale, per_group_scale): def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
scale = scale.transpose(1, 2).contiguous()
if self.is_per_channel_weight:
scale_np = scale.cpu().numpy()
scale_np.dtype = np.uint32
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
np.int64)).npu()
return scale_uint64_tensor, None
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
group_num, k, n = weight.shape group_num, k, n = weight.shape
# the weight of the new version is reduced by half by pack n, so it needs to be restored # the weight of the new version is reduced by half by pack n, so it needs to be restored
if self.new_quant_version: if self.new_quant_version:
@@ -399,13 +410,10 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
def pack_to_int32(self, weight: torch.Tensor): def pack_to_int32(self, weight: torch.Tensor):
if self.new_quant_version: if self.new_quant_version:
group_num, k, n = weight.shape
assert n % 4 == 0, "the last dim of weight needs to be divided by 4"
packed_n = n // 4
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
packed_weight = torch.from_numpy( assert weight.shape[
np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32)) -1] % 4 == 0, "the last dim of weight needs to be divided by 4"
return packed_weight.reshape(group_num, k, packed_n).npu() return weight.view(torch.int32).contiguous()
else: else:
return torch_npu.npu_quantize(weight.to(torch.float32), return torch_npu.npu_quantize(weight.to(torch.float32),
torch.tensor([1.]).npu(), None, torch.tensor([1.]).npu(), None,
@@ -417,21 +425,22 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
1, 2).contiguous() 1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose( layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous() 1, 2).contiguous()
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
1, 2).contiguous() layer, "w13_weight_scale_second") else None
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
1, 2).contiguous() layer, "w2_weight_scale_second") else None
layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose( layer.w13_weight_scale.data, w13_bias = self.process_scale(
1, 2).contiguous()
layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale_second.data, w13_bias = self.process_scale(
layer.w13_weight, layer.w13_weight_scale.data, layer.w13_weight, layer.w13_weight_scale.data,
layer.w13_weight_scale_second.data) w13_weight_scale_second)
layer.w2_weight_scale_second.data, w2_bias = self.process_scale( layer.w2_weight_scale.data, w2_bias = self.process_scale(
layer.w2_weight, layer.w2_weight_scale.data, layer.w2_weight, layer.w2_weight_scale.data,
layer.w2_weight_scale_second.data) w2_weight_scale_second)
if hasattr(layer, "w13_weight_scale_second"):
# scale_second is no longer used, release this part of the memory
del layer.w13_weight_scale_second
del layer.w2_weight_scale_second
del layer.w13_weight_offset_second
del layer.w2_weight_offset_second
self.update_bias(layer, w13_bias, w2_bias) self.update_bias(layer, w13_bias, w2_bias)