Refactor communication logic of DeepSeek for extensibility and understandability (#6321)

This commit is contained in:
fzyzcjy
2025-05-20 11:14:48 +08:00
committed by GitHub
parent f0653886a5
commit 1b19df4b2a
2 changed files with 496 additions and 188 deletions

View File

@@ -18,8 +18,7 @@
import logging
import os
from dataclasses import dataclass
from enum import Enum, IntEnum, auto
from enum import IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
@@ -29,17 +28,17 @@ from tqdm import tqdm
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
enable_moe_dense_fully_dp,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
@@ -52,9 +51,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE, get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
@@ -72,7 +70,7 @@ from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
@@ -141,6 +139,8 @@ class DeepseekV2MLP(nn.Module):
tp_size: Optional[int] = None,
) -> None:
super().__init__()
self.tp_size = tp_size
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
@@ -167,7 +167,10 @@ class DeepseekV2MLP(nn.Module):
)
self.act_fn = SiluAndMul()
def forward(self, x, forward_batch: Optional[ForwardBatch] = None):
def forward(self, x, forward_batch=None):
if (self.tp_size == 1) and x.shape[0] == 0:
return x
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
@@ -1097,19 +1100,6 @@ class DeepseekV2AttentionMLA(nn.Module):
return output
class _FFNInputMode(Enum):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED = auto()
# The MLP sublayer requires all tokens as input
FULL = auto()
@dataclass
class _DecoderLayerInfo:
is_sparse: bool
ffn_input_mode: _FFNInputMode
class DeepseekV2DecoderLayer(nn.Module):
def __init__(
@@ -1123,14 +1113,12 @@ class DeepseekV2DecoderLayer(nn.Module):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.config = config
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id
self.local_dp_size = get_local_attention_dp_size()
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.self_attn = DeepseekV2AttentionMLA(
config=config,
hidden_size=self.hidden_size,
@@ -1152,19 +1140,24 @@ class DeepseekV2DecoderLayer(nn.Module):
alt_stream=alt_stream,
)
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
previous_layer_info = self._compute_info(
config, layer_id=layer_id - 1, is_nextn=False
self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
if self.info.is_sparse:
if self.is_layer_sparse:
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
else:
if self._enable_moe_dense_fully_dp():
if enable_moe_dense_fully_dp():
mlp_tp_rank, mlp_tp_size = 0, 1
else:
mlp_tp_rank, mlp_tp_size = None, None
@@ -1178,35 +1171,23 @@ class DeepseekV2DecoderLayer(nn.Module):
tp_size=mlp_tp_size,
)
self.input_is_scattered = (
layer_id > 0
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
@staticmethod
def _enable_moe_dense_fully_dp():
return global_server_args_dict["moe_dense_tp_size"] == 1
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)
@staticmethod
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
is_sparse = is_nextn or (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
return is_nextn or (
self.config.n_routed_experts is not None
and layer_id >= self.config.first_k_dense_replace
and layer_id % self.config.moe_layer_freq == 0
)
ffn_input_mode = (
_FFNInputMode.SCATTERED
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
else _FFNInputMode.FULL
)
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
def forward(
self,
@@ -1216,114 +1197,10 @@ class DeepseekV2DecoderLayer(nn.Module):
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual, zero_allocator
)
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual, zero_allocator
)
else:
raise NotImplementedError
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
def forward_ffn_with_full_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
assert not (
self.attn_tp_size != 1 and self.input_is_scattered
), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected
hidden_states = self.mlp(hidden_states, forward_batch)
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
def forward_ffn_with_scattered_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
@@ -1331,35 +1208,15 @@ class DeepseekV2DecoderLayer(nn.Module):
zero_allocator=zero_allocator,
)
if self.attn_tp_size != 1:
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
if not self.input_is_scattered:
residual = residual.tensor_split(self.attn_tp_size)[self.attn_tp_rank]
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states = self.mlp(hidden_states, forward_batch)
if not (
self._enable_moe_dense_fully_dp()
and (not self.info.is_sparse)
and hidden_states.shape[0] == 0
):
hidden_states = self.mlp(hidden_states, forward_batch)
if self.is_last_layer and self.attn_tp_size != 1:
hidden_states += residual
residual = None
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual