From 8beb356f0daada03ac99ffbbe75181e4867533c3 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 17 Apr 2025 17:11:11 +0800 Subject: [PATCH] Refactor DeepSeek decoder layer branches (#5205) --- python/sglang/srt/models/deepseek_v2.py | 70 +++++++++++++++++-------- 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 22831a310..ad9262d2c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -18,7 +18,8 @@ import logging import os -from enum import IntEnum, auto +from dataclasses import dataclass +from enum import Enum, IntEnum, auto from typing import Any, Dict, Iterable, Optional, Tuple import torch @@ -28,6 +29,7 @@ from tqdm import tqdm from transformers import PretrainedConfig from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, parallel_state, tensor_model_parallel_all_reduce, @@ -146,7 +148,7 @@ class DeepseekV2MLP(nn.Module): ) self.act_fn = SiluAndMul() - def forward(self, x): + def forward(self, x, forward_mode: Optional[ForwardMode] = None): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -999,6 +1001,19 @@ class DeepseekV2AttentionMLA(nn.Module): return output +class _FFNInputMode(Enum): + # The MLP sublayer requires 1/tp_size tokens as input + SCATTERED = auto() + # The MLP sublayer requires all tokens as input + FULL = auto() + + +@dataclass +class _DecoderLayerInfo: + is_sparse: bool + ffn_input_mode: _FFNInputMode + + class DeepseekV2DecoderLayer(nn.Module): def __init__( @@ -1009,14 +1024,6 @@ class DeepseekV2DecoderLayer(nn.Module): is_nextn: bool = False, prefix: str = "", ) -> None: - - def is_sparse_layer(l: int): - return ( - config.n_routed_experts is not None - and l >= config.first_k_dense_replace - and l % config.moe_layer_freq == 0 - ) - super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -1047,13 +1054,17 @@ class DeepseekV2DecoderLayer(nn.Module): prefix=add_prefix("self_attn", prefix), ) - if is_nextn or is_sparse_layer(layer_id): + self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn) + previous_layer_info = self._compute_info( + config, layer_id=layer_id - 1, is_nextn=False + ) + + if self.info.is_sparse: self.mlp = DeepseekV2MoE( config=config, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) - self.is_sparse = True else: self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, @@ -1062,11 +1073,9 @@ class DeepseekV2DecoderLayer(nn.Module): quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) - self.is_sparse = False self.input_is_scattered = ( - is_sparse_layer(layer_id - 1) - and global_server_args_dict["enable_deepep_moe"] + previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 @@ -1075,6 +1084,20 @@ class DeepseekV2DecoderLayer(nn.Module): config.hidden_size, eps=config.rms_norm_eps ) + @staticmethod + def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): + is_sparse = is_nextn or ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ) + ffn_input_mode = ( + _FFNInputMode.SCATTERED + if (global_server_args_dict["enable_deepep_moe"] and is_sparse) + else _FFNInputMode.FULL + ) + return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode) + def forward( self, positions: torch.Tensor, @@ -1082,16 +1105,18 @@ class DeepseekV2DecoderLayer(nn.Module): forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: - if global_server_args_dict["enable_deepep_moe"] and self.is_sparse: - return self.forward_deepep( + if self.info.ffn_input_mode == _FFNInputMode.SCATTERED: + return self.forward_ffn_with_scattered_input( + positions, hidden_states, forward_batch, residual + ) + elif self.info.ffn_input_mode == _FFNInputMode.FULL: + return self.forward_ffn_with_full_input( positions, hidden_states, forward_batch, residual ) else: - return self.forward_normal( - positions, hidden_states, forward_batch, residual - ) + raise NotImplementedError - def forward_normal( + def forward_ffn_with_full_input( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -1158,7 +1183,7 @@ class DeepseekV2DecoderLayer(nn.Module): return hidden_states, residual - def forward_deepep( + def forward_ffn_with_scattered_input( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -1214,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states, residual = self.post_attention_layernorm( hidden_states, residual ) + hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) if self.is_last_layer and self.attn_tp_size != 1: