From f39f566e22b87ee75bd1205f982e4255a882c3a4 Mon Sep 17 00:00:00 2001 From: idouba Date: Fri, 20 Mar 2026 16:49:02 +0800 Subject: [PATCH] Refactor duplicated code into a common method to reduce redundancy (#7210) ### What this PR does / why we need it? 1. Extracting duplicated code into a method. That is defining _get_input_parallel_ in parent class _CustomRowParallelOp_, and call the helper method in its 5 children classes : - MLPRowParallelOp - OProjRowParallelOp - Flashcomm2OProjRowParallelOp - MatmulAllreduceRowParallelOp - SequenceRowParallelOp 's _apply_impl_ method 2. Variable typo fixing: split instead of splitted for the past tense ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d Signed-off-by: idouba --- vllm_ascend/ops/linear_op.py | 39 +++++++++++------------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 77c71b40..c8b12615 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -157,6 +157,13 @@ class CustomRowParallelOp(CustomLinearOp): return output return output, output_bias + def get_input_parallel(self, input_: torch.Tensor) -> torch.Tensor: + if self.input_is_parallel: + return input_ + + split_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size) + return split_input[self.tp_rank].contiguous() + class CustomReplicatedOp(CustomLinearOp): def apply_impl(self, input_): @@ -200,11 +207,7 @@ class MLPRowParallelOp(CustomRowParallelOp): return get_mlp_tp_group() def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: - if self.input_is_parallel: - input_parallel = input_ - else: - splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size) - input_parallel = splitted_input[self.tp_rank].contiguous() + input_parallel = self.get_input_parallel(input_) assert self.quant_method is not None bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.layer.bias @@ -227,11 +230,7 @@ class OProjRowParallelOp(CustomRowParallelOp): self, input_: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: - if self.input_is_parallel: - input_parallel = input_ - else: - splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size) - input_parallel = splitted_input[self.tp_rank].contiguous() + input_parallel = self.get_input_parallel(input_) # Prepare tensors for all-to-all communication local_batch_size = input_parallel.size(0) @@ -303,12 +302,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): Output.shape = [(batchsize*seqlength+padsize)/TP, hiddensize] """ # Handle input parallelism - split or use as-is - if self.input_is_parallel: - input_parallel = input_ - else: - tp_rank = self.tp_rank - splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() + input_parallel = self.get_input_parallel(input_) # padding for all-to-all num_padding_tokens = _EXTRA_CTX.pad_size @@ -394,11 +388,7 @@ class MatmulAllreduceRowParallelOp(CustomRowParallelOp): self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group) def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: - if self.input_is_parallel: - input_parallel = input_ - else: - splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size) - input_parallel = splitted_input[self.tp_rank].contiguous() + input_parallel = self.get_input_parallel(input_) """Calculate the output tensor of forward by considering fusing communication and computation.""" bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias @@ -492,12 +482,7 @@ class SequenceRowParallelOp(CustomRowParallelOp): Implemented multiple optimization projects for dense models, such as FlashComm and communication-computation fusion. """ - - if self.input_is_parallel: - input_parallel = input_ - else: - splitted_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size) - input_parallel = splitted_input[self.tp_rank].contiguous() + input_parallel = self.get_input_parallel(input_) assert self.quant_method is not None bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias