[v0.11.0][refactor] refactor SequenceRowParallelOp forward (#3654)

### What this PR does / why we need it?
This PR refactors SequenceRowParallelOp forward. In order to further
expand the operator inclusion scope in dynamic judgment scenarios, this
PR customizes the entire matmul computation and communication as a
custom operator masking. With this refactor, it will support directly
writing code such as common operation fusion into the
SequenceRowParallelOp class's member function matmul_and_reduce, without
the need to register more redundant custom masking operators.

### How was this patch tested?
CI passed with new added/existing test.

Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2025-10-23 14:45:49 +08:00
committed by GitHub
parent 54bd531db8
commit 74903af460
4 changed files with 56 additions and 4 deletions

View File

@@ -4,6 +4,7 @@ from unittest import mock
from unittest.mock import MagicMock, patch
import torch
from vllm import config
from tests.ut.base import TestBase
from vllm_ascend import ascend_config
@@ -106,6 +107,9 @@ class TestAscendRowParallelLinear(BaseLinearTest):
linear(input_tensor)
def test_oproj_tp(self):
config._current_vllm_config = MagicMock()
ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2

View File

@@ -26,6 +26,7 @@ import torch
import torch.nn as nn
import torch_npu
from torch.nn.parameter import Parameter
from vllm.config import get_current_vllm_config
from vllm.distributed import divide
from vllm.model_executor.layers.linear import ( # noqa
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
@@ -234,6 +235,13 @@ class AscendRowParallelLinear(RowParallelLinear):
return_bias: bool = True,
disable_tp: bool = False,
):
compilation_config = get_current_vllm_config().compilation_config
# TODO(shaopeng-666): Remove the visual check after the mm model reconstruction is complete.
if prefix in compilation_config.static_forward_context and \
"visual" not in prefix:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
disable_tp, prefix, self, "row")
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group

View File

@@ -366,14 +366,23 @@ class SequenceRowParallelOp(CustomRowParallelOp):
input_parallel,
bias=bias_)
else:
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
output = torch.ops.vllm.matmul_and_reduce(input_parallel,
self.prefix)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def matmul_and_reduce(self, input_parallel: torch.Tensor,
bias_: Optional[Parameter]) -> torch.Tensor:
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
from vllm_ascend.ops.register_custom_ops import \
_maybe_pad_and_reduce_impl
output = _maybe_pad_and_reduce_impl(output_parallel)
return output
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel

View File

@@ -235,6 +235,31 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
return tensor_model_parallel_all_reduce(final_hidden_states)
def _matmul_and_reduce_impl(input_parallel: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.custom_op is not None
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output = self.custom_op.matmul_and_reduce(input_parallel, bias_)
return output
def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
num_tokens = input_parallel.size(0)
if forward_context.sp_enabled:
num_tokens = num_tokens // self.tp_size
output = torch.empty(size=(num_tokens, self.output_size_per_partition),
device=input_parallel.device,
dtype=input_parallel.dtype)
return output
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
@@ -282,3 +307,9 @@ direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="matmul_and_reduce",
op_func=_matmul_and_reduce_impl,
fake_impl=_matmul_and_reduce_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")