[Quantization]300I Duo support w8a8 quantization (#1560)

### What this PR does / why we need it?
This pr supports w8a8 on 300I Duo platform. The main change is to use
`npu_quant_grouped_matmul_dequant` to replace `npu_grouped_matmul`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
offline inference on 310p runs normally.

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: tianyitang <tangtianyi4@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
Angazenn
2025-07-03 22:12:46 +08:00
committed by GitHub
parent 6d7cb14a24
commit 9fbd8017c0
5 changed files with 369 additions and 41 deletions

View File

@@ -10,7 +10,8 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod,
fused_experts, native_grouped_topk,
fused_experts, fused_experts_310p,
native_grouped_topk,
quant_per_tensor, select_experts)
@@ -111,6 +112,25 @@ class TestAscendW8A8LinearMethod(TestBase):
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_is_310p(self, mock_npu_quant_matmul, mock_is_310p):
layer = MagicMock()
layer.aclnn_input_scale = 0.1
layer.aclnn_input_offset = 0.2
layer.weight = torch.randn(128, 256)
layer.deq_scale = 0.3
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
bias = torch.randn(256)
expected_y_output = torch.randn(32, 256)
mock_npu_quant_matmul.return_value = expected_y_output
output = self.method.apply(layer, x, bias)
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading(self, mock_npu_format_cast):
layer = MagicMock()
@@ -221,6 +241,36 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
mock_fused_experts.assert_called_once()
self.assertEqual(result.shape, (32, self.hidden_size))
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
@patch('vllm_ascend.quantization.w8a8.select_experts')
@patch('vllm_ascend.quantization.w8a8.fused_experts_310p')
def test_apply_is_310p(self, mock_fused_experts_310p, mock_select_experts,
mock_is_310p):
# Setup
mock_layer = MagicMock()
x = torch.randn(32, self.hidden_size)
router_logits = torch.randn(32, 128) # 128 experts
top_k = 2
# Mock return values
mock_select_experts.return_value = (torch.randn(32, top_k),
torch.randint(0, 128, (32, top_k)))
mock_fused_experts_310p.return_value = torch.randn(
32, self.hidden_size)
# Test
result = self.moe_method.apply(layer=mock_layer,
x=x,
router_logits=router_logits,
top_k=top_k,
renormalize=True,
global_num_experts=128)
# Assertions
mock_select_experts.assert_called_once()
mock_fused_experts_310p.assert_called_once()
self.assertEqual(result.shape, (32, self.hidden_size))
class TestAscendC8KVCacheMethod(TestBase):
@@ -255,7 +305,22 @@ class TestAscendC8KVCacheMethod(TestBase):
expected_shape = (self.layer.num_kv_heads * self.layer.head_size, )
self.assertEqual(param.shape, expected_shape)
def test_process_weights_after_loading(self):
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=False)
def test_process_weights_after_loading_not_310p(self, mock_is_310p):
key_data = torch.ones(4 * 64)
value_data = torch.ones(4 * 64) * 2
self.layer.key_antiquant_scale.data = key_data
self.layer.value_antiquant_scale.data = value_data
self.method.process_weights_after_loading(self.layer)
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
def test_process_weights_after_loading_is_310p(self, mock_is_310p):
key_data = torch.ones(4 * 64)
value_data = torch.ones(4 * 64) * 2
@@ -527,6 +592,67 @@ class TestFusedExperts(TestBase):
)
class TestFusedExperts310(TestBase):
@patch('torch_npu.npu_quant_grouped_matmul_dequant')
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
@patch('torch_npu.npu_swiglu')
def test_fused_experts_310p_with_expert_map(self, mock_swiglu,
mock_get_ep_group,
mock_quant_per_tensor,
mock_matmul_dequant):
num_tokens = 32
hidden_size = 128
intermediate_size = 256
num_experts = 4
top_k = 1
hidden_states = torch.randn(num_tokens, hidden_size)
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
w1_scale = torch.tensor([0.1])
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
w2_scale = torch.tensor([0.1])
w2_input_scale = torch.tensor([0.2])
topk_weights = torch.rand(num_tokens, top_k)
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
expert_map = torch.arange(num_experts)
mock_get_ep_group.return_value.world_size = 1
mock_quant_per_tensor.return_value = torch.randint(-128,
127,
hidden_states.shape,
dtype=torch.int8)
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
intermediate_size)
mock_matmul_dequant.return_value = hidden_states
output = fused_experts_310p(
hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w1_input_scale=w1_input_scale,
w2=w2,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=num_experts,
expert_map=expert_map,
)
self.assertEqual(output.shape, (num_tokens, hidden_size))
self.assertEqual(mock_matmul_dequant.call_count, 2)
class TestSelectExperts(TestBase):
def setUp(self):

