From 74903af460b111c1348c0453dfc508011b7cd3c7 Mon Sep 17 00:00:00 2001 From: rjg-lyh <83491835+rjg-lyh@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:45:49 +0800 Subject: [PATCH] [v0.11.0][refactor] refactor SequenceRowParallelOp forward (#3654) ### What this PR does / why we need it? This PR refactors SequenceRowParallelOp forward. In order to further expand the operator inclusion scope in dynamic judgment scenarios, this PR customizes the entire matmul computation and communication as a custom operator masking. With this refactor, it will support directly writing code such as common operation fusion into the SequenceRowParallelOp class's member function matmul_and_reduce, without the need to register more redundant custom masking operators. ### How was this patch tested? CI passed with new added/existing test. Signed-off-by: rjg-lyh <1318825571@qq.com> --- tests/ut/ops/test_linear.py | 4 ++++ vllm_ascend/ops/linear.py | 8 +++++++ vllm_ascend/ops/linear_op.py | 17 ++++++++++---- vllm_ascend/ops/register_custom_ops.py | 31 ++++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 4 deletions(-) diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index 1153bfe..4634a69 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -4,6 +4,7 @@ from unittest import mock from unittest.mock import MagicMock, patch import torch +from vllm import config from tests.ut.base import TestBase from vllm_ascend import ascend_config @@ -106,6 +107,9 @@ class TestAscendRowParallelLinear(BaseLinearTest): linear(input_tensor) def test_oproj_tp(self): + + config._current_vllm_config = MagicMock() + ascend_config._ASCEND_CONFIG = MagicMock() ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2 diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 81d7d9e..969cc97 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn import torch_npu from torch.nn.parameter import Parameter +from vllm.config import get_current_vllm_config from vllm.distributed import divide from vllm.model_executor.layers.linear import ( # noqa WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase, @@ -234,6 +235,13 @@ class AscendRowParallelLinear(RowParallelLinear): return_bias: bool = True, disable_tp: bool = False, ): + compilation_config = get_current_vllm_config().compilation_config + # TODO(shaopeng-666): Remove the visual check after the mm model reconstruction is complete. + if prefix in compilation_config.static_forward_context and \ + "visual" not in prefix: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( disable_tp, prefix, self, "row") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 58a96fd..ff0462e 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -366,14 +366,23 @@ class SequenceRowParallelOp(CustomRowParallelOp): input_parallel, bias=bias_) else: - output_parallel = self.quant_method.apply(self.layer, - input_parallel, - bias=bias_) - output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) + output = torch.ops.vllm.matmul_and_reduce(input_parallel, + self.prefix) output_bias = self.bias if self.skip_bias_add else None return output, output_bias + def matmul_and_reduce(self, input_parallel: torch.Tensor, + bias_: Optional[Parameter]) -> torch.Tensor: + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + from vllm_ascend.ops.register_custom_ops import \ + _maybe_pad_and_reduce_impl + output = _maybe_pad_and_reduce_impl(output_parallel) + return output + def update_attrs(self): super().update_attrs() self.input_is_parallel = self.layer.input_is_parallel diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 5e2bbca..69e220e 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -235,6 +235,31 @@ def _maybe_all_reduce_tensor_model_parallel_impl( return tensor_model_parallel_all_reduce(final_hidden_states) +def _matmul_and_reduce_impl(input_parallel: torch.Tensor, + layer_name: str) -> torch.Tensor: + forward_context = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + assert self.custom_op is not None + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output = self.custom_op.matmul_and_reduce(input_parallel, bias_) + + return output + + +def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, + layer_name: str) -> torch.Tensor: + forward_context = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + num_tokens = input_parallel.size(0) + if forward_context.sp_enabled: + num_tokens = num_tokens // self.tp_size + output = torch.empty(size=(num_tokens, self.output_size_per_partition), + device=input_parallel.device, + dtype=input_parallel.dtype) + + return output + + direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", op_func=_maybe_all_gather_and_maybe_unpad_impl, fake_impl=_maybe_all_gather_and_maybe_unpad_fake, @@ -282,3 +307,9 @@ direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", fake_impl=lambda x: x, mutates_args=[], dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="matmul_and_reduce", + op_func=_matmul_and_reduce_impl, + fake_impl=_matmul_and_reduce_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1")