From f0d41199a6fdf6c795d63b1f811a0d712834f89b Mon Sep 17 00:00:00 2001 From: Levi <54832289+Levi-JQ@users.noreply.github.com> Date: Mon, 19 Jan 2026 17:12:13 +0800 Subject: [PATCH] [Performance] Remove index opetation when VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 (#5936) ### What this PR does / why we need it? When enable VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE>1, we need index operation to reorganize the batch, because that we need ensure the correct batch-id for each rank after the reduce-scatter op in VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE>1. But we do not need it when VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1, which dose not need reduce-scatter. Signed-off-by: Levi-JQ Co-authored-by: Levi-JQ --- vllm_ascend/ops/linear_op.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 4a90ea90..adeaa26a 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -281,6 +281,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): super().__init__(layer) self.odp_group = get_flashcomm2_odp_group() self.odp_size = self.odp_group.world_size + self.otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids( get_tp_group().world_size) self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu() @@ -338,8 +339,9 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): batch_size_per_chunk = batch_size // chunk_num # Indices of reorganized tensor chunked = x.view(chunk_num, batch_size_per_chunk, x.shape[1]) - reorganized_chunks = chunked[self.group_indices] - send_buf = reorganized_chunks.flatten(1, 2) + if self.otp_size != 1: + chunked = chunked[self.group_indices] + send_buf = chunked.flatten(1, 2) # all-to-all operation parameters all2all_tp_size = self.odp_size