Refactor DeepSeek decoder layer branches (#5205)
This commit is contained in:
@@ -18,7 +18,8 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from enum import IntEnum, auto
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, IntEnum, auto
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -28,6 +29,7 @@ from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
parallel_state,
|
||||
tensor_model_parallel_all_reduce,
|
||||
@@ -146,7 +148,7 @@ class DeepseekV2MLP(nn.Module):
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, forward_mode: Optional[ForwardMode] = None):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
@@ -999,6 +1001,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class _FFNInputMode(Enum):
|
||||
# The MLP sublayer requires 1/tp_size tokens as input
|
||||
SCATTERED = auto()
|
||||
# The MLP sublayer requires all tokens as input
|
||||
FULL = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DecoderLayerInfo:
|
||||
is_sparse: bool
|
||||
ffn_input_mode: _FFNInputMode
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -1009,14 +1024,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
is_nextn: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
|
||||
def is_sparse_layer(l: int):
|
||||
return (
|
||||
config.n_routed_experts is not None
|
||||
and l >= config.first_k_dense_replace
|
||||
and l % config.moe_layer_freq == 0
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@@ -1047,13 +1054,17 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
if is_nextn or is_sparse_layer(layer_id):
|
||||
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
|
||||
previous_layer_info = self._compute_info(
|
||||
config, layer_id=layer_id - 1, is_nextn=False
|
||||
)
|
||||
|
||||
if self.info.is_sparse:
|
||||
self.mlp = DeepseekV2MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.is_sparse = True
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -1062,11 +1073,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.is_sparse = False
|
||||
|
||||
self.input_is_scattered = (
|
||||
is_sparse_layer(layer_id - 1)
|
||||
and global_server_args_dict["enable_deepep_moe"]
|
||||
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
||||
)
|
||||
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
||||
|
||||
@@ -1075,6 +1084,20 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
|
||||
is_sparse = is_nextn or (
|
||||
config.n_routed_experts is not None
|
||||
and layer_id >= config.first_k_dense_replace
|
||||
and layer_id % config.moe_layer_freq == 0
|
||||
)
|
||||
ffn_input_mode = (
|
||||
_FFNInputMode.SCATTERED
|
||||
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
|
||||
else _FFNInputMode.FULL
|
||||
)
|
||||
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -1082,16 +1105,18 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
|
||||
return self.forward_deepep(
|
||||
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
||||
return self.forward_ffn_with_scattered_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
||||
return self.forward_ffn_with_full_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
else:
|
||||
return self.forward_normal(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_normal(
|
||||
def forward_ffn_with_full_input(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1158,7 +1183,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_deepep(
|
||||
def forward_ffn_with_scattered_input(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1214,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
|
||||
if self.is_last_layer and self.attn_tp_size != 1:
|
||||
|
||||
Reference in New Issue
Block a user