[Bugfix] Avoid unnecessary reduce-scatter call in prepare_mlp (#9169)
This commit is contained in:
@@ -292,6 +292,10 @@ def _dp_gather_via_all_gather(
|
||||
forward_batch: ForwardBatch,
|
||||
is_partial: bool,
|
||||
):
|
||||
if get_attention_tp_size() == 1:
|
||||
get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)
|
||||
return
|
||||
|
||||
if not is_partial:
|
||||
if get_attention_tp_rank() != 0:
|
||||
local_tokens.fill_(0)
|
||||
|
||||
Reference in New Issue
Block a user