Performing Vocabulary Parallelism for LM Head across Attention TP Groups (#5558)
Co-authored-by: liusy58 <liusy58@linux.alibaba.com>
This commit is contained in:
@@ -252,12 +252,12 @@ def dp_scatter(
|
||||
)
|
||||
|
||||
|
||||
def tp_reduce_scatter(
|
||||
def attn_tp_reduce_scatter(
|
||||
output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
):
|
||||
return get_attention_tp_group().reduce_scatter(output, input_list)
|
||||
|
||||
|
||||
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
||||
|
||||
Reference in New Issue
Block a user