[Bugfix] Avoid unnecessary reduce-scatter call in prepare_mlp (#9169)

This commit is contained in:
Huaixin Chang
2025-08-14 12:04:41 +08:00
committed by GitHub
parent 0fc8bf2cd4
commit 98457c0453

View File

@@ -292,6 +292,10 @@ def _dp_gather_via_all_gather(
forward_batch: ForwardBatch,
is_partial: bool,
):
if get_attention_tp_size() == 1:
get_tp_group().all_gather_into_tensor(global_tokens, local_tokens)
return
if not is_partial:
if get_attention_tp_rank() != 0:
local_tokens.fill_(0)