diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 21e0626..a27e921 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -10,7 +10,8 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod, - fused_experts, native_grouped_topk, + fused_experts, fused_experts_310p, + native_grouped_topk, quant_per_tensor, select_experts) @@ -111,6 +112,25 @@ class TestAscendW8A8LinearMethod(TestBase): expected_y_output += bias self.assertTrue(torch.equal(output, expected_y_output)) + @patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True) + @patch("torch_npu.npu_quant_matmul") + def test_apply_with_x_is_310p(self, mock_npu_quant_matmul, mock_is_310p): + layer = MagicMock() + layer.aclnn_input_scale = 0.1 + layer.aclnn_input_offset = 0.2 + layer.weight = torch.randn(128, 256) + layer.deq_scale = 0.3 + + x = torch.randint(-128, 127, (32, 128), dtype=torch.int8) + bias = torch.randn(256) + + expected_y_output = torch.randn(32, 256) + mock_npu_quant_matmul.return_value = expected_y_output + + output = self.method.apply(layer, x, bias) + expected_y_output += bias + self.assertTrue(torch.equal(output, expected_y_output)) + @patch('torch_npu.npu_format_cast') def test_process_weights_after_loading(self, mock_npu_format_cast): layer = MagicMock() @@ -221,6 +241,36 @@ class TestAscendW8A8FusedMoEMethod(TestBase): mock_fused_experts.assert_called_once() self.assertEqual(result.shape, (32, self.hidden_size)) + @patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True) + @patch('vllm_ascend.quantization.w8a8.select_experts') + @patch('vllm_ascend.quantization.w8a8.fused_experts_310p') + def test_apply_is_310p(self, mock_fused_experts_310p, mock_select_experts, + mock_is_310p): + # Setup + mock_layer = MagicMock() + x = torch.randn(32, self.hidden_size) + router_logits = torch.randn(32, 128) # 128 experts + top_k = 2 + + # Mock return values + mock_select_experts.return_value = (torch.randn(32, top_k), + torch.randint(0, 128, (32, top_k))) + mock_fused_experts_310p.return_value = torch.randn( + 32, self.hidden_size) + + # Test + result = self.moe_method.apply(layer=mock_layer, + x=x, + router_logits=router_logits, + top_k=top_k, + renormalize=True, + global_num_experts=128) + + # Assertions + mock_select_experts.assert_called_once() + mock_fused_experts_310p.assert_called_once() + self.assertEqual(result.shape, (32, self.hidden_size)) + class TestAscendC8KVCacheMethod(TestBase): @@ -255,7 +305,22 @@ class TestAscendC8KVCacheMethod(TestBase): expected_shape = (self.layer.num_kv_heads * self.layer.head_size, ) self.assertEqual(param.shape, expected_shape) - def test_process_weights_after_loading(self): + @patch("vllm_ascend.quantization.w8a8.is_310p", return_value=False) + def test_process_weights_after_loading_not_310p(self, mock_is_310p): + key_data = torch.ones(4 * 64) + value_data = torch.ones(4 * 64) * 2 + + self.layer.key_antiquant_scale.data = key_data + self.layer.value_antiquant_scale.data = value_data + + self.method.process_weights_after_loading(self.layer) + + self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256)) + self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1)) + self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2)) + + @patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True) + def test_process_weights_after_loading_is_310p(self, mock_is_310p): key_data = torch.ones(4 * 64) value_data = torch.ones(4 * 64) * 2 @@ -527,6 +592,67 @@ class TestFusedExperts(TestBase): ) +class TestFusedExperts310(TestBase): + + @patch('torch_npu.npu_quant_grouped_matmul_dequant') + @patch("vllm_ascend.quantization.w8a8.quant_per_tensor") + @patch('vllm_ascend.quantization.w8a8.get_ep_group') + @patch('torch_npu.npu_swiglu') + def test_fused_experts_310p_with_expert_map(self, mock_swiglu, + mock_get_ep_group, + mock_quant_per_tensor, + mock_matmul_dequant): + num_tokens = 32 + hidden_size = 128 + intermediate_size = 256 + num_experts = 4 + top_k = 1 + + hidden_states = torch.randn(num_tokens, hidden_size) + + w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size) + w1_scale = torch.tensor([0.1]) + w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]]) + + w2 = torch.randn(num_experts, hidden_size, intermediate_size) + w2_scale = torch.tensor([0.1]) + w2_input_scale = torch.tensor([0.2]) + + topk_weights = torch.rand(num_tokens, top_k) + topk_ids = torch.randint(0, num_experts, (num_tokens, top_k)) + expert_map = torch.arange(num_experts) + + mock_get_ep_group.return_value.world_size = 1 + + mock_quant_per_tensor.return_value = torch.randint(-128, + 127, + hidden_states.shape, + dtype=torch.int8) + + mock_swiglu.return_value = torch.randn(num_tokens * top_k, + intermediate_size) + + mock_matmul_dequant.return_value = hidden_states + + output = fused_experts_310p( + hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w1_input_scale=w1_input_scale, + w2=w2, + w2_scale=w2_scale, + w2_input_scale=w2_input_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=num_experts, + expert_map=expert_map, + ) + + self.assertEqual(output.shape, (num_tokens, hidden_size)) + self.assertEqual(mock_matmul_dequant.call_count, 2) + + class TestSelectExperts(TestBase): def setUp(self): diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 83e65b0..577a7ab 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -1,3 +1,5 @@ +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + import math import os import unittest @@ -102,6 +104,79 @@ class TestUtils(unittest.TestCase): output_tensor = utils.aligned_16(input_tensor) self.assertEqual(output_tensor.shape[0], 32) + @mock.patch('torch_npu.get_npu_format') + @mock.patch('torch_npu.npu_format_cast') + @mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE', + new=mock.MagicMock) + @mock.patch('vllm_ascend.utils.is_310p') + @mock.patch('vllm_ascend.utils.get_ascend_config') + def test_maybe_converting_weight_acl_format(self, mock_get_config, + mock_310p, mock_npu_cast, + mock_get_format): + ACL_FORMAT_FRACTAL_NZ = 29 + mock_310p.return_value = True + + mock_config = mock.MagicMock() + mock_config.torchair_graph_config.enabled = True + mock_get_config.return_value = mock_config + mock_get_format.return_value = 1 + + mock_npu_cast.return_value = 1 + + fused_moe = mock.MagicMock() + fused_moe.w13_weight = mock.MagicMock() + fused_moe.w2_weight = mock.MagicMock() + fused_moe.w13_weight.data = torch.randn(128, 256) + fused_moe.w2_weight.data = torch.randn(256, 128) + model = mock.MagicMock() + model.modules.return_value = [fused_moe] + + utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) + self.assertEqual(fused_moe.w13_weight.data, 1) + + @mock.patch('torch_npu.get_npu_format') + @mock.patch('torch_npu.npu_format_cast') + @mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE', + new=mock.MagicMock) + @mock.patch('vllm_ascend.utils.is_310p') + @mock.patch('vllm_ascend.utils.get_ascend_config') + def test_maybe_converting_weight_acl_format_format_true( + self, mock_get_config, mock_310p, mock_npu_cast, mock_get_format): + ACL_FORMAT_FRACTAL_NZ = 29 + mock_310p.return_value = True + + mock_config = mock.MagicMock() + mock_config.torchair_graph_config.enabled = True + mock_get_config.return_value = mock_config + mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ + + mock_npu_cast.return_value = 1 + + fused_moe = mock.MagicMock() + fused_moe.w13_weight = mock.MagicMock() + fused_moe.w2_weight = mock.MagicMock() + fused_moe.w13_weight.data = torch.randn(128, 256) + fused_moe.w2_weight.data = torch.randn(256, 128) + model = mock.MagicMock() + model.modules.return_value = [fused_moe] + + mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ + + utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) + + @mock.patch('vllm_ascend.utils.get_ascend_config') + @mock.patch('vllm_ascend.utils.is_310p', return_value=False) + def test_maybe_converting_weight_acl_format_not_310_not_graph( + self, mock_310p, mock_get_config): + mock_config = mock.MagicMock() + mock_config.torchair_graph_config.enabled = False + mock_get_config.return_value = mock_config + + mock_constant = mock.MagicMock() + + mock_model = mock.MagicMock() + utils.maybe_converting_weight_acl_format(mock_model, mock_constant) + @mock.patch('importlib.util.find_spec') @mock.patch('importlib.import_module') def test_try_register_lib(self, mock_import_module, mock_find_spec): @@ -111,23 +186,17 @@ class TestUtils(unittest.TestCase): lib_name = "existing_lib" lib_info = "Library found and imported successfully" utils.try_register_lib(lib_name, lib_info) - mock_find_spec.assert_called_once_with(lib_name) - mock_import_module.assert_called_once_with(lib_name) # Can't find lib mock_find_spec.return_value = None lib_name = "non_existing_lib" utils.try_register_lib(lib_name) - self.assertEqual(2, mock_find_spec.call_count) - self.assertEqual(1, mock_import_module.call_count) # import error mock_find_spec.return_value = mock.MagicMock() mock_import_module.side_effect = ImportError("import error") lib_name = "error_lib" utils.try_register_lib(lib_name) - self.assertEqual(3, mock_find_spec.call_count) - self.assertEqual(2, mock_import_module.call_count) def test_enable_custom_op(self): result = utils.enable_custom_op() diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 0746150..e41e4cb 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -23,6 +23,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.distributed.parallel_state import get_ep_group +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p def quant_per_tensor(in_tensor: torch.Tensor, @@ -42,7 +43,7 @@ class AscendW8A8LinearMethod: def __init__(self) -> None: # aclnn quant matmul requires to transpose matrix B, set to true by default. - self.transpose_weight = True + self.transpose_weight = not is_310p() @staticmethod def get_weight( @@ -95,13 +96,24 @@ class AscendW8A8LinearMethod: x = quant_per_tensor(x, layer.aclnn_input_scale, layer.aclnn_input_offset) quant_bias = layer.quant_bias if tp_rank == 0 else None - output = torch_npu.npu_quant_matmul( - x, - layer.weight, - layer.deq_scale, - bias=quant_bias, - output_dtype=original_dtype, - ) + if is_310p(): + # On 300I Duo platform, we need transpose again if + # using nz. This transpose can be skipped in torchair. + output = torch_npu.npu_quant_matmul( + x, + layer.weight.data.transpose(1, 0), + layer.deq_scale, + bias=quant_bias, + output_dtype=original_dtype, + ) + else: + output = torch_npu.npu_quant_matmul( + x, + layer.weight, + layer.deq_scale, + bias=quant_bias, + output_dtype=original_dtype, + ) return output def process_weights_after_loading(self, layer): @@ -114,7 +126,8 @@ class AscendW8A8LinearMethod: requires_grad=False).to(layer.aclnn_input_scale.dtype) if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) + layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, + ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) @@ -232,6 +245,19 @@ class AscendW8A8FusedMoEMethod: global_num_experts=global_num_experts, ) + if is_310p(): + return fused_experts_310p(hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w1_input_scale=layer.w13_input_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_input_scale=layer.w2_input_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + global_num_experts=global_num_experts, + expert_map=expert_map) return fused_experts(hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale, @@ -248,41 +274,48 @@ class AscendW8A8FusedMoEMethod: expert_map=expert_map) def process_weights_after_loading(self, layer): - # torch.npu.config.allow_internal_format = True - layer.w13_weight.data = layer.w13_weight.data.transpose( - 1, 2).contiguous() - layer.w2_weight.data = layer.w2_weight.data.transpose(1, - 2).contiguous() + if not is_310p(): + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose( + 1, 2).contiguous() layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( - layer.w13_weight_scale.data.shape[0], -1).to(torch.float32) + layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( - layer.w13_weight_offset.data.shape[0], -1).to(torch.float16) + layer.w13_weight_offset.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( - layer.w2_weight_scale.data.shape[0], -1).to(torch.float32) + layer.w2_weight_scale.data.shape[0], -1) layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( - layer.w2_weight_offset.data.shape[0], -1).to(torch.float16) + layer.w2_weight_offset.data.shape[0], -1) expanding_factor_w13 = layer.w13_weight.data.shape[1] expanding_factor_w2 = layer.w2_weight.data.shape[1] - layer.w13_input_scale.data = torch.nn.Parameter( - layer.w13_input_scale.data.repeat( - 1, expanding_factor_w13)[0:1]).to(torch.float16) - layer.w2_input_scale.data = torch.nn.Parameter( - layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]).to( - torch.float16) + if is_310p(): + layer.w13_input_scale.data = torch.nn.Parameter( + layer.w13_input_scale.data.max()) + layer.w2_input_scale.data = torch.nn.Parameter( + layer.w2_input_scale.data.max()) + else: + layer.w13_input_scale.data = torch.nn.Parameter( + layer.w13_input_scale.data.repeat(1, + expanding_factor_w13)[0:1]) + layer.w2_input_scale.data = torch.nn.Parameter( + layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]) + layer.w13_input_offset.data = torch.nn.Parameter( - layer.w13_input_scale.data.repeat( - 1, expanding_factor_w13)[0:1]).to(torch.int8) + layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1]) layer.w2_input_offset.data = torch.nn.Parameter( - layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]).to( - torch.int8) + layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]) - # NZ - layer.w13_weight.data = torch_npu.npu_format_cast( - layer.w13_weight.data, 29).contiguous() - layer.w2_weight.data = torch_npu.npu_format_cast( - layer.w2_weight.data, 29).contiguous() + # converting ACL_FORMAT_FRACTAL_NZ. + # npu_quant_grouped_matmul_dequant in eager mode does not accept + # ACL_FORMAT_FRACTAL_NZ. + if not is_310p(): + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous() + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous() class AscendC8KVCacheMethod: @@ -407,6 +440,69 @@ class AscendC8KVCacheMethod: return output +def fused_experts_310p( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w1_input_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + w2_input_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + global_num_experts: int, + expert_map: torch.Tensor = None, +) -> torch.Tensor: + ep_size = get_ep_group().world_size + local_num_experts = global_num_experts // ep_size + local_num_group = top_k // ep_size + + bsz, _ = hidden_states.shape + flatten_topk_ids = topk_ids.view(-1) + sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) + sorted_topk_ids = sorted_topk_ids.to(torch.int32) + sorted_hidden_states = hidden_states.index_select( + 0, sorted_topk_ids // local_num_group) + + experts_id = torch.arange(0, + local_num_experts, + dtype=topk_ids.dtype, + device=topk_ids.device) + num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to( + torch.float32).sum(0) + topk_scales = topk_weights.view(-1).index_select( + 0, sorted_topk_ids).unsqueeze(-1) + group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) + + gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant( + x=sorted_hidden_states, + quantized_weight=w1, + weight_scale=w1_scale, + group_list=group_list, + x_scale=w1_input_scale, + quant_mode="pertensor") + + gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( + torch.float16) + gate_up_out *= topk_scales + + down_out = torch_npu.npu_quant_grouped_matmul_dequant( + x=gate_up_out, + quantized_weight=w2, + weight_scale=w2_scale, + group_list=group_list, + x_scale=w2_input_scale, + quant_mode="pertensor") + + unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) + unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids) + final_hidden_states = unsorted_hidden_states.reshape( + bsz, top_k // ep_size, -1).sum(1) + + return final_hidden_states + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index a667b38..d02c2ec 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -31,6 +31,7 @@ from torch_npu.npu.streams import Event from vllm.logger import logger import vllm_ascend.envs as envs +from vllm_ascend.ascend_config import get_ascend_config try: # Recent release of torchair has moved these ops to `.scope`. @@ -175,6 +176,28 @@ def aligned_16(tensor: torch.Tensor): return new_tensor +def maybe_converting_weight_acl_format(model, format=ACL_FORMAT_FRACTAL_NZ): + # currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ + # in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ + # is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this + # conversion when using torchair graph mode on 300I Duo platform. + # TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant + # accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode. + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + use_torchair = get_ascend_config().torchair_graph_config.enabled + if not is_310p() or not use_torchair: + return + for module in model.modules(): + if isinstance(module, FusedMoE): + if torch_npu.get_npu_format(module.w13_weight.data) == format: + return + module.w13_weight.data = torch_npu.npu_format_cast( + module.w13_weight.data, format) + module.w2_weight.data = torch_npu.npu_format_cast( + module.w2_weight.data, format) + + def try_register_lib(lib_name: str, lib_info: str = ""): import importlib import importlib.util diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 78e05ed..aac66f7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -77,6 +77,7 @@ from vllm_ascend.pool.metadata import PoolingMetadata from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, ProfileExecuteDuration, is_310p, + maybe_converting_weight_acl_format, vllm_version_is) from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer @@ -1196,6 +1197,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): model_kwargs["kv_caches"] = self.kv_caches model_kwargs["attn_metadata"] = attn_metadata if self.torchair_graph_enabled and not with_prefill: + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_NZ) + compiled_model = self._get_torchair_lazy_compiled_model( padded_batch_size) hidden_states = compiled_model( @@ -1207,6 +1211,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) else: assert self.model is not None + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_ND) + hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -1878,6 +1885,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv, tuple), "kv_cache must be a tuple" torch._dynamo.mark_static(kv[0]) torch._dynamo.mark_static(kv[1]) + + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_NZ) + compiled_model = self._get_torchair_lazy_compiled_model( num_tokens) hidden_states = compiled_model( @@ -1889,6 +1900,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): attn_metadata=attn_metadata, ) else: + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_ND) + hidden_states = model( input_ids=input_ids, positions=positions,