View File

@@ -1,3 +1,5 @@
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
import math
import os
import unittest
@@ -102,6 +104,79 @@ class TestUtils(unittest.TestCase):
output_tensor = utils.aligned_16(input_tensor)
self.assertEqual(output_tensor.shape[0], 32)
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
@mock.patch('vllm_ascend.utils.is_310p')
@mock.patch('vllm_ascend.utils.get_ascend_config')
def test_maybe_converting_weight_acl_format(self, mock_get_config,
mock_310p, mock_npu_cast,
mock_get_format):
ACL_FORMAT_FRACTAL_NZ = 29
mock_310p.return_value = True
mock_config = mock.MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_config.return_value = mock_config
mock_get_format.return_value = 1
mock_npu_cast.return_value = 1
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
self.assertEqual(fused_moe.w13_weight.data, 1)
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
@mock.patch('vllm_ascend.utils.is_310p')
@mock.patch('vllm_ascend.utils.get_ascend_config')
def test_maybe_converting_weight_acl_format_format_true(
self, mock_get_config, mock_310p, mock_npu_cast, mock_get_format):
ACL_FORMAT_FRACTAL_NZ = 29
mock_310p.return_value = True
mock_config = mock.MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_config.return_value = mock_config
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
mock_npu_cast.return_value = 1
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
@mock.patch('vllm_ascend.utils.get_ascend_config')
@mock.patch('vllm_ascend.utils.is_310p', return_value=False)
def test_maybe_converting_weight_acl_format_not_310_not_graph(
self, mock_310p, mock_get_config):
mock_config = mock.MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_config.return_value = mock_config
mock_constant = mock.MagicMock()
mock_model = mock.MagicMock()
utils.maybe_converting_weight_acl_format(mock_model, mock_constant)
@mock.patch('importlib.util.find_spec')
@mock.patch('importlib.import_module')
def test_try_register_lib(self, mock_import_module, mock_find_spec):
@@ -111,23 +186,17 @@ class TestUtils(unittest.TestCase):
lib_name = "existing_lib"
lib_info = "Library found and imported successfully"
utils.try_register_lib(lib_name, lib_info)
mock_find_spec.assert_called_once_with(lib_name)
mock_import_module.assert_called_once_with(lib_name)
# Can't find lib
mock_find_spec.return_value = None
lib_name = "non_existing_lib"
utils.try_register_lib(lib_name)
self.assertEqual(2, mock_find_spec.call_count)
self.assertEqual(1, mock_import_module.call_count)
# import error
mock_find_spec.return_value = mock.MagicMock()
mock_import_module.side_effect = ImportError("import error")
lib_name = "error_lib"
utils.try_register_lib(lib_name)
self.assertEqual(3, mock_find_spec.call_count)
self.assertEqual(2, mock_import_module.call_count)
def test_enable_custom_op(self):
result = utils.enable_custom_op()

View File

