[Performance] Remove index opetation when VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 (#5936)

### What this PR does / why we need it?
When enable VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE>1, we need index
operation to reorganize the batch, because that we need ensure the
correct batch-id for each rank after the reduce-scatter op in
VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE>1. But we do not need it when
VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1, which dose not need
reduce-scatter.

Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
Levi
2026-01-19 17:12:13 +08:00
committed by GitHub
parent bc486d9530
commit f0d41199a6

View File

@@ -281,6 +281,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
super().__init__(layer)
self.odp_group = get_flashcomm2_odp_group()
self.odp_size = self.odp_group.world_size
self.otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(
get_tp_group().world_size)
self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu()
@@ -338,8 +339,9 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
batch_size_per_chunk = batch_size // chunk_num
# Indices of reorganized tensor
chunked = x.view(chunk_num, batch_size_per_chunk, x.shape[1])
reorganized_chunks = chunked[self.group_indices]
send_buf = reorganized_chunks.flatten(1, 2)
if self.otp_size != 1:
chunked = chunked[self.group_indices]
send_buf = chunked.flatten(1, 2)
# all-to-all operation parameters
all2all_tp_size = self.odp_size