Files
xc-llm-ascend/tests/ut/quantization/test_w4a8_dynamic.py
Wang Kunpeng 8a59367d0c [main][Feature] Support deepseek w4a8 quantization (#2172)
### What this PR does / why we need it?
Supports Deepseek-R1 w4a8 quantization.
Since R1 w4a8 uses mixed quantization, only the MOE layer uses
w4a8_dynamic quantization, so we added the w4a8_dynamic.py file, which
includes the AscendW4A8DynamicFusedMoEMethod class.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py` and
`tests/ut/quantization/test_quantizer.py`
Adding e2e case in
`tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC`
to test deepseek w4a8_dynamic quantized model

#### 1.How to get weights using Modelslim
##### Installation steps
Use the branch master, the commit id is:
298e175d69b3b855111a1e09bbe2fcd12fdb4e24
git clone https://gitee.com/ascend/msit.git
cd msit/msmodelslim
bash install.sh

##### The required transformers environment
transformers>=4.48.2

##### Generate w4a8 weights
cd /example/DeepSeek
Command reference: msmodelslim/example/DeepSeek/README.md Execute the
[pre-check](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#%E8%BF%90%E8%A1%8C%E5%89%8D%E5%BF%85%E6%A3%80)
and [DeepSeek-R1 w4a8 mix
quantization](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-%E6%B7%B7%E5%90%88%E9%87%8F%E5%8C%96%E5%89%8D%E4%B8%89%E5%B1%82-mlpw8a8-dynamic-%E9%87%8F%E5%8C%96mla%E5%85%B1%E4%BA%AB%E4%B8%93%E5%AE%B6w8a8%E9%87%8F%E5%8C%96%E8%B7%AF%E7%94%B1%E4%B8%93%E5%AE%B6w4a8-dynamic%E9%87%8F%E5%8C%96)
chapter
Reference command:python3 quant_deepseek_w4a8.py --model_path {Original
weight path} --save_path {Generate weight path} --mindie_format

##### Adapt to vllm-ascend
Since mindie_format generates mindie format, some adaptation
modifications are needed for vllm-ascend to use it:
`quant_model_description_w8a8_dynamic.json` rename to
`quant_model_description.json`, and add `"group_size": 256`
Modification in `config.json`:`"model_type":deepseekv2` is changed to
`"model_type":deepseek_v3`; `quantization_config` is removed;
tips:The group_size and weights match. If the w4a8 weights are not
generated using msmodelslim, you can check the group_size in
quantization_config in config.json.

#### 2.How to run w4a8
##### a.How to run eager mode
export VLLM_USE_V1=1 # v1

python -m vllm.entrypoints.openai.api_server --model=$1
--trust-remote-code -tp $2 -dp $3 --enable_expert_parallel
--quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6
--enforce-eager
eg: python -m vllm.entrypoints.openai.api_server
--model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4
--enable_expert_parallel --quantization ascend --port 8002
--max-model-len 5120 --max-num-seqs 128 --enforce-eager

##### b.How to run graph mode
export VLLM_USE_V1=1 # v1
export HCCL_BUFFSIZE=1024

python -m vllm.entrypoints.openai.api_server --model=$1
--trust-remote-code -tp $2 -dp $3 --enable_expert_parallel
--quantization ascend --port $4 --max-model-len $5
--additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
eg: python -m vllm.entrypoints.openai.api_server
--model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4
--enable_expert_parallel --quantization ascend --port 8002
--max-model-len 5120
--additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'


- vLLM version: v0.10.0
- vLLM main:
c494f96fbc

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
2025-08-06 10:17:44 +08:00

110 lines
5.5 KiB
Python

from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w4a8_dynamic import (
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod)
class TestAscendW4A8DynamicLinearMethod(TestBase):
def setUp(self):
self.method = AscendW4A8DynamicLinearMethod()
self.method.group_size = 8
def test_get_weight(self):
weight = self.method.get_weight(8, 32, torch.bfloat16)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (32, 8))
def test_get_pergroup_param(self):
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
self.assertEqual(params["weight_scale"].shape, (32, 1))
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset"].shape, (32, 1))
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
@patch("vllm_ascend.ascend_config.get_ascend_config")
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
mock_get_ep_group):
mock_ascend_config = Mock()
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
mock_get_ascend_config.return_value = mock_ascend_config
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
def test_get_weight(self):
param_dict = self.quant_method.get_weight(8, 4, 14, torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape, (8, 8, 14))
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
def test_get_dynamic_quant_param(self, mock_get_current_vllm_config):
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(
quant_description={"group_size": 2})
mock_get_current_vllm_config.return_value = mock_vllm_config
param_dict = self.quant_method.get_dynamic_quant_param(
8, 4, 14, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].shape, (8, 8, 1))
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
(8, 8, 7))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w2_weight_scale"].shape, (8, 14, 1))
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
torch.bfloat16)
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
(8, 14, 2))
@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
layer = torch.nn.Module()
layer.w13_weight = torch.nn.Parameter(torch.zeros((8, 8, 14),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros((8, 14, 4),
dtype=torch.int8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
(8, 8, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w13_weight_offset = torch.nn.Parameter(torch.zeros(
(8, 8, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
(8, 8, 7), dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
(8, 14, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_offset = torch.nn.Parameter(torch.zeros(
(8, 14, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
(8, 14, 2), dtype=torch.bfloat16),
requires_grad=False)
mock_npu.return_value = torch.Tensor()
mock_npu_quantize.return_value = torch.Tensor()
self.quant_method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "w13_scale_bias"))
self.assertEqual(layer.w13_scale_bias.data.shape, (8, 8))
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
self.assertTrue(hasattr(layer, "w2_scale_bias"))
self.assertEqual(layer.w2_scale_bias.data.shape, (8, 14))
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)