@@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import AttentionType
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
def quant_per_tensor(in_tensor: torch.Tensor,
@@ -42,7 +43,7 @@ class AscendW8A8LinearMethod:
def __init__(self) -> None:
# aclnn quant matmul requires to transpose matrix B, set to true by default.
self.transpose_weight = True
self.transpose_weight = not is_310p()
@staticmethod
def get_weight(
@@ -95,13 +96,24 @@ class AscendW8A8LinearMethod:
x = quant_per_tensor(x, layer.aclnn_input_scale,
layer.aclnn_input_offset)
quant_bias = layer.quant_bias if tp_rank == 0 else None
output = torch_npu.npu_quant_matmul(
x,
layer.weight,
layer.deq_scale,
bias=quant_bias,
output_dtype=original_dtype,
)
if is_310p():
# On 300I Duo platform, we need transpose again if
# using nz. This transpose can be skipped in torchair.
output = torch_npu.npu_quant_matmul(
x,
layer.weight.data.transpose(1, 0),
layer.deq_scale,
bias=quant_bias,
output_dtype=original_dtype,
)
else:
output = torch_npu.npu_quant_matmul(
x,
layer.weight,
layer.deq_scale,
bias=quant_bias,
output_dtype=original_dtype,
)
return output
def process_weights_after_loading(self, layer):
@@ -114,7 +126,8 @@ class AscendW8A8LinearMethod:
requires_grad=False).to(layer.aclnn_input_scale.dtype)
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
ACL_FORMAT_FRACTAL_NZ)
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
@@ -232,6 +245,19 @@ class AscendW8A8FusedMoEMethod:
global_num_experts=global_num_experts,
)
if is_310p():
return fused_experts_310p(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w1_input_scale=layer.w13_input_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
w2_input_scale=layer.w2_input_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
@@ -248,41 +274,48 @@ class AscendW8A8FusedMoEMethod:
expert_map=expert_map)
def process_weights_after_loading(self, layer):
# torch.npu.config.allow_internal_format = True
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
if not is_310p():
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1).to(torch.float32)
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
layer.w13_weight_offset.data.shape[0], -1).to(torch.float16)
layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
layer.w2_weight_scale.data.shape[0], -1).to(torch.float32)
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1).to(torch.float16)
layer.w2_weight_offset.data.shape[0], -1)
expanding_factor_w13 = layer.w13_weight.data.shape[1]
expanding_factor_w2 = layer.w2_weight.data.shape[1]
layer.w13_input_scale.data = torch.nn.Parameter(
layer.w13_input_scale.data.repeat(
1, expanding_factor_w13)[0:1]).to(torch.float16)
layer.w2_input_scale.data = torch.nn.Parameter(
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]).to(
torch.float16)
if is_310p():
layer.w13_input_scale.data = torch.nn.Parameter(
layer.w13_input_scale.data.max())
layer.w2_input_scale.data = torch.nn.Parameter(
layer.w2_input_scale.data.max())
else:
layer.w13_input_scale.data = torch.nn.Parameter(
layer.w13_input_scale.data.repeat(1,
expanding_factor_w13)[0:1])
layer.w2_input_scale.data = torch.nn.Parameter(
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
layer.w13_input_offset.data = torch.nn.Parameter(
layer.w13_input_scale.data.repeat(
1, expanding_factor_w13)[0:1]).to(torch.int8)
layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1])
layer.w2_input_offset.data = torch.nn.Parameter(
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]).to(
torch.int8)
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
# NZ
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, 29).contiguous()
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, 29).contiguous()
# converting ACL_FORMAT_FRACTAL_NZ.
# npu_quant_grouped_matmul_dequant in eager mode does not accept
# ACL_FORMAT_FRACTAL_NZ.
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
class AscendC8KVCacheMethod:
@@ -407,6 +440,69 @@ class AscendC8KVCacheMethod:
return output
def fused_experts_310p(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w1_input_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
) -> torch.Tensor:
ep_size = get_ep_group().world_size
local_num_experts = global_num_experts // ep_size
local_num_group = top_k // ep_size
bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, sorted_topk_ids // local_num_group)
experts_id = torch.arange(0,
local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant(
x=sorted_hidden_states,
quantized_weight=w1,
weight_scale=w1_scale,
group_list=group_list,
x_scale=w1_input_scale,
quant_mode="pertensor")
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
gate_up_out *= topk_scales
down_out = torch_npu.npu_quant_grouped_matmul_dequant(
x=gate_up_out,
quantized_weight=w2,
weight_scale=w2_scale,
group_list=group_list,
x_scale=w2_input_scale,
quant_mode="pertensor")
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
bsz, top_k // ep_size, -1).sum(1)
return final_hidden_states
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,

View File

@@ -31,6 +31,7 @@ from torch_npu.npu.streams import Event
from vllm.logger import logger
import vllm_ascend.envs as envs
from vllm_ascend.ascend_config import get_ascend_config
try:
# Recent release of torchair has moved these ops to `.scope`.
@@ -175,6 +176,28 @@ def aligned_16(tensor: torch.Tensor):
return new_tensor
def maybe_converting_weight_acl_format(model, format=ACL_FORMAT_FRACTAL_NZ):
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
# conversion when using torchair graph mode on 300I Duo platform.
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
use_torchair = get_ascend_config().torchair_graph_config.enabled
if not is_310p() or not use_torchair:
return
for module in model.modules():
if isinstance(module, FusedMoE):
if torch_npu.get_npu_format(module.w13_weight.data) == format:
return
module.w13_weight.data = torch_npu.npu_format_cast(
module.w13_weight.data, format)
module.w2_weight.data = torch_npu.npu_format_cast(
module.w2_weight.data, format)
def try_register_lib(lib_name: str, lib_info: str = ""):
import importlib
import importlib.util

View File

@@ -77,6 +77,7 @@ from vllm_ascend.pool.metadata import PoolingMetadata
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
ProfileExecuteDuration, is_310p,
maybe_converting_weight_acl_format,
vllm_version_is)
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
@@ -1196,6 +1197,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
if self.torchair_graph_enabled and not with_prefill:
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(
padded_batch_size)
hidden_states = compiled_model(
@@ -1207,6 +1211,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
else:
assert self.model is not None
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_ND)
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
@@ -1878,6 +1885,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv, tuple), "kv_cache must be a tuple"
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
hidden_states = compiled_model(
@@ -1889,6 +1900,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_metadata=attn_metadata,
)
else:
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_ND)
hidden_states = model(
input_ids=input_ids,
positions=positions,