Support TP in attention for two batch overlap (#6634)

This commit is contained in:
fzyzcjy
2025-05-27 11:28:12 +08:00
committed by GitHub
parent ebd1ed49d4
commit 32cd707002
4 changed files with 104 additions and 8 deletions

View File

@@ -448,6 +448,13 @@ class CommunicateSummableTensorPairFn:
):
return CommunicateSummableTensorPairFn._gather
if (
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
and (output_mode == ScatterMode.SCATTERED)
):
return CommunicateSummableTensorPairFn._scatter
raise NotImplementedError(
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
)
@@ -496,3 +503,15 @@ class CommunicateSummableTensorPairFn:
local_hidden_states,
)
return hidden_states, residual
@staticmethod
def _scatter(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
):
assert residual is None, "not yet handled residual!=None"
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
hidden_states = tensor_list[context.attn_tp_rank]
return hidden_states, residual