[Quantization]300I Duo support w8a8 quantization (#1560)
### What this PR does / why we need it? This pr supports w8a8 on 300I Duo platform. The main change is to use `npu_quant_grouped_matmul_dequant` to replace `npu_grouped_matmul`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? offline inference on 310p runs normally. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
@@ -10,7 +10,8 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
|
||||
AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod,
|
||||
fused_experts, native_grouped_topk,
|
||||
fused_experts, fused_experts_310p,
|
||||
native_grouped_topk,
|
||||
quant_per_tensor, select_experts)
|
||||
|
||||
|
||||
@@ -111,6 +112,25 @@ class TestAscendW8A8LinearMethod(TestBase):
|
||||
expected_y_output += bias
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
|
||||
@patch("torch_npu.npu_quant_matmul")
|
||||
def test_apply_with_x_is_310p(self, mock_npu_quant_matmul, mock_is_310p):
|
||||
layer = MagicMock()
|
||||
layer.aclnn_input_scale = 0.1
|
||||
layer.aclnn_input_offset = 0.2
|
||||
layer.weight = torch.randn(128, 256)
|
||||
layer.deq_scale = 0.3
|
||||
|
||||
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
||||
bias = torch.randn(256)
|
||||
|
||||
expected_y_output = torch.randn(32, 256)
|
||||
mock_npu_quant_matmul.return_value = expected_y_output
|
||||
|
||||
output = self.method.apply(layer, x, bias)
|
||||
expected_y_output += bias
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
def test_process_weights_after_loading(self, mock_npu_format_cast):
|
||||
layer = MagicMock()
|
||||
@@ -221,6 +241,36 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
mock_fused_experts.assert_called_once()
|
||||
self.assertEqual(result.shape, (32, self.hidden_size))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
|
||||
@patch('vllm_ascend.quantization.w8a8.select_experts')
|
||||
@patch('vllm_ascend.quantization.w8a8.fused_experts_310p')
|
||||
def test_apply_is_310p(self, mock_fused_experts_310p, mock_select_experts,
|
||||
mock_is_310p):
|
||||
# Setup
|
||||
mock_layer = MagicMock()
|
||||
x = torch.randn(32, self.hidden_size)
|
||||
router_logits = torch.randn(32, 128) # 128 experts
|
||||
top_k = 2
|
||||
|
||||
# Mock return values
|
||||
mock_select_experts.return_value = (torch.randn(32, top_k),
|
||||
torch.randint(0, 128, (32, top_k)))
|
||||
mock_fused_experts_310p.return_value = torch.randn(
|
||||
32, self.hidden_size)
|
||||
|
||||
# Test
|
||||
result = self.moe_method.apply(layer=mock_layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=True,
|
||||
global_num_experts=128)
|
||||
|
||||
# Assertions
|
||||
mock_select_experts.assert_called_once()
|
||||
mock_fused_experts_310p.assert_called_once()
|
||||
self.assertEqual(result.shape, (32, self.hidden_size))
|
||||
|
||||
|
||||
class TestAscendC8KVCacheMethod(TestBase):
|
||||
|
||||
@@ -255,7 +305,22 @@ class TestAscendC8KVCacheMethod(TestBase):
|
||||
expected_shape = (self.layer.num_kv_heads * self.layer.head_size, )
|
||||
self.assertEqual(param.shape, expected_shape)
|
||||
|
||||
def test_process_weights_after_loading(self):
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=False)
|
||||
def test_process_weights_after_loading_not_310p(self, mock_is_310p):
|
||||
key_data = torch.ones(4 * 64)
|
||||
value_data = torch.ones(4 * 64) * 2
|
||||
|
||||
self.layer.key_antiquant_scale.data = key_data
|
||||
self.layer.value_antiquant_scale.data = value_data
|
||||
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
|
||||
def test_process_weights_after_loading_is_310p(self, mock_is_310p):
|
||||
key_data = torch.ones(4 * 64)
|
||||
value_data = torch.ones(4 * 64) * 2
|
||||
|
||||
@@ -527,6 +592,67 @@ class TestFusedExperts(TestBase):
|
||||
)
|
||||
|
||||
|
||||
class TestFusedExperts310(TestBase):
|
||||
|
||||
@patch('torch_npu.npu_quant_grouped_matmul_dequant')
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
def test_fused_experts_310p_with_expert_map(self, mock_swiglu,
|
||||
mock_get_ep_group,
|
||||
mock_quant_per_tensor,
|
||||
mock_matmul_dequant):
|
||||
num_tokens = 32
|
||||
hidden_size = 128
|
||||
intermediate_size = 256
|
||||
num_experts = 4
|
||||
top_k = 1
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
|
||||
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
|
||||
w1_scale = torch.tensor([0.1])
|
||||
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
|
||||
|
||||
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
|
||||
w2_scale = torch.tensor([0.1])
|
||||
w2_input_scale = torch.tensor([0.2])
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
|
||||
expert_map = torch.arange(num_experts)
|
||||
|
||||
mock_get_ep_group.return_value.world_size = 1
|
||||
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
hidden_states.shape,
|
||||
dtype=torch.int8)
|
||||
|
||||
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
|
||||
intermediate_size)
|
||||
|
||||
mock_matmul_dequant.return_value = hidden_states
|
||||
|
||||
output = fused_experts_310p(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w1_input_scale=w1_input_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (num_tokens, hidden_size))
|
||||
self.assertEqual(mock_matmul_dequant.call_count, 2)
|
||||
|
||||
|
||||
class TestSelectExperts(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
||||
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
@@ -102,6 +104,79 @@ class TestUtils(unittest.TestCase):
|
||||
output_tensor = utils.aligned_16(input_tensor)
|
||||
self.assertEqual(output_tensor.shape[0], 32)
|
||||
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
@mock.patch('vllm_ascend.utils.is_310p')
|
||||
@mock.patch('vllm_ascend.utils.get_ascend_config')
|
||||
def test_maybe_converting_weight_acl_format(self, mock_get_config,
|
||||
mock_310p, mock_npu_cast,
|
||||
mock_get_format):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_310p.return_value = True
|
||||
|
||||
mock_config = mock.MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_get_format.return_value = 1
|
||||
|
||||
mock_npu_cast.return_value = 1
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
self.assertEqual(fused_moe.w13_weight.data, 1)
|
||||
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
@mock.patch('vllm_ascend.utils.is_310p')
|
||||
@mock.patch('vllm_ascend.utils.get_ascend_config')
|
||||
def test_maybe_converting_weight_acl_format_format_true(
|
||||
self, mock_get_config, mock_310p, mock_npu_cast, mock_get_format):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_310p.return_value = True
|
||||
|
||||
mock_config = mock.MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_config.return_value = mock_config
|
||||
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
mock_npu_cast.return_value = 1
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
@mock.patch('vllm_ascend.utils.get_ascend_config')
|
||||
@mock.patch('vllm_ascend.utils.is_310p', return_value=False)
|
||||
def test_maybe_converting_weight_acl_format_not_310_not_graph(
|
||||
self, mock_310p, mock_get_config):
|
||||
mock_config = mock.MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_config.return_value = mock_config
|
||||
|
||||
mock_constant = mock.MagicMock()
|
||||
|
||||
mock_model = mock.MagicMock()
|
||||
utils.maybe_converting_weight_acl_format(mock_model, mock_constant)
|
||||
|
||||
@mock.patch('importlib.util.find_spec')
|
||||
@mock.patch('importlib.import_module')
|
||||
def test_try_register_lib(self, mock_import_module, mock_find_spec):
|
||||
@@ -111,23 +186,17 @@ class TestUtils(unittest.TestCase):
|
||||
lib_name = "existing_lib"
|
||||
lib_info = "Library found and imported successfully"
|
||||
utils.try_register_lib(lib_name, lib_info)
|
||||
mock_find_spec.assert_called_once_with(lib_name)
|
||||
mock_import_module.assert_called_once_with(lib_name)
|
||||
|
||||
# Can't find lib
|
||||
mock_find_spec.return_value = None
|
||||
lib_name = "non_existing_lib"
|
||||
utils.try_register_lib(lib_name)
|
||||
self.assertEqual(2, mock_find_spec.call_count)
|
||||
self.assertEqual(1, mock_import_module.call_count)
|
||||
|
||||
# import error
|
||||
mock_find_spec.return_value = mock.MagicMock()
|
||||
mock_import_module.side_effect = ImportError("import error")
|
||||
lib_name = "error_lib"
|
||||
utils.try_register_lib(lib_name)
|
||||
self.assertEqual(3, mock_find_spec.call_count)
|
||||
self.assertEqual(2, mock_import_module.call_count)
|
||||
|
||||
def test_enable_custom_op(self):
|
||||
result = utils.enable_custom_op()
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import AttentionType
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
|
||||
|
||||
def quant_per_tensor(in_tensor: torch.Tensor,
|
||||
@@ -42,7 +43,7 @@ class AscendW8A8LinearMethod:
|
||||
|
||||
def __init__(self) -> None:
|
||||
# aclnn quant matmul requires to transpose matrix B, set to true by default.
|
||||
self.transpose_weight = True
|
||||
self.transpose_weight = not is_310p()
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
@@ -95,13 +96,24 @@ class AscendW8A8LinearMethod:
|
||||
x = quant_per_tensor(x, layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset)
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
if is_310p():
|
||||
# On 300I Duo platform, we need transpose again if
|
||||
# using nz. This transpose can be skipped in torchair.
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight.data.transpose(1, 0),
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
else:
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
@@ -114,7 +126,8 @@ class AscendW8A8LinearMethod:
|
||||
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
|
||||
@@ -232,6 +245,19 @@ class AscendW8A8FusedMoEMethod:
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
|
||||
if is_310p():
|
||||
return fused_experts_310p(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
@@ -248,41 +274,48 @@ class AscendW8A8FusedMoEMethod:
|
||||
expert_map=expert_map)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
# torch.npu.config.allow_internal_format = True
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
|
||||
2).contiguous()
|
||||
if not is_310p():
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1).to(torch.float32)
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
||||
layer.w13_weight_offset.data.shape[0], -1).to(torch.float16)
|
||||
layer.w13_weight_offset.data.shape[0], -1)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||
layer.w2_weight_scale.data.shape[0], -1).to(torch.float32)
|
||||
layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||
layer.w2_weight_offset.data.shape[0], -1).to(torch.float16)
|
||||
layer.w2_weight_offset.data.shape[0], -1)
|
||||
expanding_factor_w13 = layer.w13_weight.data.shape[1]
|
||||
expanding_factor_w2 = layer.w2_weight.data.shape[1]
|
||||
layer.w13_input_scale.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.repeat(
|
||||
1, expanding_factor_w13)[0:1]).to(torch.float16)
|
||||
|
||||
layer.w2_input_scale.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]).to(
|
||||
torch.float16)
|
||||
if is_310p():
|
||||
layer.w13_input_scale.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.max())
|
||||
layer.w2_input_scale.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.max())
|
||||
else:
|
||||
layer.w13_input_scale.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.repeat(1,
|
||||
expanding_factor_w13)[0:1])
|
||||
layer.w2_input_scale.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
||||
|
||||
layer.w13_input_offset.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.repeat(
|
||||
1, expanding_factor_w13)[0:1]).to(torch.int8)
|
||||
layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1])
|
||||
layer.w2_input_offset.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]).to(
|
||||
torch.int8)
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
||||
|
||||
# NZ
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, 29).contiguous()
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, 29).contiguous()
|
||||
# converting ACL_FORMAT_FRACTAL_NZ.
|
||||
# npu_quant_grouped_matmul_dequant in eager mode does not accept
|
||||
# ACL_FORMAT_FRACTAL_NZ.
|
||||
if not is_310p():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
||||
|
||||
|
||||
class AscendC8KVCacheMethod:
|
||||
@@ -407,6 +440,69 @@ class AscendC8KVCacheMethod:
|
||||
return output
|
||||
|
||||
|
||||
def fused_experts_310p(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w1_input_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
ep_size = get_ep_group().world_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
||||
x=sorted_hidden_states,
|
||||
quantized_weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
group_list=group_list,
|
||||
x_scale=w1_input_scale,
|
||||
quant_mode="pertensor")
|
||||
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
||||
x=gate_up_out,
|
||||
quantized_weight=w2,
|
||||
weight_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
x_scale=w2_input_scale,
|
||||
quant_mode="pertensor")
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
|
||||
@@ -31,6 +31,7 @@ from torch_npu.npu.streams import Event
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
try:
|
||||
# Recent release of torchair has moved these ops to `.scope`.
|
||||
@@ -175,6 +176,28 @@ def aligned_16(tensor: torch.Tensor):
|
||||
return new_tensor
|
||||
|
||||
|
||||
def maybe_converting_weight_acl_format(model, format=ACL_FORMAT_FRACTAL_NZ):
|
||||
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
|
||||
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
|
||||
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
|
||||
# conversion when using torchair graph mode on 300I Duo platform.
|
||||
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
|
||||
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
use_torchair = get_ascend_config().torchair_graph_config.enabled
|
||||
if not is_310p() or not use_torchair:
|
||||
return
|
||||
for module in model.modules():
|
||||
if isinstance(module, FusedMoE):
|
||||
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
||||
return
|
||||
module.w13_weight.data = torch_npu.npu_format_cast(
|
||||
module.w13_weight.data, format)
|
||||
module.w2_weight.data = torch_npu.npu_format_cast(
|
||||
module.w2_weight.data, format)
|
||||
|
||||
|
||||
def try_register_lib(lib_name: str, lib_info: str = ""):
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
@@ -77,6 +77,7 @@ from vllm_ascend.pool.metadata import PoolingMetadata
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
ProfileExecuteDuration, is_310p,
|
||||
maybe_converting_weight_acl_format,
|
||||
vllm_version_is)
|
||||
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
||||
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
||||
@@ -1196,6 +1197,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_batch_size)
|
||||
hidden_states = compiled_model(
|
||||
@@ -1207,6 +1211,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@@ -1878,6 +1885,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv, tuple), "kv_cache must be a tuple"
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
num_tokens)
|
||||
hidden_states = compiled_model(
|
||||
@@ -1889,6 +1900,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
Reference in New Issue
Block a user