diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 5061ff37..3ee8116c 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -4,7 +4,8 @@ from unittest.mock import MagicMock, patch import torch from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.distributed.parallel_state import GroupCoordinator -from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) from tests.ut.base import TestBase from vllm_ascend.ascend_config import init_ascend_config @@ -972,16 +973,13 @@ class TestAscendMLAImpl(TestBase): def test_process_weights_after_loading(self, mock_format_cast): layer = MagicMock(spec=LinearBase) layer.input_size_per_partition = 10 - quant_method = MagicMock() - apply = MagicMock() - quant_method.apply = apply + quant_method = MagicMock(spec=UnquantizedLinearMethod) layer.quant_method = quant_method shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim + self.impl.v_head_dim) shape_1 = self.impl.kv_lora_rank layer.weight = torch.randn(shape_0, shape_1) self.impl.kv_b_proj = layer - apply.return_value = layer.weight.T mock_format_cast.return_value = layer.weight self.impl.process_weights_after_loading(torch.bfloat16) diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index 995a69f4..4718002c 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -1,3 +1,4 @@ +import os import unittest from unittest import mock from unittest.mock import MagicMock, patch @@ -61,22 +62,24 @@ class TestAscendUnquantizedLinearMethod(TestBase): mock_dtype = mock.PropertyMock(return_value=torch.float16) type(self.layer.weight.data).dtype = mock_dtype - @mock.patch("vllm_ascend.ops.linear.is_enable_nz") + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"}) @mock.patch("torch_npu.npu_format_cast") - def test_process_weights_after_loading_enable_nz(self, mock_format_cast, - mock_is_nz): - mock_is_nz.return_value = 1 - self.method.process_weights_after_loading(self.layer) - mock_format_cast.assert_called_once() - - @mock.patch("vllm_ascend.ops.linear.is_enable_nz") - @mock.patch("torch_npu.npu_format_cast") - def test_process_weights_after_loading_disable_nz(self, mock_format_cast, - mock_is_nz): - mock_is_nz.return_value = 0 + def test_process_weights_after_loading_with_nz0(self, mock_format_cast): self.method.process_weights_after_loading(self.layer) mock_format_cast.assert_not_called() + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}) + @mock.patch("torch_npu.npu_format_cast") + def test_process_weights_after_loading_with_nz1(self, mock_format_cast): + self.method.process_weights_after_loading(self.layer) + mock_format_cast.assert_not_called() + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"}) + @mock.patch("torch_npu.npu_format_cast") + def test_process_weights_after_loading_with_nz2(self, mock_format_cast): + self.method.process_weights_after_loading(self.layer) + mock_format_cast.assert_called_once() + class TestAscendRowParallelLinear(BaseLinearTest): diff --git a/tests/ut/quantization/test_w4a4_flatquant_dynamic.py b/tests/ut/quantization/test_w4a4_flatquant_dynamic.py index ce8cfec6..d02ad6bd 100644 --- a/tests/ut/quantization/test_w4a4_flatquant_dynamic.py +++ b/tests/ut/quantization/test_w4a4_flatquant_dynamic.py @@ -199,7 +199,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase): (self.output_size, self.input_size // 8), dtype=torch.int32) mock_pack_weights.return_value = mock_packed - self.method.transpose_weight = False self.method.process_weights_after_loading(layer) mock_pack_weights.assert_called_once() self.assertFalse(hasattr(layer, 'weight')) @@ -212,35 +211,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase): self.assertEqual(layer.left_trans.shape, (24, 24)) self.assertTrue(layer.left_trans.is_contiguous()) - @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights') - def test_process_weights_after_loading_with_transpose( - self, mock_pack_weights): - """Tests weight processing after loading, with transpose.""" - layer = nn.Module() - layer.weight = torch.randint(-8, - 7, (self.output_size, self.input_size), - dtype=torch.int8) - layer.weight_scale = torch.randn(self.output_size, - 1, - dtype=torch.bfloat16) - layer.weight_offset = torch.randn(self.output_size, - 1, - dtype=torch.bfloat16) - layer.left_trans = torch.randn(24, 24) - layer.right_trans = torch.randn(32, 32) - layer.clip_ratio = torch.tensor([0.9]) - mock_packed = torch.randint(0, - 100, - (self.output_size, self.input_size // 8), - dtype=torch.int32) - mock_pack_weights.return_value = mock_packed - self.method.transpose_weight = True - self.method.process_weights_after_loading(layer) - self.assertTrue(hasattr(layer, 'weight_packed')) - self.assertEqual(layer.weight_packed.shape, - (self.input_size // 8, self.output_size)) - self.assertTrue(layer.weight_packed.is_contiguous()) - if __name__ == '__main__': unittest.main(argv=['first-arg-is-ignored'], exit=False) diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index 2116b0c1..3ed2a877 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -62,7 +62,8 @@ class TestAscendW4A8DynamicLinearMethod(TestBase): @patch('torch_npu.npu_convert_weight_to_int4pack') @patch('torch.Tensor.npu') - def test_process_weights_after_loading(self, mock_npu, + @patch("torch_npu.npu_format_cast") + def test_process_weights_after_loading(self, mock_format_cast, mock_npu, mock_npu_convert_weight): mock_npu.side_effect = lambda: torch.zeros( (1, 32), dtype=torch.float32) @@ -85,6 +86,8 @@ class TestAscendW4A8DynamicLinearMethod(TestBase): layer.weight_offset_second = torch.nn.Parameter(torch.empty_like( layer.weight_scale_second.data), requires_grad=False) + mock_format_cast.return_value = layer.weight.data.transpose( + 0, 1).contiguous() self.method.process_weights_after_loading(layer) self.assertTrue(hasattr(layer, "weight_scale_bias")) self.assertEqual(layer.weight_scale_bias.data.shape, (32, )) @@ -110,6 +113,8 @@ class TestAscendW4A8DynamicLinearMethod(TestBase): new_layer.scale_bias = torch.nn.Parameter(torch.zeros( (32, 1), dtype=torch.float32), requires_grad=False) + mock_format_cast.return_value = new_layer.weight.data.transpose( + 0, 1).contiguous() self.method.process_weights_after_loading(new_layer) self.assertEqual(new_layer.scale_bias.data.shape, (32, )) self.assertTrue(hasattr(new_layer, "weight_scale_second")) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index ba006dfc..9fa549b2 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -1,3 +1,4 @@ +import os from unittest.mock import MagicMock, patch import torch @@ -132,20 +133,21 @@ class TestAscendW8A8LinearMethod(TestBase): expected_y_output += bias self.assertTrue(torch.equal(output, expected_y_output)) - @patch("vllm_ascend.quantization.w8a8.is_enable_nz") + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"}) @patch('torch_npu.npu_format_cast') - def test_process_weights_after_loading_not_nz(self, mock_npu_format_cast, - mock_is_nz): + def test_process_weights_after_loading_with_nz0(self, + mock_npu_format_cast): layer = MagicMock() - layer.weight.data = torch.randn(128, 256) + layer.weight.data = torch.randint(-127, + 128, (128, 256), + dtype=torch.int8) layer.input_scale.data = torch.tensor([0.1]) layer.input_offset.data = torch.tensor([0]) layer.deq_scale = torch.tensor([0.5]) layer.weight_scale.data = torch.randn(128, 1) layer.weight_offset.data = torch.randn(128, 1) - mock_is_nz.return_value = 0 mock_npu_format_cast.return_value = MagicMock self.method.process_weights_after_loading(layer) @@ -160,20 +162,50 @@ class TestAscendW8A8LinearMethod(TestBase): self.assertEqual(layer.weight_offset.data.shape, (128, )) mock_npu_format_cast.assert_not_called() - @patch("vllm_ascend.quantization.w8a8.is_enable_nz") + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}) @patch('torch_npu.npu_format_cast') - def test_process_weights_after_loading_nz(self, mock_npu_format_cast, - mock_is_nz): + def test_process_weights_after_loading_with_nz1(self, + mock_npu_format_cast): layer = MagicMock() - layer.weight.data = torch.randn(128, 256) + layer.weight.data = torch.randint(-127, + 128, (128, 256), + dtype=torch.int8) + layer.input_scale.data = torch.tensor([0.1]) + layer.input_offset.data = torch.tensor([0]) + layer.deq_scale = torch.tensor([0.5]) + layer.weight_scale.data = torch.randn(128, 1) + layer.weight_offset.data = torch.randn(128, 1) + + mock_npu_format_cast.return_value = MagicMock + self.method.process_weights_after_loading(layer) + + expected_offset = torch.tensor([0]).repeat(256).to(torch.int8) + self.assertTrue( + torch.equal(layer.aclnn_input_offset.data, expected_offset)) + self.assertFalse(layer.aclnn_input_offset.requires_grad) + + self.assertFalse(layer.deq_scale.requires_grad) + + self.assertEqual(layer.weight_scale.data.shape, (128, )) + self.assertEqual(layer.weight_offset.data.shape, (128, )) + mock_npu_format_cast.assert_called_once() + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"}) + @patch('torch_npu.npu_format_cast') + def test_process_weights_after_loading_with_nz2(self, + mock_npu_format_cast): + layer = MagicMock() + + layer.weight.data = torch.randint(-127, + 128, (128, 256), + dtype=torch.int8) layer.input_scale.data = torch.tensor([0.1]) layer.input_offset.data = torch.tensor([0]) layer.deq_scale = torch.tensor([0.5]) layer.weight_scale.data = torch.randn(128, 1) layer.weight_offset.data = torch.randn(128, 1) - mock_is_nz.return_value = 1 mock_npu_format_cast.return_value = MagicMock self.method.process_weights_after_loading(layer) diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 3be5e783..6219e686 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -35,14 +35,6 @@ class TestUtils(TestBase): from vllm_ascend import platform importlib.reload(platform) - def test_is_enable_nz(self): - with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ", - 1): - self.assertTrue(utils.is_enable_nz()) - with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ", - 0): - self.assertFalse(utils.is_enable_nz()) - def test_nd_to_nz_2d(self): # can be divided by 16 input_tensor = torch.randn(32, 64) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 0646b41c..17a5f04e 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -14,8 +14,7 @@ from vllm.distributed import (get_decode_context_model_parallel_rank, get_pcp_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import logger -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.utils.math_utils import cdiv, round_down from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import MLAAttentionSpec @@ -38,8 +37,8 @@ from vllm_ascend.ops.shared_weight_layer import ( register_layer_to_shared_weight_series) from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - flashcomm2_o_shared_enabled, is_enable_nz, +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, + flashcomm2_o_shared_enabled, maybe_trans_nz, weak_ref_tensors) from vllm_ascend.worker.npu_input_batch import NPUInputBatch @@ -796,40 +795,11 @@ class AscendMLAImpl(MLAAttentionImpl): return ql_nope.transpose(0, 1), q_pe def process_weights_after_loading(self, act_dtype: torch.dtype): - - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - try: - return getattr(layer, attr) - except AttributeError: - pass - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - # Weight will be reshaped next. To be on the safe side, the format - # of the weight should be reverted to FRACTAL_AND. - layer.weight.data = torch_npu.npu_format_cast( - layer.weight.data, ACL_FORMAT_FRACTAL_ND) - return layer.weight - - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + # NOTE: We currently do not support quant kv_b_proj. + assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod) + # NOTE: Weight will be reshaped next, we need to revert and transpose it. + kv_b_proj_weight = torch_npu.npu_format_cast( + self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( @@ -852,15 +822,8 @@ class AscendMLAImpl(MLAAttentionImpl): # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() - # Function `get_and_maybe_dequant_weights` will cast the weights to - # FRACTAL_AND. So we need to cast to FRACTAL_NZ again. - if is_enable_nz(): - self.kv_b_proj.weight.data = torch_npu.npu_format_cast( - self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ) - - # Waiting for BMM NZ support - # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) - # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + # TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz + # self.W_UV = maybe_trans_nz(self.W_UV) if self.enable_mlapo: # Currently mlapo only supports W8A8 quantization in MLA scenario @@ -875,6 +838,9 @@ class AscendMLAImpl(MLAAttentionImpl): "thus mlapo is disabled for these layers.") if self.enable_mlapo: self._process_weights_for_fused_mlapo(act_dtype) + else: + # if mlapo, W_UK_T can't trans nz + self.W_UK_T = maybe_trans_nz(self.W_UK_T) if self.fc2_o_shared_enable and is_hidden_layer( self.vllm_config, self.o_proj): diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index a53e9423..c2fd9dd5 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -9,7 +9,7 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.forward_context import get_forward_context from vllm.logger import logger -from vllm.model_executor.layers.linear import (LinearBase, ReplicatedLinear, +from vllm.model_executor.layers.linear import (ReplicatedLinear, UnquantizedLinearMethod) from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.utils import AttentionCGSupport @@ -29,9 +29,8 @@ from vllm_ascend.ops.shared_weight_layer import ( from vllm_ascend.ops.triton.rope import rope_forward_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - _round_up, dispose_layer, enable_sp, - is_enable_nz, replace_layer) +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer, + enable_sp, maybe_trans_nz, replace_layer) from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: @@ -404,40 +403,11 @@ class AscendSFAImpl(MLAAttentionImpl): self.cp_size = 1 def process_weights_after_loading(self, act_dtype: torch.dtype): - - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - try: - return getattr(layer, attr) - except AttributeError: - pass - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - # Weight will be reshaped next. To be on the safe side, the format - # of the weight should be reverted to FRACTAL_AND. - layer.weight.data = torch_npu.npu_format_cast( - layer.weight.data, ACL_FORMAT_FRACTAL_ND) - return layer.weight - - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + # NOTE: We currently do not support quant kv_b_proj. + assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod) + # NOTE: Weight will be reshaped next, we need to revert and transpose it. + kv_b_proj_weight = torch_npu.npu_format_cast( + self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( @@ -460,15 +430,9 @@ class AscendSFAImpl(MLAAttentionImpl): # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() - # Function `get_and_maybe_dequant_weights` will cast the weights to - # FRACTAL_AND. So we need to cast to FRACTAL_NZ again. - if is_enable_nz(): - self.kv_b_proj.weight.data = torch_npu.npu_format_cast( - self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ) + # TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz + # self.W_UV = maybe_trans_nz(self.W_UV) - # Waiting for BMM NZ support - # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) - # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) # Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory dispose_layer(self.kv_b_proj) @@ -502,6 +466,9 @@ class AscendSFAImpl(MLAAttentionImpl): logger.warning_once(msg) else: self._process_weights_for_fused_mlapo(act_dtype) + if not self.enable_mlapo: + # if mlapo, W_UK_T can't trans nz + self.W_UK_T = maybe_trans_nz(self.W_UK_T) def _v_up_proj(self, x): forward_context = get_forward_context() diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 4e92d800..c3c5e967 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -123,7 +123,10 @@ env_variables: Dict[str, Callable[[], Any]] = { lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))), "VLLM_ASCEND_ENABLE_MLAPO": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))), - # Whether to enable transpose weight and cast format to FRACTAL_NZ. + # Whether to enable weight cast format to FRACTAL_NZ. + # 0: close nz; + # 1: only quant case enable nz; + # 2: enable nz as long as possible. "VLLM_ASCEND_ENABLE_NZ": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)), # Decide whether we should enable CP parallelism. diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 2d0e7afc..2a331ed8 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -19,7 +19,6 @@ from typing import Any, Callable, Optional import torch import torch.nn.functional as F -import torch_npu from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce) @@ -48,8 +47,8 @@ from vllm_ascend.quantization.w4a8_dynamic import \ AscendW4A8DynamicFusedMoEMethod from vllm_ascend.quantization.w8a8_dynamic import \ AscendW8A8DynamicFusedMoEMethod -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, - enable_sp, get_ascend_device_type, is_enable_nz, +from vllm_ascend.utils import (AscendDeviceType, enable_sp, + get_ascend_device_type, maybe_trans_nz, npu_stream_switch, shared_expert_dp_enabled, shared_experts_calculation_stream) @@ -73,12 +72,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): 1, 2).contiguous() layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) - if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz( - ): - layer.w13_weight.data = torch_npu.npu_format_cast( - layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) - layer.w2_weight.data = torch_npu.npu_format_cast( - layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) + if get_ascend_device_type() != AscendDeviceType._310P: + layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) + layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) def apply(self, layer: torch.nn.Module, diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 53a3b26b..b5fc2bcc 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -24,7 +24,6 @@ from typing import Optional, Union import torch import torch.nn as nn -import torch_npu from torch.nn.parameter import Parameter from vllm.config import get_current_vllm_config from vllm.distributed import divide @@ -37,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import \ from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz +from vllm_ascend.utils import maybe_trans_nz class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): @@ -45,11 +44,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - if "conv1d" not in layer.prefix and ( - is_enable_nz() and layer.weight.data.dtype - in [torch.float16, torch.bfloat16]): - layer.weight.data = torch_npu.npu_format_cast( - layer.weight.data, ACL_FORMAT_FRACTAL_NZ) + if "conv1d" not in layer.prefix: + layer.weight.data = maybe_trans_nz(layer.weight.data) # TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group diff --git a/vllm_ascend/quantization/w4a4_flatquant_dynamic.py b/vllm_ascend/quantization/w4a4_flatquant_dynamic.py index 326980fc..f13dae2f 100644 --- a/vllm_ascend/quantization/w4a4_flatquant_dynamic.py +++ b/vllm_ascend/quantization/w4a4_flatquant_dynamic.py @@ -86,7 +86,6 @@ class AscendW4A4FlatQuantDynamicLinearMethod: input_size = 0 def __init__(self): - self.transpose_weight = False self.sym = True @staticmethod @@ -176,9 +175,8 @@ class AscendW4A4FlatQuantDynamicLinearMethod: return output def process_weights_after_loading(self, layer): + # NOTE: Currently, w4a4 can't support weight nz weight_packed = pack_int4_weights(layer.weight.data) - if self.transpose_weight: - weight_packed = weight_packed.transpose(0, 1).contiguous() layer.register_parameter( 'weight_packed', torch.nn.Parameter(weight_packed, requires_grad=False)) diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index dc5f580d..45a7bc18 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.experts_selector import select_experts -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz +from vllm_ascend.utils import maybe_trans_nz class AscendW4A8DynamicLinearMethod: @@ -35,8 +35,6 @@ class AscendW4A8DynamicLinearMethod: """ def __init__(self): - self.transpose_weight = True - vllm_config = get_current_vllm_config() self.group_size = vllm_config.quant_config.quant_description.get( "group_size", 256) @@ -170,8 +168,8 @@ class AscendW4A8DynamicLinearMethod: ) def process_weights_after_loading(self, layer: torch.nn.Module): - if self.transpose_weight: - layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight.data = maybe_trans_nz(layer.weight.data) layer.weight_scale.data = layer.weight_scale.data.flatten().to( torch.float32) layer.weight_offset.data = layer.weight_offset.data.flatten() @@ -214,8 +212,6 @@ class AscendW4A8DynamicFusedMoEMethod: """ def __init__(self): - self.transpose_weight = True - self.ep_group = get_ep_group() vllm_config = get_current_vllm_config() @@ -462,11 +458,10 @@ class AscendW4A8DynamicFusedMoEMethod: torch.quint4x2, -1, False) def process_weights_after_loading(self, layer): - if self.transpose_weight: - layer.w13_weight.data = layer.w13_weight.data.transpose( - 1, 2).contiguous() - layer.w2_weight.data = layer.w2_weight.data.transpose( - 1, 2).contiguous() + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose(1, + 2).contiguous() w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr( layer, "w13_weight_scale_second") else None @@ -487,10 +482,7 @@ class AscendW4A8DynamicFusedMoEMethod: self.update_bias(layer, w13_bias, w2_bias) - if is_enable_nz(): - layer.w13_weight.data = torch_npu.npu_format_cast( - layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) - layer.w2_weight.data = torch_npu.npu_format_cast( - layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) + layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 349946d5..30846a3c 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -21,9 +21,8 @@ import torch import torch_npu from vllm.forward_context import get_forward_context -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, - COMPRESSED_TENSORS_METHOD, AscendDeviceType, - get_ascend_device_type, is_enable_nz) +from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType, + get_ascend_device_type, maybe_trans_nz) def quant_per_tensor(in_tensor: torch.Tensor, @@ -42,9 +41,7 @@ class AscendW8A8LinearMethod: """ def __init__(self) -> None: - # aclnn quant matmul requires to transpose matrix B, set to true by default. - self.transpose_weight = get_ascend_device_type( - ) != AscendDeviceType._310P + pass @staticmethod def get_weight( @@ -189,11 +186,9 @@ class AscendW8A8LinearMethod: layer.aclnn_input_offset = torch.nn.Parameter( layer.input_offset.data.repeat(expanding_factor), requires_grad=False).to(layer.aclnn_input_scale.dtype) - if self.transpose_weight: + if get_ascend_device_type() != AscendDeviceType._310P: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - if is_enable_nz(): - layer.weight.data = torch_npu.npu_format_cast( - layer.weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.weight.data = maybe_trans_nz(layer.weight.data) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) ascend_quant_method = getattr(layer, "ascend_quant_method", "") diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 8952d3cf..e32360ce 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -29,7 +29,7 @@ from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.flash_common3_context import get_flash_common3_context from vllm_ascend.ops.fused_moe.experts_selector import select_experts -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz +from vllm_ascend.utils import maybe_trans_nz class AscendW8A8DynamicLinearMethod: @@ -37,7 +37,7 @@ class AscendW8A8DynamicLinearMethod: """ def __init__(self): - self.transpose_weight = True + pass @staticmethod def get_weight(input_size: int, output_size: int, @@ -91,12 +91,9 @@ class AscendW8A8DynamicLinearMethod: return output def process_weights_after_loading(self, layer): - if self.transpose_weight: - layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() # cast quantized weight tensors in NZ format for higher inference speed - if is_enable_nz(): - layer.weight.data = torch_npu.npu_format_cast( - layer.weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.weight.data = maybe_trans_nz(layer.weight.data) layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.flatten() @@ -107,8 +104,6 @@ class AscendW8A8DynamicFusedMoEMethod: """ def __init__(self): - self.transpose_weight = True - self.ep_group = get_ep_group() vllm_config = get_current_vllm_config() @@ -270,14 +265,12 @@ class AscendW8A8DynamicFusedMoEMethod: mc2_mask=kwargs.get("mc2_mask", None)) def process_weights_after_loading(self, layer): - if self.transpose_weight: - layer.w13_weight.data = layer.w13_weight.data.transpose( - 1, 2).contiguous() - layer.w2_weight.data = layer.w2_weight.data.transpose( - 1, 2).contiguous() - if is_enable_nz(): - torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ) - torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) + layer.w13_weight.data = layer.w13_weight.data.transpose( + 1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose(1, + 2).contiguous() + layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) + layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 9e2431bb..e7b3b8ec 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -122,8 +122,22 @@ def _unregister_print_streams_on_exit(): atexit.register(_unregister_print_streams_on_exit) -def is_enable_nz(): - return envs_ascend.VLLM_ASCEND_ENABLE_NZ +def maybe_trans_nz(weight: torch.Tensor): + if not envs_ascend.VLLM_ASCEND_ENABLE_NZ: + # NZ is not enabled + return weight + if weight.dtype == torch.float: + # fp32 can not support NZ + return weight + elif weight.dtype in {torch.bfloat16, torch.float16}: + # bf16/fp16 will trans nz when VLLM_ASCEND_ENABLE_NZ is 2 + if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2: + return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ) + else: + return weight + else: + # quant weight will trans nz by default + return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ) def _round_up(x: int, align: int): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 327a369a..94ae0916 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -114,10 +114,10 @@ from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.spec_decode.mtp_proposer import MtpProposer -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - AscendDeviceType, ProfileExecuteDuration, - enable_sp, get_ascend_device_type, is_enable_nz, - is_moe_model, lmhead_tp_enable, vllm_version_is) +from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration, + enable_sp, get_ascend_device_type, is_moe_model, + lmhead_tp_enable, maybe_trans_nz, + vllm_version_is) from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.ascend_forward_context import ( # isort: skip @@ -137,9 +137,6 @@ torch.npu.config.allow_internal_format = True if get_ascend_device_type() == AscendDeviceType._310P: torch_npu.npu.set_compile_mode(jit_compile=False) - ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ -else: - ACL_FORMAT = ACL_FORMAT_FRACTAL_ND @dataclass @@ -2225,16 +2222,6 @@ class NPUModelRunner(GPUModelRunner): self.model = get_model(vllm_config=self.vllm_config) if self.dynamic_eplb: model_register(self.model, self.model_config) - if get_ascend_device_type() == AscendDeviceType._310P: - from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, QKVParallelLinear, - RowParallelLinear) - for module in self.model.modules(): - if isinstance(module, - (MergedColumnParallelLinear, - QKVParallelLinear, RowParallelLinear)): - module.weight.data = self._convert_torch_format( - module.weight.data) if self.drafter: logger.info("Loading drafter model...") self.drafter.load_model(self.model) @@ -2255,13 +2242,6 @@ class NPUModelRunner(GPUModelRunner): self.vllm_config, runtime_mode=CUDAGraphMode.FULL) - def _convert_torch_format(self, tensor): - if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \ - and not is_enable_nz(): - return tensor - tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) - return tensor - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -2534,9 +2514,10 @@ class NPUModelRunner(GPUModelRunner): self.model_config.hf_text_config.qk_rope_head_dim ] k_cache = raw_k_tensor.view(dtype).view(k_shape) - k_cache = self._convert_torch_format(k_cache) v_cache = raw_v_tensor.view(dtype).view(v_shape) - v_cache = self._convert_torch_format(v_cache) + if get_ascend_device_type() == AscendDeviceType._310P: + k_cache = maybe_trans_nz(k_cache) + v_cache = maybe_trans_nz(v_cache) if self.use_sparse and raw_dsa_k_tensor is not None: dsa_k_cache_shape = (num_blocks, kv_cache_spec.block_size, 1, 128) diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index de23efa7..324cbed0 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -55,7 +55,7 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (AscendDeviceType, check_ascend_device_type, - enable_sp, get_ascend_device_type, is_enable_nz, + enable_sp, get_ascend_device_type, register_ascend_customop) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -160,7 +160,7 @@ class NPUWorker(WorkerBase): used_bytes / GiB_bytes) def wake_up(self, tags: Optional[list[str]] = None) -> None: - if is_enable_nz(): + if envs_ascend.VLLM_ASCEND_ENABLE_NZ: raise ValueError( "FRACTAL_NZ mode is enabled. This may cause model parameter precision issues " "in the RL scenarios. Please set VLLM_ASCEND_ENABLE_NZ=0.") diff --git a/vllm_ascend/xlite/xlite.py b/vllm_ascend/xlite/xlite.py index 73ae515a..6ffac6db 100644 --- a/vllm_ascend/xlite/xlite.py +++ b/vllm_ascend/xlite/xlite.py @@ -26,10 +26,10 @@ from vllm.logger import logger from vllm.sequence import IntermediateTensors from xlite._C import AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime +import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) -from vllm_ascend.utils import is_enable_nz class XliteModel: @@ -134,7 +134,7 @@ class LlamaXliteModel(XliteModel): config.moe_tp_size = 1 config.attn_type = AttnMHA - config.weight_nz = is_enable_nz() + config.weight_nz = envs_ascend.VLLM_ASCEND_ENABLE_NZ scheduler_config = vllm_config.scheduler_config max_batch_size = scheduler_config.max_num_seqs max_seq_len = vllm_config.model_config.max_model_len