[fix] PD disaggregation when enable mtp and tp!=dp (#7420)
This commit is contained in:
@@ -310,4 +310,4 @@ def attn_tp_reduce_scatter(
|
||||
|
||||
|
||||
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
||||
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
|
||||
|
||||
Reference in New Issue
Block a user