Support TP in attention for two batch overlap (#6634)
This commit is contained in:
@@ -448,6 +448,13 @@ class CommunicateSummableTensorPairFn:
|
||||
):
|
||||
return CommunicateSummableTensorPairFn._gather
|
||||
|
||||
if (
|
||||
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||
and (output_mode == ScatterMode.SCATTERED)
|
||||
):
|
||||
return CommunicateSummableTensorPairFn._scatter
|
||||
|
||||
raise NotImplementedError(
|
||||
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
|
||||
)
|
||||
@@ -496,3 +503,15 @@ class CommunicateSummableTensorPairFn:
|
||||
local_hidden_states,
|
||||
)
|
||||
return hidden_states, residual
|
||||
|
||||
@staticmethod
|
||||
def _scatter(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: CommunicateContext,
|
||||
):
|
||||
assert residual is None, "not yet handled residual!=None"
|
||||
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
|
||||
hidden_states = tensor_list[context.attn_tp_rank]
|
||||
return hidden_states, residual
|
||||
|
||||
Reference in New Issue
Block a user