[Bugfix] Avoid unnecessary reduce-scatter call in prepare_mlp (#9169)
This commit is contained in:
@@ -292,6 +292,10 @@ def _dp_gather_via_all_gather(
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
is_partial: bool,
|
is_partial: bool,
|
||||||
):
|
):
|
||||||
|
if get_attention_tp_size() == 1:
|
||||||
|
get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)
|
||||||
|
return
|
||||||
|
|
||||||
if not is_partial:
|
if not is_partial:
|
||||||
if get_attention_tp_rank() != 0:
|
if get_attention_tp_rank() != 0:
|
||||||
local_tokens.fill_(0)
|
local_tokens.fill_(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user