diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 1a982ad..25f543e 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -376,7 +376,8 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(q_pe.shape[1], self.impl.num_heads) self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim) - def test_process_weights_after_loading(self): + @patch('torch_npu.npu_format_cast') + def test_process_weights_after_loading(self, mock_format_cast): layer = MagicMock(spec=LinearBase) layer.input_size_per_partition = 10 quant_method = MagicMock() @@ -389,6 +390,7 @@ class TestAscendMLAImpl(TestBase): layer.weight = torch.randn(shape_0, shape_1) self.impl.kv_b_proj = layer apply.return_value = layer.weight.T + mock_format_cast.return_value = layer.weight self.impl.process_weights_after_loading(torch.bfloat16) self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads) diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index 693aea5..e33a39a 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -12,7 +12,7 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest import torch @@ -20,6 +20,7 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm_ascend import ascend_config from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention, CustomDeepseekV2RowParallelLinear) @@ -46,6 +47,13 @@ def test_row_parallel_linear(cls, mock_distributed): def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward, mock_distributed, base_config): mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) + # Make a fake ascend config because of the AscendLinearBase + vllm_config = MagicMock() + vllm_config.additional_config = None + vllm_config.parallel_config.enable_expert_parallel = False + vllm_config.parallel_config.tensor_parallel_size = 1 + vllm_config.kv_transfer_config = None + ascend_config.init_ascend_config(vllm_config) attn = CustomDeepseekV2MLAAttention(config=base_config, hidden_size=128, @@ -78,6 +86,7 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward, kv_lora_rank=16, prefix="layers.1.self_attn") assert hasattr(attn, "q_proj") + ascend_config._ASCEND_CONFIG = None def test_deepseek_v2_lmhead(mock_distributed, vllm_config): @@ -90,6 +99,14 @@ def test_deepseek_v2_lmhead(mock_distributed, vllm_config): config = SimpleConfig() + # Make a fake ascend config because of the AscendLinearBase + vllm_config = MagicMock() + vllm_config.additional_config = None + vllm_config.parallel_config.enable_expert_parallel = False + vllm_config.parallel_config.tensor_parallel_size = 1 + vllm_config.kv_transfer_config = None + ascend_config.init_ascend_config(vllm_config) + # 直接创建lmhead和logits_processor lmhead = ParallelLMHead(config.vocab_size, config.hidden_size) logits_processor = LogitsProcessor(config.vocab_size) @@ -105,3 +122,4 @@ def test_deepseek_v2_lmhead(mock_distributed, vllm_config): return_value=mock_logits): logits = logits_processor(lmhead, mock_output) assert logits.shape == (2, 4, config.vocab_size) + ascend_config._ASCEND_CONFIG = None diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index e22d7ca..e2b0eff 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -5,10 +5,13 @@ from unittest.mock import MagicMock, patch import torch +from tests.ut.base import TestBase from vllm_ascend import ascend_config from vllm_ascend.distributed import parallel_state from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear, - AscendRowParallelLinear) + AscendReplicatedLinear, + AscendRowParallelLinear, + AscendUnquantizedLinearMethod) class BaseLinearTest(unittest.TestCase): @@ -49,6 +52,47 @@ class BaseLinearTest(unittest.TestCase): p.stop() +class TestAscendUnquantizedLinearMethod(TestBase): + + def setUp(self): + self.method = AscendUnquantizedLinearMethod() + + @mock.patch("vllm_ascend.ops.linear.is_enable_nz") + @mock.patch("torch_npu.npu_format_cast") + @mock.patch("torch.version") + def test_process_weights_after_loading_is_8_3_enable_nz( + self, mock_version, mock_format_cast, mock_is_nz): + layer = mock.MagicMock() + + mock_version.cann = "8.3.RC1" + mock_is_nz.return_value = 1 + self.method.process_weights_after_loading(layer) + mock_format_cast.assert_called_once() + + @mock.patch("vllm_ascend.ops.linear.is_enable_nz") + @mock.patch("torch_npu.npu_format_cast") + @mock.patch("torch.version") + def test_process_weights_after_loading_is_8_3_disable_nz( + self, mock_version, mock_format_cast, mock_is_nz): + layer = mock.MagicMock() + + mock_version.cann = "8.3.RC1" + mock_is_nz.return_value = 0 + self.method.process_weights_after_loading(layer) + mock_format_cast.assert_not_called() + + @mock.patch("vllm_ascend.ops.linear.is_enable_nz") + @mock.patch("torch.version") + def test_process_weights_after_loading_not_8_3(self, mock_version, + mock_is_nz): + layer = mock.MagicMock() + + mock_version.cann = "8.2.RC1" + mock_is_nz.return_value = 1 + # Should not raise exception + self.method.process_weights_after_loading(layer) + + class TestAscendRowParallelLinear(BaseLinearTest): def test_mlp_optimize(self): @@ -92,5 +136,24 @@ class TestAscendMergedColumnParallelLinear(BaseLinearTest): self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP) +class TestAscendReplicatedLinear(BaseLinearTest): + + def test_init_disable_tp(self): + linear = AscendReplicatedLinear( + input_size=16, + output_size=8, + ) + self.assertTrue( + isinstance(linear.quant_method, AscendUnquantizedLinearMethod)) + + def test_init_without_disable_tp(self): + linear = AscendReplicatedLinear( + input_size=16, + output_size=8, + ) + self.assertTrue( + isinstance(linear.quant_method, AscendUnquantizedLinearMethod)) + + if __name__ == '__main__': unittest.main() diff --git a/tests/ut/quantization/test_quant_config.py b/tests/ut/quantization/test_quant_config.py index 5a119b4..4622692 100644 --- a/tests/ut/quantization/test_quant_config.py +++ b/tests/ut/quantization/test_quant_config.py @@ -4,10 +4,10 @@ import torch from vllm.attention.layer import Attention from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import LinearBase from tests.ut.base import TestBase +from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod, AscendQuantConfig) from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD @@ -82,7 +82,7 @@ class TestAscendQuantConfig(TestBase): 'is_layer_skipped_ascend', return_value=True): method = self.ascend_config.get_quant_method(linear_layer, ".attn") - self.assertIsInstance(method, UnquantizedLinearMethod) + self.assertIsInstance(method, AscendUnquantizedLinearMethod) # Test quantized layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 69b33a9..98dd8f4 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -137,8 +137,10 @@ class TestAscendW8A8LinearMethod(TestBase): expected_y_output += bias self.assertTrue(torch.equal(output, expected_y_output)) + @patch("vllm_ascend.quantization.w8a8.is_enable_nz") @patch('torch_npu.npu_format_cast') - def test_process_weights_after_loading(self, mock_npu_format_cast): + def test_process_weights_after_loading_not_nz(self, mock_npu_format_cast, + mock_is_nz): layer = MagicMock() layer.weight.data = torch.randn(128, 256) @@ -148,6 +150,7 @@ class TestAscendW8A8LinearMethod(TestBase): layer.weight_scale.data = torch.randn(128, 1) layer.weight_offset.data = torch.randn(128, 1) + mock_is_nz.return_value = 0 mock_npu_format_cast.return_value = MagicMock self.method.process_weights_after_loading(layer) @@ -160,6 +163,35 @@ class TestAscendW8A8LinearMethod(TestBase): self.assertEqual(layer.weight_scale.data.shape, (128, )) self.assertEqual(layer.weight_offset.data.shape, (128, )) + mock_npu_format_cast.assert_not_called() + + @patch("vllm_ascend.quantization.w8a8.is_enable_nz") + @patch('torch_npu.npu_format_cast') + def test_process_weights_after_loading_nz(self, mock_npu_format_cast, + mock_is_nz): + layer = MagicMock() + + layer.weight.data = torch.randn(128, 256) + layer.input_scale.data = torch.tensor([0.1]) + layer.input_offset.data = torch.tensor([0]) + layer.deq_scale = torch.tensor([0.5]) + layer.weight_scale.data = torch.randn(128, 1) + layer.weight_offset.data = torch.randn(128, 1) + + mock_is_nz.return_value = 1 + mock_npu_format_cast.return_value = MagicMock + self.method.process_weights_after_loading(layer) + + expected_offset = torch.tensor([0]).repeat(256).to(torch.int8) + self.assertTrue( + torch.equal(layer.aclnn_input_offset.data, expected_offset)) + self.assertFalse(layer.aclnn_input_offset.requires_grad) + + self.assertFalse(layer.deq_scale.requires_grad) + + self.assertEqual(layer.weight_scale.data.shape, (128, )) + self.assertEqual(layer.weight_offset.data.shape, (128, )) + mock_npu_format_cast.assert_called_once() class TestAscendW8A8FusedMoEMethod(TestBase): diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index b2b3c32..7bc8f5b 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -39,6 +39,14 @@ class TestUtils(TestBase): "Ascend910P1"): self.assertFalse(utils.is_310p()) + def test_is_enable_nz(self): + with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ", + 1): + self.assertTrue(utils.is_enable_nz()) + with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ", + 0): + self.assertFalse(utils.is_enable_nz()) + def test_sleep_mode_enabled(self): utils._SLEEP_MODE_ENABLED = None with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__", diff --git a/tests/ut/torchair/test_utils.py b/tests/ut/torchair/test_utils.py index 8965738..0252851 100644 --- a/tests/ut/torchair/test_utils.py +++ b/tests/ut/torchair/test_utils.py @@ -96,15 +96,17 @@ class TestTorchairUtils(TestBase): self.assertEqual(args[0], expected_name) self.assertEqual(args[1], expected_path) + @mock.patch('vllm_ascend.torchair.utils.is_enable_nz') @mock.patch('torch_npu.get_npu_format') @mock.patch('torch_npu.npu_format_cast') @mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE', new=mock.MagicMock) - def test_converting_weight_acl_format(self, mock_npu_cast, - mock_get_format): + def test_converting_weight_acl_format_to_nz(self, mock_npu_cast, + mock_get_format, mock_is_nz): ACL_FORMAT_FRACTAL_NZ = 29 mock_get_format.return_value = 1 mock_npu_cast.return_value = 1 + mock_is_nz.return_value = 1 fused_moe = mock.MagicMock() fused_moe.w13_weight = mock.MagicMock() @@ -137,3 +139,26 @@ class TestTorchairUtils(TestBase): utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) mock_npu_cast.assert_not_called() + + @mock.patch('vllm_ascend.torchair.utils.is_enable_nz') + @mock.patch('torch_npu.get_npu_format') + @mock.patch('torch_npu.npu_format_cast') + @mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE', + new=mock.MagicMock) + def test_converting_weight_acl_format_no_nz(self, mock_npu_cast, + mock_get_format, mock_is_nz): + ACL_FORMAT_FRACTAL_NZ = 29 + mock_get_format.return_value = 1 + mock_npu_cast.return_value = 1 + mock_is_nz.return_value = 0 + + fused_moe = mock.MagicMock() + fused_moe.w13_weight = mock.MagicMock() + fused_moe.w2_weight = mock.MagicMock() + fused_moe.w13_weight.data = torch.randn(128, 256) + fused_moe.w2_weight.data = torch.randn(256, 128) + model = mock.MagicMock() + model.modules.return_value = [fused_moe] + + utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) + mock_npu_cast.assert_not_called() diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5bf3262..ac27231 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -27,6 +27,8 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, + is_enable_nz) from vllm_ascend.worker.npu_input_batch import InputBatch if TYPE_CHECKING: @@ -595,6 +597,10 @@ class AscendMLAImpl(MLAAttentionImpl): del eye # standardize to (output, input) return dequant_weights.T + # Weight will be reshaped next. To be on the safe side, the format + # of the weight should be reverted to FRACTAL_AND. + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_ND) return layer.weight # we currently do not have quantized bmm's which are needed for @@ -623,6 +629,12 @@ class AscendMLAImpl(MLAAttentionImpl): # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + # Function `get_and_maybe_dequant_weights` will cast the weights to + # FRACTAL_AND. So we need to cast to FRACTAL_NZ again. + if is_enable_nz(): + self.kv_b_proj.weight.data = torch_npu.npu_format_cast( + self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ) + # Waiting for BMM NZ support # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 2db4515..2a00496 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -169,6 +169,9 @@ env_variables: Dict[str, Callable[[], Any]] = { lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)), "VLLM_ASCEND_ENABLE_MLAPO": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))), + # Whether to enable transpose weight and cast format to FRACTAL_NZ. + "VLLM_ASCEND_ENABLE_NZ": + lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)), } # end-env-vars-definition diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index a2d86a3..4e0ead1 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -32,13 +32,15 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, +from vllm.distributed import (divide, get_pp_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, split_tensor_along_last_dim, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, +from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, + ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -57,16 +59,81 @@ from vllm.model_executor.models.deepseek_v2 import ( from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter, maybe_prefix) +from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.models.layers.mla import AscendMLAModules from vllm_ascend.models.layers.sfa import (AscendSFAModules, AscendSparseFlashAttention, Indexer) from vllm_ascend.ops.common_fused_moe import AscendFusedMoE +from vllm_ascend.ops.linear import AscendLinearBase class CustomDeepseekV2RowParallelLinear(RowParallelLinear): + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + # Divide the weight matrix along the first dimension. + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + AscendLinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = nn.Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + self.update_param_tp_status() + def forward( self, input_, diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 64ff76c..b82c931 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -37,7 +37,8 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz, + npu_stream_switch) class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): @@ -83,7 +84,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): w2_data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) - if not is_310p(): + if not is_310p() and is_enable_nz(): layer.w13_weight.data = torch_npu.npu_format_cast( layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w2_weight.data = torch_npu.npu_format_cast( diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 0861940..665ac74 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -24,17 +24,29 @@ from typing import Optional, Union import torch import torch.nn as nn +import torch_npu from torch.nn.parameter import Parameter from vllm.distributed import divide from vllm.model_executor.layers.linear import ( # noqa WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase, MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase, - RowParallelLinear, UnquantizedLinearMethod) + ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig from vllm.model_executor.utils import set_weight_attrs -from vllm_ascend.ops.linear_op import get_parallel_op +from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz + + +class AscendUnquantizedLinearMethod(UnquantizedLinearMethod): + """Linear method without quantization""" + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + if is_enable_nz() and torch.version.cann.startswith("8.3"): + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) # TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group @@ -65,7 +77,7 @@ class AscendLinearBase(LinearBase): self.prefix = prefix if quant_config is None: self.quant_method: Optional[ - QuantizeMethodBase] = UnquantizedLinearMethod() + QuantizeMethodBase] = AscendUnquantizedLinearMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) @@ -364,3 +376,81 @@ class AscendColumnParallelLinear(ColumnParallelLinear): return self.custom_op.apply(input_) return super().forward(input_) + + +class AscendReplicatedLinear(ReplicatedLinear): + """Ascend Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. + disable_tp: Take no effect for replicated linear layers. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + self.custom_op = get_replicated_op(disable_tp, prefix, self) + # If MergedReplicatedLinear, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = self.output_sizes + else: + self.output_partition_sizes = [output_size] + + AscendLinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights(self, + self.input_size, [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=self.params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + if self.custom_op is not None: + self.custom_op.update_attrs() + + def forward( + self, + input_, + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.custom_op is not None: + return self.custom_op.apply(input_) + + return super().forward(input_) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 9ceeb29..663d28e 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -17,16 +17,16 @@ This file extends the functionality of linear operations by encapsulating custom communication groups and forward functions into classes (linear ops). Current class inheritance structure: -CustomTensorParallelOp +CustomLinearOp ├── CustomColumnParallelOp │ ├── MLPColumnParallelOp │ ├── SequenceColumnParallelOp └── CustomRowParallelOp - ├── MLPRowParallelOp - ├── OProjRowParallelOp - ├── MatmulAllreduceRowParallelOp - └── SequenceRowParallelOp - +│ ├── MLPRowParallelOp +│ ├── OProjRowParallelOp +│ ├── MatmulAllreduceRowParallelOp +│ └── SequenceRowParallelOp +└── CustomReplicatedOp How to extend a new linear op? Taking column parallel op as an example: 1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp 2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method @@ -52,7 +52,7 @@ from vllm_ascend.utils import (dense_optim_enable, enable_sp, oproj_tp_enable) -class CustomTensorParallelOp: +class CustomLinearOp: def __init__(self, layer): self.layer = layer @@ -95,7 +95,7 @@ class CustomTensorParallelOp: return output, output_bias -class CustomColumnParallelOp(CustomTensorParallelOp): +class CustomColumnParallelOp(CustomLinearOp): def __init__(self, layer): super().__init__(layer) @@ -106,7 +106,7 @@ class CustomColumnParallelOp(CustomTensorParallelOp): self.gather_output = self.layer.gather_output -class CustomRowParallelOp(CustomTensorParallelOp): +class CustomRowParallelOp(CustomLinearOp): def __init__(self, layer): super().__init__(layer) @@ -129,6 +129,18 @@ class CustomRowParallelOp(CustomTensorParallelOp): return output, output_bias +class CustomReplicatedOp(CustomLinearOp): + + def apply_impl(self, input_): + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + + output = self.quant_method.apply(self.layer, input_, bias) + output_bias = self.bias if self.skip_bias_add else None + + return output, output_bias + + class MLPColumnParallelOp(CustomColumnParallelOp): def __init__(self, layer): @@ -422,3 +434,11 @@ def get_parallel_op(disable_tp, prefix, layer, direct): return custom_op, custom_op.tp_rank, custom_op.tp_size return None, get_tp_group().rank_in_group, get_tp_group().world_size + + +def get_replicated_op(disable_tp, prefix, + layer) -> Optional[Union[CustomReplicatedOp]]: + if disable_tp: + return None + + return CustomReplicatedOp(layer) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index f2d7176..185b206 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -24,8 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - RowParallelLinear, - UnquantizedLinearMethod) + RowParallelLinear) from vllm.model_executor.layers.quantization import \ register_quantization_config from vllm.model_executor.layers.quantization.base_config import ( @@ -39,6 +38,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod +from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, oproj_tp_enable) @@ -101,7 +101,7 @@ class AscendQuantConfig(QuantizationConfig): if isinstance(layer, LinearBase): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): - return UnquantizedLinearMethod() + return AscendUnquantizedLinearMethod() return AscendLinearMethod(self, prefix, self.packed_modules_mapping) elif isinstance(layer, Attention) and \ diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index b8bcc78..30f0811 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.moe.experts_selector import select_experts -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz class AscendW4A8DynamicLinearMethod: @@ -393,9 +393,10 @@ class AscendW4A8DynamicFusedMoEMethod: self.update_bias(layer, w13_bias, w2_bias) - layer.w13_weight.data = torch_npu.npu_format_cast( - layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) - layer.w2_weight.data = torch_npu.npu_format_cast( - layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) + if is_enable_nz(): + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index fb4c5a4..5c7d986 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -25,7 +25,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.moe.experts_selector import select_experts -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz def quant_per_tensor(in_tensor: torch.Tensor, @@ -156,8 +156,9 @@ class AscendW8A8LinearMethod: requires_grad=False).to(layer.aclnn_input_scale.dtype) if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, - ACL_FORMAT_FRACTAL_NZ) + if is_enable_nz(): + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) @@ -340,7 +341,7 @@ class AscendW8A8FusedMoEMethod: # converting ACL_FORMAT_FRACTAL_NZ. # npu_quant_grouped_matmul_dequant in eager mode does not accept # ACL_FORMAT_FRACTAL_NZ. - if not is_310p(): + if not is_310p() and is_enable_nz(): layer.w13_weight.data = torch_npu.npu_format_cast( layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous() layer.w2_weight.data = torch_npu.npu_format_cast( diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 1942797..df9c3b2 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -26,7 +26,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.moe.experts_selector import select_experts -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz class AscendW8A8DynamicLinearMethod: @@ -101,8 +101,9 @@ class AscendW8A8DynamicLinearMethod: if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() # cast quantized weight tensors in NZ format for higher inference speed - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, - ACL_FORMAT_FRACTAL_NZ) + if is_enable_nz(): + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_NZ) layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.flatten() @@ -267,8 +268,9 @@ class AscendW8A8DynamicFusedMoEMethod: 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() - torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ) - torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) + if is_enable_nz(): + torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ) + torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index b027b2f..6517127 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -29,6 +29,7 @@ from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, dispose_tensor, get_ascend_soc_version, + is_enable_nz, is_hierarchical_communication_enabled) @@ -829,7 +830,9 @@ class TorchairAscendW8A8DynamicLinearMethod: if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() # cast quantized weight tensors in NZ format (29) for higher inference speed - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) + if is_enable_nz(): + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, 29) layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.flatten() @@ -1048,7 +1051,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod: 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() - torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) + if is_enable_nz(): + torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py index 8dc6a68..f4a00e5 100644 --- a/vllm_ascend/torchair/torchair_sfa.py +++ b/vllm_ascend/torchair/torchair_sfa.py @@ -24,6 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata +from vllm_ascend.utils import is_enable_nz from vllm_ascend.worker.npu_input_batch import InputBatch if TYPE_CHECKING: @@ -841,7 +842,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl): wd_qkv = wd_qkv.t().contiguous() wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous() - self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) + if is_enable_nz(): + self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone() kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( @@ -874,7 +876,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl): self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1) wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() - self.wu_q = torch_npu.npu_format_cast(wu_q, 29) + if is_enable_nz(): + self.wu_q = torch_npu.npu_format_cast(wu_q, 29) qb_deq_scl = self.q_proj.deq_scale.data.clone() qb_deq_scl = qb_deq_scl.reshape( diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index f75e7c1..164a620 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -14,6 +14,7 @@ try: except ImportError: from torchair.ops import NpuStreamSwitch as _npu_stream_switch from torchair.ops import npu_wait_tensor as _npu_wait_tensor +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes" KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes" @@ -141,6 +142,9 @@ def converting_weight_acl_format(model, format): if isinstance(module, FusedMoE): if torch_npu.get_npu_format(module.w13_weight.data) == format: return + if format == ACL_FORMAT_FRACTAL_NZ \ + and not is_enable_nz(): + return module.w13_weight.data = torch_npu.npu_format_cast( module.w13_weight.data, format) module.w2_weight.data = torch_npu.npu_format_cast( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index e6076b9..9e64d47 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -65,6 +65,10 @@ def is_310p(): return _IS_310P +def is_enable_nz(): + return envs_ascend.VLLM_ASCEND_ENABLE_NZ + + def sleep_mode_enabled(): global _SLEEP_MODE_ENABLED if _SLEEP_MODE_ENABLED is None: @@ -508,6 +512,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): from vllm_ascend.ops.linear import (AscendColumnParallelLinear, AscendMergedColumnParallelLinear, AscendQKVParallelLinear, + AscendReplicatedLinear, AscendRowParallelLinear) from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding, @@ -526,6 +531,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding, "MergedColumnParallelLinear": AscendMergedColumnParallelLinear, "QKVParallelLinear": AscendQKVParallelLinear, + "ReplicatedLinear": AscendReplicatedLinear, "DeepseekScalingRotaryEmbedding": AscendDeepseekScalingRotaryEmbedding, "VocabParallelEmbedding": AscendVocabParallelEmbedding, "ParallelLMHead": AscendParallelLMHead, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5cfad82..12a42c1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -97,6 +97,7 @@ from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import (MoECommType, set_ascend_forward_context) @@ -125,7 +126,7 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, ProfileExecuteDuration, - get_ascend_soc_version, is_310p, + get_ascend_soc_version, is_310p, is_enable_nz, lmhead_tp_enable) from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch @@ -137,8 +138,6 @@ else: import torch_npu -import vllm_ascend.envs as envs_ascend - # if true, allow tensor initialization and casting with internal format (e.g., NZ) torch.npu.config.allow_internal_format = True @@ -2609,6 +2608,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): runtime_mode=CUDAGraphMode.FULL) def _convert_torch_format(self, tensor): + if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \ + and not is_enable_nz(): + return tensor tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) return tensor