Refactor DeepSeek decoder layer branches (#5205)
This commit is contained in:
@@ -18,7 +18,8 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from enum import IntEnum, auto
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, IntEnum, auto
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -28,6 +29,7 @@ from tqdm import tqdm
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
parallel_state,
|
parallel_state,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
@@ -146,7 +148,7 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
)
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, forward_mode: Optional[ForwardMode] = None):
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(x)
|
x, _ = self.down_proj(x)
|
||||||
@@ -999,6 +1001,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class _FFNInputMode(Enum):
|
||||||
|
# The MLP sublayer requires 1/tp_size tokens as input
|
||||||
|
SCATTERED = auto()
|
||||||
|
# The MLP sublayer requires all tokens as input
|
||||||
|
FULL = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _DecoderLayerInfo:
|
||||||
|
is_sparse: bool
|
||||||
|
ffn_input_mode: _FFNInputMode
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2DecoderLayer(nn.Module):
|
class DeepseekV2DecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1009,14 +1024,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
is_nextn: bool = False,
|
is_nextn: bool = False,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
def is_sparse_layer(l: int):
|
|
||||||
return (
|
|
||||||
config.n_routed_experts is not None
|
|
||||||
and l >= config.first_k_dense_replace
|
|
||||||
and l % config.moe_layer_freq == 0
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
@@ -1047,13 +1054,17 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
prefix=add_prefix("self_attn", prefix),
|
prefix=add_prefix("self_attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_nextn or is_sparse_layer(layer_id):
|
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
|
||||||
|
previous_layer_info = self._compute_info(
|
||||||
|
config, layer_id=layer_id - 1, is_nextn=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.info.is_sparse:
|
||||||
self.mlp = DeepseekV2MoE(
|
self.mlp = DeepseekV2MoE(
|
||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("mlp", prefix),
|
prefix=add_prefix("mlp", prefix),
|
||||||
)
|
)
|
||||||
self.is_sparse = True
|
|
||||||
else:
|
else:
|
||||||
self.mlp = DeepseekV2MLP(
|
self.mlp = DeepseekV2MLP(
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
@@ -1062,11 +1073,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("mlp", prefix),
|
prefix=add_prefix("mlp", prefix),
|
||||||
)
|
)
|
||||||
self.is_sparse = False
|
|
||||||
|
|
||||||
self.input_is_scattered = (
|
self.input_is_scattered = (
|
||||||
is_sparse_layer(layer_id - 1)
|
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
||||||
and global_server_args_dict["enable_deepep_moe"]
|
|
||||||
)
|
)
|
||||||
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
||||||
|
|
||||||
@@ -1075,6 +1084,20 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
config.hidden_size, eps=config.rms_norm_eps
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
|
||||||
|
is_sparse = is_nextn or (
|
||||||
|
config.n_routed_experts is not None
|
||||||
|
and layer_id >= config.first_k_dense_replace
|
||||||
|
and layer_id % config.moe_layer_freq == 0
|
||||||
|
)
|
||||||
|
ffn_input_mode = (
|
||||||
|
_FFNInputMode.SCATTERED
|
||||||
|
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
|
||||||
|
else _FFNInputMode.FULL
|
||||||
|
)
|
||||||
|
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -1082,16 +1105,18 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
|
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
||||||
return self.forward_deepep(
|
return self.forward_ffn_with_scattered_input(
|
||||||
|
positions, hidden_states, forward_batch, residual
|
||||||
|
)
|
||||||
|
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
||||||
|
return self.forward_ffn_with_full_input(
|
||||||
positions, hidden_states, forward_batch, residual
|
positions, hidden_states, forward_batch, residual
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_normal(
|
raise NotImplementedError
|
||||||
positions, hidden_states, forward_batch, residual
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_normal(
|
def forward_ffn_with_full_input(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -1158,7 +1183,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
def forward_deepep(
|
def forward_ffn_with_scattered_input(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -1214,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
hidden_states, residual
|
hidden_states, residual
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||||
|
|
||||||
if self.is_last_layer and self.attn_tp_size != 1:
|
if self.is_last_layer and self.attn_tp_size != 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user