Support TP in attention for two batch overlap (#6634)
This commit is contained in:
@@ -448,6 +448,13 @@ class CommunicateSummableTensorPairFn:
|
|||||||
):
|
):
|
||||||
return CommunicateSummableTensorPairFn._gather
|
return CommunicateSummableTensorPairFn._gather
|
||||||
|
|
||||||
|
if (
|
||||||
|
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||||
|
and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
|
||||||
|
and (output_mode == ScatterMode.SCATTERED)
|
||||||
|
):
|
||||||
|
return CommunicateSummableTensorPairFn._scatter
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
|
f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
|
||||||
)
|
)
|
||||||
@@ -496,3 +503,15 @@ class CommunicateSummableTensorPairFn:
|
|||||||
local_hidden_states,
|
local_hidden_states,
|
||||||
)
|
)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _scatter(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
context: CommunicateContext,
|
||||||
|
):
|
||||||
|
assert residual is None, "not yet handled residual!=None"
|
||||||
|
tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
|
||||||
|
hidden_states = tensor_list[context.attn_tp_rank]
|
||||||
|
return hidden_states, residual
|
||||||
|
|||||||
@@ -1613,6 +1613,9 @@ class DeepseekV2Model(nn.Module):
|
|||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
|
input_data_scatter_mode=self.layers[
|
||||||
|
normal_num_layers - 1
|
||||||
|
].layer_scatter_modes.layer_output_mode,
|
||||||
zero_allocator=zero_allocator,
|
zero_allocator=zero_allocator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,11 @@ from sglang.srt.distributed import (
|
|||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
from sglang.srt.layers.communicator import (
|
||||||
|
LayerCommunicator,
|
||||||
|
LayerScatterModes,
|
||||||
|
ScatterMode,
|
||||||
|
)
|
||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
attn_tp_all_gather,
|
attn_tp_all_gather,
|
||||||
attn_tp_reduce_scatter,
|
attn_tp_reduce_scatter,
|
||||||
@@ -447,6 +451,7 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
hidden_states, residual = model_forward_maybe_tbo(
|
hidden_states, residual = model_forward_maybe_tbo(
|
||||||
layers=self.layers,
|
layers=self.layers,
|
||||||
enable_tbo=True,
|
enable_tbo=True,
|
||||||
|
input_data_scatter_mode=ScatterMode.model_input_output(),
|
||||||
positions=positions,
|
positions=positions,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
|||||||
@@ -5,6 +5,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.layers.communicator import (
|
||||||
|
CommunicateContext,
|
||||||
|
CommunicateSimpleFn,
|
||||||
|
CommunicateSummableTensorPairFn,
|
||||||
|
ScatterMode,
|
||||||
|
)
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||||
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
|
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
|
||||||
@@ -355,6 +361,7 @@ def model_forward_maybe_tbo(
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
input_data_scatter_mode: ScatterMode,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
zero_allocator: Optional[BumpAllocator] = None,
|
zero_allocator: Optional[BumpAllocator] = None,
|
||||||
):
|
):
|
||||||
@@ -365,20 +372,32 @@ def model_forward_maybe_tbo(
|
|||||||
residual=residual,
|
residual=residual,
|
||||||
**(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}),
|
**(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}),
|
||||||
)
|
)
|
||||||
|
layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
|
||||||
operations_strategy = OperationsStrategy.init_new_tbo(
|
operations_strategy = OperationsStrategy.init_new_tbo(
|
||||||
layers, forward_batch.global_forward_mode
|
layers, forward_batch.global_forward_mode
|
||||||
)
|
)
|
||||||
if enable_tbo:
|
if enable_tbo:
|
||||||
return _model_forward_tbo(inputs, operations_strategy)
|
return _model_forward_tbo(
|
||||||
|
inputs=inputs,
|
||||||
|
operations_strategy=operations_strategy,
|
||||||
|
input_data_scatter_mode=input_data_scatter_mode,
|
||||||
|
layer_input_scatter_mode=layer_input_scatter_mode,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return _model_forward_non_tbo(inputs, operations_strategy)
|
return _model_forward_non_tbo(inputs, operations_strategy)
|
||||||
|
|
||||||
|
|
||||||
def _model_forward_tbo(inputs, operations_strategy: OperationsStrategy):
|
def _model_forward_tbo(
|
||||||
# The attn_tp_size!=1 case is not yet extracted to master
|
inputs,
|
||||||
assert get_attention_tp_size() == 1
|
operations_strategy: OperationsStrategy,
|
||||||
|
input_data_scatter_mode: ScatterMode,
|
||||||
inputs_arr = _model_forward_tbo_split_inputs(**inputs)
|
layer_input_scatter_mode: ScatterMode,
|
||||||
|
):
|
||||||
|
inputs_arr = _model_forward_tbo_split_inputs(
|
||||||
|
**inputs,
|
||||||
|
input_data_scatter_mode=input_data_scatter_mode,
|
||||||
|
layer_input_scatter_mode=layer_input_scatter_mode,
|
||||||
|
)
|
||||||
del inputs
|
del inputs
|
||||||
|
|
||||||
with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
|
with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
|
||||||
@@ -401,7 +420,57 @@ def _model_forward_tbo_split_inputs(
|
|||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
zero_allocator: Optional[BumpAllocator] = None,
|
zero_allocator: Optional[BumpAllocator],
|
||||||
|
input_data_scatter_mode: ScatterMode,
|
||||||
|
layer_input_scatter_mode: ScatterMode,
|
||||||
|
) -> List[Dict]:
|
||||||
|
tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL
|
||||||
|
context = CommunicateContext.init_new()
|
||||||
|
|
||||||
|
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
|
||||||
|
hidden_states_input_mode=input_data_scatter_mode,
|
||||||
|
residual_input_mode=input_data_scatter_mode,
|
||||||
|
output_mode=tbo_splitter_scatter_mode,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
residual=residual,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_arr = _model_forward_tbo_split_inputs_raw(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
residual=residual,
|
||||||
|
positions=positions,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
zero_allocator=zero_allocator,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _post_transform(hidden_states, residual, forward_batch, **kwargs):
|
||||||
|
hidden_states, residual = CommunicateSummableTensorPairFn.execute(
|
||||||
|
hidden_states_input_mode=tbo_splitter_scatter_mode,
|
||||||
|
residual_input_mode=tbo_splitter_scatter_mode,
|
||||||
|
output_mode=layer_input_scatter_mode,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
residual=residual,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
return dict(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
residual=residual,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [_post_transform(**inputs) for inputs in inputs_arr]
|
||||||
|
|
||||||
|
|
||||||
|
def _model_forward_tbo_split_inputs_raw(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
zero_allocator: Optional[BumpAllocator],
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
return [
|
return [
|
||||||
dict(
|
dict(
|
||||||
|
|||||||
Reference in New Issue
Block a user