Support TP in attention for two batch overlap (#6634)
This commit is contained in:
@@ -448,6 +448,13 @@ class CommunicateSummableTensorPairFn:
|
||||
):
|
||||
return CommunicateSummableTensorPairFn._gather
|
||||
|
||||
if (
|
||||
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||
and (output_mode == ScatterMode.SCATTERED)
|
||||
):
|
||||
return CommunicateSummableTensorPairFn._scatter
|
||||
|
||||
raise NotImplementedError(
|
||||
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
|
||||
)
|
||||
@@ -496,3 +503,15 @@ class CommunicateSummableTensorPairFn:
|
||||
local_hidden_states,
|
||||
)
|
||||
return hidden_states, residual
|
||||
|
||||
@staticmethod
|
||||
def _scatter(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
context: CommunicateContext,
|
||||
):
|
||||
assert residual is None, "not yet handled residual!=None"
|
||||
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
|
||||
hidden_states = tensor_list[context.attn_tp_rank]
|
||||
return hidden_states, residual
|
||||
|
||||
@@ -1613,6 +1613,9 @@ class DeepseekV2Model(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
input_data_scatter_mode=self.layers[
|
||||
normal_num_layers - 1
|
||||
].layer_scatter_modes.layer_output_mode,
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
|
||||
|
||||
@@ -32,7 +32,11 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.communicator import (
|
||||
LayerCommunicator,
|
||||
LayerScatterModes,
|
||||
ScatterMode,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
attn_tp_reduce_scatter,
|
||||
@@ -447,6 +451,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
hidden_states, residual = model_forward_maybe_tbo(
|
||||
layers=self.layers,
|
||||
enable_tbo=True,
|
||||
input_data_scatter_mode=ScatterMode.model_input_output(),
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
hidden_states=hidden_states,
|
||||
|
||||
@@ -5,6 +5,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.communicator import (
|
||||
CommunicateContext,
|
||||
CommunicateSimpleFn,
|
||||
CommunicateSummableTensorPairFn,
|
||||
ScatterMode,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
|
||||
@@ -355,6 +361,7 @@ def model_forward_maybe_tbo(
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
hidden_states: torch.Tensor,
|
||||
input_data_scatter_mode: ScatterMode,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: Optional[BumpAllocator] = None,
|
||||
):
|
||||
@@ -365,20 +372,32 @@ def model_forward_maybe_tbo(
|
||||
residual=residual,
|
||||
**(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}),
|
||||
)
|
||||
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
|
||||
operations_strategy = OperationsStrategy.init_new_tbo(
|
||||
layers, forward_batch.global_forward_mode
|
||||
)
|
||||
if enable_tbo:
|
||||
return _model_forward_tbo(inputs, operations_strategy)
|
||||
return _model_forward_tbo(
|
||||
inputs=inputs,
|
||||
operations_strategy=operations_strategy,
|
||||
input_data_scatter_mode=input_data_scatter_mode,
|
||||
layer_input_scatter_mode=layer_input_scatter_mode,
|
||||
)
|
||||
else:
|
||||
return _model_forward_non_tbo(inputs, operations_strategy)
|
||||
|
||||
|
||||
def _model_forward_tbo(inputs, operations_strategy: OperationsStrategy):
|
||||
# The attn_tp_size!=1 case is not yet extracted to master
|
||||
assert get_attention_tp_size() == 1
|
||||
|
||||
inputs_arr = _model_forward_tbo_split_inputs(**inputs)
|
||||
def _model_forward_tbo(
|
||||
inputs,
|
||||
operations_strategy: OperationsStrategy,
|
||||
input_data_scatter_mode: ScatterMode,
|
||||
layer_input_scatter_mode: ScatterMode,
|
||||
):
|
||||
inputs_arr = _model_forward_tbo_split_inputs(
|
||||
**inputs,
|
||||
input_data_scatter_mode=input_data_scatter_mode,
|
||||
layer_input_scatter_mode=layer_input_scatter_mode,
|
||||
)
|
||||
del inputs
|
||||
|
||||
with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
|
||||
@@ -401,7 +420,57 @@ def _model_forward_tbo_split_inputs(
|
||||
residual: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: Optional[BumpAllocator] = None,
|
||||
zero_allocator: Optional[BumpAllocator],
|
||||
input_data_scatter_mode: ScatterMode,
|
||||
layer_input_scatter_mode: ScatterMode,
|
||||
) -> List[Dict]:
|
||||
tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL
|
||||
context = CommunicateContext.init_new()
|
||||
|
||||
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
|
||||
hidden_states_input_mode=input_data_scatter_mode,
|
||||
residual_input_mode=input_data_scatter_mode,
|
||||
output_mode=tbo_splitter_scatter_mode,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=forward_batch,
|
||||
context=context,
|
||||
)
|
||||
|
||||
inputs_arr = _model_forward_tbo_split_inputs_raw(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
|
||||
def _post_transform(hidden_states, residual, forward_batch, **kwargs):
|
||||
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
|
||||
hidden_states_input_mode=tbo_splitter_scatter_mode,
|
||||
residual_input_mode=tbo_splitter_scatter_mode,
|
||||
output_mode=layer_input_scatter_mode,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=forward_batch,
|
||||
context=context,
|
||||
)
|
||||
return dict(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=forward_batch,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return [_post_transform(**inputs) for inputs in inputs_arr]
|
||||
|
||||
|
||||
def _model_forward_tbo_split_inputs_raw(
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: Optional[BumpAllocator],
|
||||
) -> List[Dict]:
|
||||
return [
|
||||
dict(
|
||||
|
||||
Reference in New Issue
Block a user