diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 4a90ea90..adeaa26a 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -281,6 +281,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): super().__init__(layer) self.odp_group = get_flashcomm2_odp_group() self.odp_size = self.odp_group.world_size + self.otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids( get_tp_group().world_size) self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu() @@ -338,8 +339,9 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp): batch_size_per_chunk = batch_size // chunk_num # Indices of reorganized tensor chunked = x.view(chunk_num, batch_size_per_chunk, x.shape[1]) - reorganized_chunks = chunked[self.group_indices] - send_buf = reorganized_chunks.flatten(1, 2) + if self.otp_size != 1: + chunked = chunked[self.group_indices] + send_buf = chunked.flatten(1, 2) # all-to-all operation parameters all2all_tp_size = self.odp_size