diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index b0c05a2..0623185 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -9,12 +9,6 @@ from tests.ut.base import PytestBase from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod -def mock_maybe_chunk_residual(x, residual): - if x.size(0) != residual.size(0): - return residual[:4] - return residual - - def mock_rms_norm(x, weight, eps): return x + 1, None @@ -36,8 +30,6 @@ class TestAscendRMSNorm(PytestBase): @pytest.fixture(autouse=True) def context(self, mocker: MockerFixture): - mocker.patch("torch.ops.vllm.maybe_chunk_residual", - side_effect=mock_maybe_chunk_residual) mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) mocker.patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) @@ -66,21 +58,6 @@ class TestAscendRMSNorm(PytestBase): assert torch.allclose(x_out, x_out_expected) - # Test case for flashcomm_v1 scenario - def test_forward_oot_with_flashcomm_v1(self): - layer = RMSNorm(hidden_size=512, eps=1e-05) - x = torch.randn(4, 512, dtype=torch.bfloat16) - residual = torch.randn(16, 512, dtype=torch.bfloat16) - - x_out, residual_out = layer.forward_oot(x, residual) - - x_out_expected = 2 * x - residual_out_expected = 2 * residual[:4] - - assert residual_out.size(0) == 4 - assert torch.allclose(x_out, x_out_expected) - assert torch.allclose(residual_out, residual_out_expected) - # Test case for addrmsnorm + w8a8 quant fusion def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture): mock_is_310p = mocker.patch("vllm_ascend.utils.is_310p") diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py index d137985..37ea1af 100644 --- a/tests/ut/ops/test_vocab_parallel_embedding.py +++ b/tests/ut/ops/test_vocab_parallel_embedding.py @@ -104,9 +104,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase): input_ = torch.tensor([1, 2, 3]) - with patch( - "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", - side_effect=lambda x: x) as mock_reduce_tp1: + with patch("torch.ops.vllm.maybe_pad_and_reduce", + side_effect=lambda x: x) as mock_reduce_tp1: output = layer.forward(input_) # Should just pass through without masking @@ -123,9 +122,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase): input_ = torch.tensor([15, 35]) # one org vocab, one added vocab - with patch( - "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", - side_effect=lambda x: x) as mock_reduce_tp: + with patch("torch.ops.vllm.maybe_pad_and_reduce", + side_effect=lambda x: x) as mock_reduce_tp: # Call the forward method output = layer.forward(input_) @@ -150,9 +148,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase): return_value=mock_output.clone()) # Patch tensor_model_parallel_all_reduce to mock its behavior - with patch( - "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", - side_effect=lambda x: x): + with patch("torch.ops.vllm.maybe_pad_and_reduce", + side_effect=lambda x: x): # Call the forward method output = layer.forward(input_) # Check that invalid positions (0, 2, 4) were zeroed out @@ -176,9 +173,8 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase): for input_, expected_shape in test_cases: with self.subTest(input=input_): - with patch( - "vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce", - side_effect=lambda x: x): + with patch("torch.ops.vllm.maybe_pad_and_reduce", + side_effect=lambda x: x): # Call the forward method output = layer.forward(input_) self.assertEqual(output.shape, expected_shape) diff --git a/tests/ut/torchair/models/test_torchair_deepseek_v2.py b/tests/ut/torchair/models/test_torchair_deepseek_v2.py index 0231ab9..2dfcfb9 100644 --- a/tests/ut/torchair/models/test_torchair_deepseek_v2.py +++ b/tests/ut/torchair/models/test_torchair_deepseek_v2.py @@ -276,10 +276,7 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, @patch("torch_npu.npu_add_rms_norm") @patch("torch_npu.npu_rms_norm") @patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) -@patch("torch.ops.vllm.maybe_chunk_residual", - side_effect=lambda x, residual: residual) -def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual, - mock_maybe_wait_prefetch_done, +def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done, mock_rms_norm, mock_add_norm, mock_distributed, base_config, vllm_config, mock_forward_context): diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 3dfca53..fbe281f 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -64,7 +64,6 @@ class AscendRMSNorm(RMSNorm): import torch_npu if residual is not None: - residual = torch.ops.vllm.maybe_chunk_residual(x, residual) assert x.size(0) == residual.size(0) x, residual = _addrmsnorm_forward_oot( self, x, residual, self.next_need_quant_fusion_linear) diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 51399cc..0861940 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -34,8 +34,7 @@ from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig from vllm.model_executor.utils import set_weight_attrs -from vllm_ascend.ops.linear_op import (get_column_parallel_op, - get_row_parallel_op) +from vllm_ascend.ops.linear_op import get_parallel_op # TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group @@ -100,8 +99,8 @@ class AscendQKVParallelLinear(QKVParallelLinear): return_bias: bool = True, disable_tp: bool = False, ): - self.custom_op, _, tp_size = get_column_parallel_op( - disable_tp, prefix, self) + self.custom_op, _, tp_size = get_parallel_op(disable_tp, prefix, self, + "column") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group self.hidden_size = hidden_size self.head_size = head_size @@ -173,8 +172,8 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear): return_bias: bool = True, disable_tp: bool = False, ): - self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op( - disable_tp, prefix, self) + self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( + disable_tp, prefix, self, "column") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group self.output_sizes = output_sizes assert all(output_size % self.tp_size == 0 @@ -222,8 +221,8 @@ class AscendRowParallelLinear(RowParallelLinear): return_bias: bool = True, disable_tp: bool = False, ): - self.custom_op, self.tp_rank, self.tp_size = get_row_parallel_op( - disable_tp, prefix, self) + self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( + disable_tp, prefix, self, "row") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group # Divide the weight matrix along the first dimension. self.input_size_per_partition = divide(input_size, self.tp_size) @@ -304,8 +303,8 @@ class AscendColumnParallelLinear(ColumnParallelLinear): return_bias: bool = True, disable_tp: bool = False, ): - self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op( - disable_tp, prefix, self) + self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( + disable_tp, prefix, self, "column") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 819af72..9ceeb29 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -20,13 +20,12 @@ Current class inheritance structure: CustomTensorParallelOp ├── CustomColumnParallelOp │ ├── MLPColumnParallelOp -│ ├── DenseOptimMergedColumnParallelOp -│ └── DenseOptimQKVParallelOp +│ ├── SequenceColumnParallelOp └── CustomRowParallelOp ├── MLPRowParallelOp ├── OProjRowParallelOp ├── MatmulAllreduceRowParallelOp - └── DenseOptimRowParallelOp + └── SequenceRowParallelOp How to extend a new linear op? Taking column parallel op as an example: 1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp @@ -36,7 +35,7 @@ How to extend a new linear op? Taking column parallel op as an example: Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op. """ -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch import torch.distributed as dist @@ -153,69 +152,6 @@ class MLPColumnParallelOp(CustomColumnParallelOp): return output, output_bias -class SequenceMergedColumnParallelOp(CustomColumnParallelOp): - - def apply_impl( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - """Linear layer with column parallelism. - - Implemented multiple optimization projects for dense models, such as FlashComm and - communication-computation fusion. - """ - - bias = self.bias if not self.skip_bias_add else None - - # Matrix multiply. - assert self.quant_method is not None - - input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) - output_parallel = self.quant_method.apply(self.layer, input_, bias) - - if self.gather_output: - # All-gather across the partitions. - output = self.comm_group.all_gather(output_parallel) - else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - - -class SequenceQKVParallelOp(CustomColumnParallelOp): - - def __init__(self, layer, prefix): - super().__init__(layer) - self.prefix = prefix - - def apply_impl( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - """Linear layer with column parallelism. - - Implemented multiple optimization projects for dense models, such as FlashComm and - communication-computation fusion. - """ - - bias = self.bias if not self.skip_bias_add else None - - # Matrix multiply. - assert self.quant_method is not None - - layer_num = self.prefix.split('.')[2] - - input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - input_, layer_num != '0') - output_parallel = self.quant_method.apply(self.layer, input_, bias) - - if self.gather_output: - # All-gather across the partitions. - output = self.comm_group.all_gather(output_parallel) - else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - - class MLPRowParallelOp(CustomRowParallelOp): def __init__(self, layer): @@ -364,11 +300,35 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp): self.weight_t = self.layer.weight.t() -class SequenceRowParallelOp(CustomRowParallelOp): +class SequenceColumnParallelOp(CustomColumnParallelOp): - def __init__(self, layer, prefix): - super().__init__(layer) - self.prefix = prefix + def apply_impl( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + """Linear layer with column parallelism. + + Implemented multiple optimization projects for dense models, such as FlashComm and + communication-computation fusion. + """ + + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) + output_parallel = self.quant_method.apply(self.layer, input_, bias) + + if self.gather_output: + # All-gather across the partitions. + output = self.comm_group.all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class SequenceRowParallelOp(CustomRowParallelOp): def apply_impl( self, input_: torch.Tensor @@ -408,50 +368,55 @@ class SequenceRowParallelOp(CustomRowParallelOp): self.reduce_results = self.layer.reduce_results -def get_column_parallel_op( - disable_tp, prefix, layer -) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp, - SequenceQKVParallelOp]], int, int]: +def _get_column_parallel_op( + prefix, layer +) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]: + if mlp_tp_enable() and "gate_up_proj" in prefix: + return MLPColumnParallelOp(layer) + if enable_sp(): + if "shared_expert" in prefix: + return None + if "gate_up_proj" in prefix: + return SequenceColumnParallelOp(layer) + if "in_proj" in prefix: + return SequenceColumnParallelOp(layer) + if "qkv_proj" in prefix or "conv1d" in prefix: + return SequenceColumnParallelOp(layer) + + return None + + +def _get_row_parallel_op( + prefix, layer +) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, + MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]: + if "down_proj" in prefix and mlp_tp_enable(): + return MLPRowParallelOp(layer) + if "o_proj" in prefix and oproj_tp_enable(): + return OProjRowParallelOp(layer) + if matmul_allreduce_enable(): + return MatmulAllreduceRowParallelOp(layer) + if enable_sp(): + if "shared_expert" in prefix: + return None + if "o_proj" in prefix or "out_proj" in prefix or "down_proj" in prefix: + return SequenceRowParallelOp(layer) + + return None + + +def get_parallel_op(disable_tp, prefix, layer, direct): if disable_tp: return None, 0, 1 - - custom_op: Optional[Union[ - MLPColumnParallelOp, - SequenceMergedColumnParallelOp, - SequenceQKVParallelOp, - ]] = None - if "gate_up_proj" in prefix and mlp_tp_enable(): - custom_op = MLPColumnParallelOp(layer) - elif "gate_up_proj" in prefix and enable_sp(): - custom_op = SequenceMergedColumnParallelOp(layer) - elif enable_sp(): - custom_op = SequenceQKVParallelOp(layer, prefix) - - if custom_op is not None: - return custom_op, custom_op.tp_rank, custom_op.tp_size - - return None, get_tp_group().rank_in_group, get_tp_group().world_size - - -def get_row_parallel_op( - disable_tp, prefix, layer -) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp, - MatmulAllreduceRowParallelOp, - SequenceRowParallelOp]], int, int]: - if disable_tp: - return None, 0, 1 - - custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp, + custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, + MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp]] = None - if "down_proj" in prefix and mlp_tp_enable(): - custom_op = MLPRowParallelOp(layer) - elif "o_proj" in prefix and oproj_tp_enable(): - custom_op = OProjRowParallelOp(layer) - elif matmul_allreduce_enable(): - custom_op = MatmulAllreduceRowParallelOp(layer) - elif enable_sp(): - custom_op = SequenceRowParallelOp(layer, prefix) + if direct == "row": + custom_op = _get_row_parallel_op(prefix, layer) + + if direct == "column": + custom_op = _get_column_parallel_op(prefix, layer) if custom_op is not None: return custom_op, custom_op.tp_rank, custom_op.tp_size diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 7e9cdde..a66fcd3 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -1,9 +1,7 @@ import torch import torch.nn.functional as F import torch_npu -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather, +from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) from vllm.forward_context import get_forward_context @@ -15,27 +13,6 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.utils import npu_stream_switch, prefetch_stream -def _maybe_chunk_residual_impl(x: torch.Tensor, - residual: torch.Tensor) -> torch.Tensor: - try: - forward_context = get_forward_context() - except AssertionError: - return residual - - if x.size(0) != residual.size(0): - sp_enabled = forward_context.sp_enabled - assert sp_enabled is True, ("Currently, this situation only occurs " - "when sp is enabled") - pad_size = forward_context.pad_size - if pad_size > 0: - residual = F.pad(residual, (0, 0, 0, pad_size)) - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - residual = torch.chunk(residual, tp_size, dim=0)[tp_rank] - - return residual - - def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, label: bool) -> torch.Tensor: try: @@ -187,12 +164,6 @@ def _maybe_all_reduce_tensor_model_parallel_impl( return tensor_model_parallel_all_reduce(final_hidden_states) -direct_register_custom_op(op_name="maybe_chunk_residual", - op_func=_maybe_chunk_residual_impl, - fake_impl=lambda x, residual: residual, - mutates_args=[], - dispatch_key="PrivateUse1") - direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", op_func=_maybe_all_gather_and_maybe_unpad_impl, fake_impl=lambda x, label: x, diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 0a7d7ef..69be390 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -20,7 +20,7 @@ from typing import Optional, Tuple import torch from torch import nn from torch.nn.parameter import Parameter -from vllm.distributed import divide, tensor_model_parallel_all_reduce +from vllm.distributed import divide from vllm.distributed.parallel_state import get_tp_group from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -163,7 +163,7 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) # Reduce across all the model parallel GPUs. - output = tensor_model_parallel_all_reduce(output_parallel) + output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) return output diff --git a/vllm_ascend/torchair/ops/torchair_vocab_parallel_embedding.py b/vllm_ascend/torchair/ops/torchair_vocab_parallel_embedding.py new file mode 100644 index 0000000..f83f2bc --- /dev/null +++ b/vllm_ascend/torchair/ops/torchair_vocab_parallel_embedding.py @@ -0,0 +1,38 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from vllm.distributed import tensor_model_parallel_all_reduce + + +def vocab_embedding_forward(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = self._get_masked_input_and_mask( + input_, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 668a7e7..f75e7c1 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -208,11 +208,15 @@ def torchair_ops_patch(): from vllm_ascend.ops.layernorm import AscendRMSNorm from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) + from vllm_ascend.ops.vocab_parallel_embedding import \ + AscendVocabParallelEmbedding from vllm_ascend.torchair.ops import (torchair_activation, torchair_layernorm) from vllm_ascend.torchair.ops.torchair_rotary_embedding import ( deepseek_rope_init_func, native_rope_deepseek_forward, qwen_rope_init_func, rope_forward) + from vllm_ascend.torchair.ops.torchair_vocab_parallel_embedding import \ + vocab_embedding_forward AscendRotaryEmbedding.__init__ = qwen_rope_init_func # type: ignore[method-assign] AscendRotaryEmbedding.forward_oot = rope_forward # type: ignore[method-assign] @@ -222,3 +226,4 @@ def torchair_ops_patch(): AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign] AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign] + AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign]