[main][quantization] Support deepseek w4a8 per-channel quantization (#3011)
### What this PR does / why we need it?
1.Support deepseek w4a8 per-channel quantization
2.The eager mode supports converting weights to the NZ format
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
#### How to get weights using Modelslim
##### Installation steps
git clone https://gitcode.com/Ascend/msit.git
cd msit/msmodelslim
bash install.sh
##### Generate w4a8 per-channel weights
cd /example/DeepSeek
Command reference: msmodelslim/example/DeepSeek/README.md
- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9
---------
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -108,18 +108,19 @@ Please convert DeepSeek series models using `br_release_MindStudio_8.1.RC2_TR5_2
|
||||
|
||||
### 3. When converting deepseek series models with modelslim, what should you pay attention?
|
||||
|
||||
When using the weight generated by modelslim with the `--dynamic` parameter, if torchair graph mode is enabled, please modify the configuration file in the CANN package to prevent incorrect inference results.
|
||||
When the mla portion of the weights used `W8A8_DYNAMIC` quantization, if torchair graph mode is enabled, please modify the configuration file in the CANN package to prevent incorrect inference results.
|
||||
|
||||
The operation steps are as follows:
|
||||
|
||||
1. Search in the CANN package directory used, for example:
|
||||
find /usr/local/Ascend/ -name fusion_config.json
|
||||
|
||||
2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows:
|
||||
2. Add `"AddRmsNormDynamicQuantFusionPass":"off",` and `"MultiAddRmsNormDynamicQuantFusionPass":"off",` to the fusion_config.json you find, the location is as follows:
|
||||
|
||||
```bash
|
||||
{
|
||||
"Switch":{
|
||||
"GraphFusion":{
|
||||
"AddRmsNormDynamicQuantFusionPass":"off",
|
||||
"MultiAddRmsNormDynamicQuantFusionPass":"off",
|
||||
```
|
||||
|
||||
@@ -35,6 +35,11 @@ QWEN_DENSE_MODELS = [
|
||||
"vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8"
|
||||
]
|
||||
|
||||
DEEPSEEK_W4A8_MODELS = [
|
||||
"vllm-ascend/DeepSeek-V3-W4A8-Pruing",
|
||||
"vllm-ascend/DeepSeek-V3.1-W4A8-puring"
|
||||
]
|
||||
|
||||
|
||||
def test_models_distributed_QwQ():
|
||||
example_prompts = [
|
||||
@@ -109,14 +114,15 @@ def test_models_distributed_Qwen3_W4A8DYNAMIC():
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", DEEPSEEK_W4A8_MODELS)
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"})
|
||||
def test_models_distributed_DeepSeek_W4A8DYNAMIC():
|
||||
def test_models_distributed_DeepSeek_W4A8DYNAMIC(model):
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
max_tokens = 5
|
||||
with VllmRunner(
|
||||
snapshot_download("vllm-ascend/DeepSeek-V3-W4A8-Pruing"),
|
||||
snapshot_download(model),
|
||||
dtype="auto",
|
||||
tensor_parallel_size=2,
|
||||
quantization="ascend",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
@@ -95,19 +94,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
# old quant version weight
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.experts, 2 * self.input_size, 1))
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size))
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
||||
(self.experts, self.output_size, 1))
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size))
|
||||
@@ -119,40 +118,87 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
self.assertEqual(
|
||||
param_dict["w2_scale_bias"].shape,
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
pergroup_param = [
|
||||
"w13_weight_scale_second", "w13_weight_offset_second",
|
||||
"w2_weight_scale_second", "w2_weight_offset_second"
|
||||
]
|
||||
is_contains = any(key in param_dict for key in pergroup_param)
|
||||
self.assertFalse(is_contains)
|
||||
|
||||
def build_layer(self,
|
||||
is_new_quant_version=True,
|
||||
is_per_channel_weight=False):
|
||||
layer = torch.nn.Module()
|
||||
if is_new_quant_version:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
w13_scale_bias = torch.zeros(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
|
||||
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros((self.experts, self.output_size,
|
||||
16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
if not is_per_channel_weight:
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
return layer
|
||||
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('torch_npu.npu_quantize')
|
||||
@patch('torch.Tensor.npu')
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
||||
# old quant version weight
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
new_layer = copy.deepcopy(layer)
|
||||
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize,
|
||||
mock_npu_format_cast):
|
||||
mock_npu.return_value = torch.Tensor()
|
||||
mock_npu_quantize.return_value = torch.Tensor()
|
||||
|
||||
def func_by_args(weight, num_format):
|
||||
return weight
|
||||
|
||||
mock_npu_format_cast.side_effect = func_by_args
|
||||
# old quant version weight
|
||||
layer = self.build_layer(is_new_quant_version=False)
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
||||
self.assertEqual(layer.w13_scale_bias.data.shape,
|
||||
@@ -164,23 +210,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
new_layer.w13_weight.data = torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8)
|
||||
new_layer.w2_weight.data = torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8)
|
||||
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
|
||||
dtype=torch.float32)
|
||||
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros(
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
new_layer = self.build_layer(is_new_quant_version=True)
|
||||
self.quant_method.process_weights_after_loading(new_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(new_layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
per_channel_layer = self.build_layer(is_new_quant_version=True,
|
||||
is_per_channel_weight=True)
|
||||
self.quant_method.process_weights_after_loading(per_channel_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
@@ -85,19 +84,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
# old quant version weight
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.experts, 2 * self.input_size, 1))
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size))
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
||||
(self.experts, self.output_size, 1))
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size))
|
||||
@@ -109,40 +108,80 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
self.assertEqual(
|
||||
param_dict["w2_scale_bias"].shape,
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
pergroup_param = [
|
||||
"w13_weight_scale_second", "w13_weight_offset_second",
|
||||
"w2_weight_scale_second", "w2_weight_offset_second"
|
||||
]
|
||||
is_contains = any(key in param_dict for key in pergroup_param)
|
||||
self.assertFalse(is_contains)
|
||||
|
||||
def build_layer(self,
|
||||
is_new_quant_version=True,
|
||||
is_per_channel_weight=False):
|
||||
layer = torch.nn.Module()
|
||||
if is_new_quant_version:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
w13_scale_bias = torch.zeros(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
|
||||
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros((self.experts, self.output_size,
|
||||
16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
if not is_per_channel_weight:
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
return layer
|
||||
|
||||
@patch('torch_npu.npu_quantize')
|
||||
@patch('torch.Tensor.npu')
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
||||
# old quant version weight
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
new_layer = copy.deepcopy(layer)
|
||||
|
||||
mock_npu.return_value = torch.Tensor()
|
||||
mock_npu_quantize.return_value = torch.Tensor()
|
||||
# old quant version weight
|
||||
layer = self.build_layer(is_new_quant_version=False)
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
||||
self.assertEqual(layer.w13_scale_bias.data.shape,
|
||||
@@ -154,23 +193,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
new_layer.w13_weight.data = torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8)
|
||||
new_layer.w2_weight.data = torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8)
|
||||
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
|
||||
dtype=torch.float32)
|
||||
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros(
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
new_layer = self.build_layer(is_new_quant_version=True)
|
||||
self.quant_method.process_weights_after_loading(new_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(new_layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
per_channel_layer = self.build_layer(is_new_quant_version=True,
|
||||
is_per_channel_weight=True)
|
||||
self.quant_method.process_weights_after_loading(per_channel_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
|
||||
@@ -27,6 +27,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
|
||||
class AscendW4A8DynamicLinearMethod:
|
||||
@@ -132,6 +133,8 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
|
||||
self.is_per_channel_weight = self.group_size == 0
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
# NOTE: new quantize weights: 2 int4 pack into int8
|
||||
@@ -182,44 +185,44 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
if not self.is_per_channel_weight:
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
if self.new_quant_version:
|
||||
param_dict["w13_scale_bias"] = torch.empty(
|
||||
@@ -288,8 +291,8 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_second,
|
||||
w2_scale=layer.w2_weight_scale_second,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
@@ -305,6 +308,14 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
|
||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||
scale = scale.transpose(1, 2).contiguous()
|
||||
if self.is_per_channel_weight:
|
||||
scale_np = scale.cpu().numpy()
|
||||
scale_np.dtype = np.uint32
|
||||
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
|
||||
np.int64)).npu()
|
||||
return scale_uint64_tensor, None
|
||||
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
|
||||
group_num, k, n = weight.shape
|
||||
# the weight of the new version is reduced by half by pack n, so it needs to be restored
|
||||
if self.new_quant_version:
|
||||
@@ -347,13 +358,10 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
|
||||
def pack_to_int32(self, weight: torch.Tensor):
|
||||
if self.new_quant_version:
|
||||
group_num, k, n = weight.shape
|
||||
assert n % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
packed_n = n // 4
|
||||
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
|
||||
packed_weight = torch.from_numpy(
|
||||
np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32))
|
||||
return packed_weight.reshape(group_num, k, packed_n).npu()
|
||||
assert weight.shape[
|
||||
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
return weight.view(torch.int32).contiguous()
|
||||
else:
|
||||
return torch_npu.npu_quantize(weight.to(torch.float32),
|
||||
torch.tensor([1.]).npu(), None,
|
||||
@@ -365,23 +373,29 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
layer.w13_weight_scale_second.data, w13_bias = self.process_scale(
|
||||
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
|
||||
layer, "w13_weight_scale_second") else None
|
||||
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
|
||||
layer, "w2_weight_scale_second") else None
|
||||
layer.w13_weight_scale.data, w13_bias = self.process_scale(
|
||||
layer.w13_weight, layer.w13_weight_scale.data,
|
||||
layer.w13_weight_scale_second.data)
|
||||
layer.w2_weight_scale_second.data, w2_bias = self.process_scale(
|
||||
w13_weight_scale_second)
|
||||
layer.w2_weight_scale.data, w2_bias = self.process_scale(
|
||||
layer.w2_weight, layer.w2_weight_scale.data,
|
||||
layer.w2_weight_scale_second.data)
|
||||
w2_weight_scale_second)
|
||||
if hasattr(layer, "w13_weight_scale_second"):
|
||||
# scale_second is no longer used, release this part of the memory
|
||||
del layer.w13_weight_scale_second
|
||||
del layer.w2_weight_scale_second
|
||||
del layer.w13_weight_offset_second
|
||||
del layer.w2_weight_offset_second
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
|
||||
@@ -139,6 +139,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
|
||||
self.is_per_channel_weight = self.group_size == 0
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
# NOTE: new quantize weights: 2 int4 pack into int8
|
||||
@@ -188,44 +190,45 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
|
||||
if not self.is_per_channel_weight:
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
if self.new_quant_version:
|
||||
param_dict["w13_scale_bias"] = torch.empty(
|
||||
@@ -318,8 +321,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_second,
|
||||
w2_scale=layer.w2_weight_scale_second,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
@@ -343,8 +346,8 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_second,
|
||||
w2_scale=layer.w2_weight_scale_second,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
@@ -357,6 +360,14 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
)
|
||||
|
||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||
scale = scale.transpose(1, 2).contiguous()
|
||||
if self.is_per_channel_weight:
|
||||
scale_np = scale.cpu().numpy()
|
||||
scale_np.dtype = np.uint32
|
||||
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
|
||||
np.int64)).npu()
|
||||
return scale_uint64_tensor, None
|
||||
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
|
||||
group_num, k, n = weight.shape
|
||||
# the weight of the new version is reduced by half by pack n, so it needs to be restored
|
||||
if self.new_quant_version:
|
||||
@@ -399,13 +410,10 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
|
||||
def pack_to_int32(self, weight: torch.Tensor):
|
||||
if self.new_quant_version:
|
||||
group_num, k, n = weight.shape
|
||||
assert n % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
packed_n = n // 4
|
||||
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
|
||||
packed_weight = torch.from_numpy(
|
||||
np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32))
|
||||
return packed_weight.reshape(group_num, k, packed_n).npu()
|
||||
assert weight.shape[
|
||||
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
return weight.view(torch.int32).contiguous()
|
||||
else:
|
||||
return torch_npu.npu_quantize(weight.to(torch.float32),
|
||||
torch.tensor([1.]).npu(), None,
|
||||
@@ -417,21 +425,22 @@ class TorchairAscendW4A8DynamicFusedMoEMethod:
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
layer.w13_weight_scale_second.data, w13_bias = self.process_scale(
|
||||
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
|
||||
layer, "w13_weight_scale_second") else None
|
||||
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
|
||||
layer, "w2_weight_scale_second") else None
|
||||
layer.w13_weight_scale.data, w13_bias = self.process_scale(
|
||||
layer.w13_weight, layer.w13_weight_scale.data,
|
||||
layer.w13_weight_scale_second.data)
|
||||
layer.w2_weight_scale_second.data, w2_bias = self.process_scale(
|
||||
w13_weight_scale_second)
|
||||
layer.w2_weight_scale.data, w2_bias = self.process_scale(
|
||||
layer.w2_weight, layer.w2_weight_scale.data,
|
||||
layer.w2_weight_scale_second.data)
|
||||
w2_weight_scale_second)
|
||||
if hasattr(layer, "w13_weight_scale_second"):
|
||||
# scale_second is no longer used, release this part of the memory
|
||||
del layer.w13_weight_scale_second
|
||||
del layer.w2_weight_scale_second
|
||||
del layer.w13_weight_offset_second
|
||||
del layer.w2_weight_offset_second
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user