diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index 1153bfe..4634a69 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -4,6 +4,7 @@ from unittest import mock from unittest.mock import MagicMock, patch import torch +from vllm import config from tests.ut.base import TestBase from vllm_ascend import ascend_config @@ -106,6 +107,9 @@ class TestAscendRowParallelLinear(BaseLinearTest): linear(input_tensor) def test_oproj_tp(self): + + config._current_vllm_config = MagicMock() + ascend_config._ASCEND_CONFIG = MagicMock() ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2 diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 81d7d9e..969cc97 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn import torch_npu from torch.nn.parameter import Parameter +from vllm.config import get_current_vllm_config from vllm.distributed import divide from vllm.model_executor.layers.linear import ( # noqa WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase, @@ -234,6 +235,13 @@ class AscendRowParallelLinear(RowParallelLinear): return_bias: bool = True, disable_tp: bool = False, ): + compilation_config = get_current_vllm_config().compilation_config + # TODO(shaopeng-666): Remove the visual check after the mm model reconstruction is complete. + if prefix in compilation_config.static_forward_context and \ + "visual" not in prefix: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( disable_tp, prefix, self, "row") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 58a96fd..ff0462e 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -366,14 +366,23 @@ class SequenceRowParallelOp(CustomRowParallelOp): input_parallel, bias=bias_) else: - output_parallel = self.quant_method.apply(self.layer, - input_parallel, - bias=bias_) - output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel) + output = torch.ops.vllm.matmul_and_reduce(input_parallel, + self.prefix) output_bias = self.bias if self.skip_bias_add else None return output, output_bias + def matmul_and_reduce(self, input_parallel: torch.Tensor, + bias_: Optional[Parameter]) -> torch.Tensor: + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self.layer, + input_parallel, + bias=bias_) + from vllm_ascend.ops.register_custom_ops import \ + _maybe_pad_and_reduce_impl + output = _maybe_pad_and_reduce_impl(output_parallel) + return output + def update_attrs(self): super().update_attrs() self.input_is_parallel = self.layer.input_is_parallel diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 5e2bbca..69e220e 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -235,6 +235,31 @@ def _maybe_all_reduce_tensor_model_parallel_impl( return tensor_model_parallel_all_reduce(final_hidden_states) +def _matmul_and_reduce_impl(input_parallel: torch.Tensor, + layer_name: str) -> torch.Tensor: + forward_context = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + assert self.custom_op is not None + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output = self.custom_op.matmul_and_reduce(input_parallel, bias_) + + return output + + +def _matmul_and_reduce_impl_fake(input_parallel: torch.Tensor, + layer_name: str) -> torch.Tensor: + forward_context = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + num_tokens = input_parallel.size(0) + if forward_context.sp_enabled: + num_tokens = num_tokens // self.tp_size + output = torch.empty(size=(num_tokens, self.output_size_per_partition), + device=input_parallel.device, + dtype=input_parallel.dtype) + + return output + + direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", op_func=_maybe_all_gather_and_maybe_unpad_impl, fake_impl=_maybe_all_gather_and_maybe_unpad_fake, @@ -282,3 +307,9 @@ direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", fake_impl=lambda x: x, mutates_args=[], dispatch_key="PrivateUse1") + +direct_register_custom_op(op_name="matmul_and_reduce", + op_func=_matmul_and_reduce_impl, + fake_impl=_matmul_and_reduce_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1")