Performing Vocabulary Parallelism for LM Head across Attention TP Groups (#5558)

Co-authored-by: liusy58 <liusy58@linux.alibaba.com>
This commit is contained in:
Cheng Wan
2025-05-12 02:36:29 -04:00
committed by GitHub
parent 9f2c9568f0
commit 25c83fff6a
8 changed files with 71 additions and 23 deletions

View File

@@ -252,12 +252,12 @@ def dp_scatter(
)
def tp_reduce_scatter(
def attn_tp_reduce_scatter(
output: torch.Tensor,
input_list: List[torch.Tensor],
):
return get_attention_tp_group().reduce_scatter(output, input_list)
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)