[Performance] Remove index opetation when VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 (#5936)
### What this PR does / why we need it? When enable VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE>1, we need index operation to reorganize the batch, because that we need ensure the correct batch-id for each rank after the reduce-scatter op in VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE>1. But we do not need it when VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1, which dose not need reduce-scatter. Signed-off-by: Levi-JQ <yujinqi2@huawei.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -281,6 +281,7 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
|||||||
super().__init__(layer)
|
super().__init__(layer)
|
||||||
self.odp_group = get_flashcomm2_odp_group()
|
self.odp_group = get_flashcomm2_odp_group()
|
||||||
self.odp_size = self.odp_group.world_size
|
self.odp_size = self.odp_group.world_size
|
||||||
|
self.otp_size = get_ascend_config().flashcomm2_oproj_tensor_parallel_size
|
||||||
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(
|
self.reorgnized_batch_ids = get_flashcomm2_reorgnized_batch_ids(
|
||||||
get_tp_group().world_size)
|
get_tp_group().world_size)
|
||||||
self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu()
|
self.group_indices = torch.tensor(self.reorgnized_batch_ids).npu()
|
||||||
@@ -338,8 +339,9 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
|||||||
batch_size_per_chunk = batch_size // chunk_num
|
batch_size_per_chunk = batch_size // chunk_num
|
||||||
# Indices of reorganized tensor
|
# Indices of reorganized tensor
|
||||||
chunked = x.view(chunk_num, batch_size_per_chunk, x.shape[1])
|
chunked = x.view(chunk_num, batch_size_per_chunk, x.shape[1])
|
||||||
reorganized_chunks = chunked[self.group_indices]
|
if self.otp_size != 1:
|
||||||
send_buf = reorganized_chunks.flatten(1, 2)
|
chunked = chunked[self.group_indices]
|
||||||
|
send_buf = chunked.flatten(1, 2)
|
||||||
|
|
||||||
# all-to-all operation parameters
|
# all-to-all operation parameters
|
||||||
all2all_tp_size = self.odp_size
|
all2all_tp_size = self.odp_size
|
||||||
|
|||||||
Reference in New Issue
Block a user