2024-11-22 22:16:53 +08:00
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-07-28 23:07:12 +10:00
2024-07-26 17:10:07 -07:00
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
""" Inference-only DeepseekV2 model. """
2024-11-22 22:16:53 +08:00
2025-04-04 16:59:29 +08:00
import logging
2025-02-25 07:44:31 +05:30
import os
2024-07-26 17:10:07 -07:00
from typing import Any , Dict , Iterable , Optional , Tuple
import torch
2024-12-24 01:10:22 +08:00
import torch . nn . functional as F
2024-07-26 17:10:07 -07:00
from torch import nn
2025-04-04 16:59:29 +08:00
from tqdm import tqdm
2024-07-26 17:10:07 -07:00
from transformers import PretrainedConfig
2025-01-17 22:31:51 +08:00
from sglang . srt . distributed import (
2024-07-26 17:10:07 -07:00
get_tensor_model_parallel_world_size ,
2025-03-19 23:16:31 +08:00
parallel_state ,
2024-07-26 17:10:07 -07:00
tensor_model_parallel_all_reduce ,
)
2024-08-13 00:15:59 +08:00
from sglang . srt . layers . activation import SiluAndMul
2025-03-13 08:23:56 -07:00
from sglang . srt . layers . dp_attention import (
2025-03-18 16:41:36 -04:00
dp_gather_partial ,
2025-03-13 08:23:56 -07:00
dp_scatter ,
get_attention_dp_size ,
get_attention_tp_rank ,
get_attention_tp_size ,
2025-03-27 20:09:35 -04:00
tp_all_gather ,
tp_reduce_scatter ,
2025-03-13 08:23:56 -07:00
)
2024-08-13 00:15:59 +08:00
from sglang . srt . layers . layernorm import RMSNorm
2024-09-19 20:53:11 +08:00
from sglang . srt . layers . linear import (
ColumnParallelLinear ,
MergedColumnParallelLinear ,
ReplicatedLinear ,
RowParallelLinear ,
)
2024-07-26 17:10:07 -07:00
from sglang . srt . layers . logits_processor import LogitsProcessor
2025-03-19 23:16:31 +08:00
from sglang . srt . layers . moe . ep_moe . layer import DeepEPMoE , EPMoE
from sglang . srt . layers . moe . ep_moe . token_dispatcher import DeepEPDispatcher
2024-12-24 01:10:22 +08:00
from sglang . srt . layers . moe . fused_moe_triton import FusedMoE
2025-03-19 23:16:31 +08:00
from sglang . srt . layers . moe . topk import select_experts
2024-09-19 20:53:11 +08:00
from sglang . srt . layers . quantization . base_config import QuantizationConfig
2024-12-26 00:02:14 +08:00
from sglang . srt . layers . quantization . fp8_utils import (
block_quant_to_tensor_quant ,
input_to_float8 ,
2024-12-30 05:31:12 -08:00
normalize_e4m3fn_to_e4m3fnuz ,
2024-12-26 00:02:14 +08:00
)
2025-02-24 21:43:35 +08:00
from sglang . srt . layers . quantization . int8_utils import (
block_dequant as int8_block_dequant ,
)
2024-07-26 17:10:07 -07:00
from sglang . srt . layers . radix_attention import RadixAttention
2025-03-12 23:45:52 -07:00
from sglang . srt . layers . rotary_embedding import get_rope , get_rope_wrapper
2024-11-01 11:13:07 +08:00
from sglang . srt . layers . vocab_parallel_embedding import (
ParallelLMHead ,
VocabParallelEmbedding ,
)
2025-03-27 04:21:25 +08:00
from sglang . srt . managers . expert_distribution import ExpertDistributionRecorder
2024-08-05 01:40:33 +08:00
from sglang . srt . managers . schedule_batch import global_server_args_dict
2025-03-19 23:16:31 +08:00
from sglang . srt . model_executor . forward_batch_info import ForwardBatch , ForwardMode
2024-12-02 23:22:13 +08:00
from sglang . srt . model_loader . weight_utils import default_weight_loader
2025-04-03 17:56:44 +08:00
from sglang . srt . utils import DeepEPMode , add_prefix , is_cuda , is_hip
2024-12-30 05:31:12 -08:00
2025-03-11 18:12:56 -07:00
_is_hip = is_hip ( )
2025-03-15 23:06:17 -07:00
_is_cuda = is_cuda ( )
2024-09-17 00:43:52 -07:00
2025-03-15 23:06:17 -07:00
if _is_cuda :
from sgl_kernel import awq_dequantize , bmm_fp8
2025-03-18 00:11:36 -07:00
else :
from vllm import _custom_ops as ops
2024-07-26 17:10:07 -07:00
2025-03-30 12:51:44 -07:00
if _is_hip :
from sglang . srt . layers . attention . triton_ops . rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope ,
)
2025-03-24 21:34:19 -07:00
expert_distribution_recorder = ExpertDistributionRecorder ( )
2025-04-04 16:59:29 +08:00
logger = logging . getLogger ( __name__ )
2024-07-26 17:10:07 -07:00
class DeepseekV2MLP ( nn . Module ) :
def __init__ (
self ,
hidden_size : int ,
intermediate_size : int ,
hidden_act : str ,
quant_config : Optional [ QuantizationConfig ] = None ,
reduce_results : bool = True ,
2025-03-05 17:11:00 +08:00
prefix : str = " " ,
2025-03-19 23:16:31 +08:00
tp_rank : Optional [ int ] = None ,
tp_size : Optional [ int ] = None ,
2024-07-26 17:10:07 -07:00
) - > None :
super ( ) . __init__ ( )
self . gate_up_proj = MergedColumnParallelLinear (
2025-03-05 17:11:00 +08:00
hidden_size ,
[ intermediate_size ] * 2 ,
bias = False ,
quant_config = quant_config ,
prefix = add_prefix ( " gate_up_proj " , prefix ) ,
2025-03-19 23:16:31 +08:00
tp_rank = tp_rank ,
tp_size = tp_size ,
2024-07-26 17:10:07 -07:00
)
self . down_proj = RowParallelLinear (
intermediate_size ,
hidden_size ,
bias = False ,
quant_config = quant_config ,
reduce_results = reduce_results ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " down_proj " , prefix ) ,
2025-03-19 23:16:31 +08:00
tp_rank = tp_rank ,
tp_size = tp_size ,
2024-07-26 17:10:07 -07:00
)
if hidden_act != " silu " :
raise ValueError (
f " Unsupported activation: { hidden_act } . "
" Only silu is supported for now. "
)
self . act_fn = SiluAndMul ( )
def forward ( self , x ) :
gate_up , _ = self . gate_up_proj ( x )
x = self . act_fn ( gate_up )
x , _ = self . down_proj ( x )
return x
2024-12-24 01:10:22 +08:00
class MoEGate ( nn . Module ) :
2025-03-05 17:11:00 +08:00
def __init__ (
self ,
config ,
prefix : str = " " ,
) :
2024-12-24 01:10:22 +08:00
super ( ) . __init__ ( )
self . weight = nn . Parameter (
torch . empty ( ( config . n_routed_experts , config . hidden_size ) )
)
if config . topk_method == " noaux_tc " :
self . e_score_correction_bias = nn . Parameter (
torch . empty ( ( config . n_routed_experts ) )
)
else :
self . e_score_correction_bias = None
def forward ( self , hidden_states ) :
logits = F . linear ( hidden_states , self . weight , None )
return logits
2024-07-26 17:10:07 -07:00
class DeepseekV2MoE ( nn . Module ) :
def __init__ (
self ,
config : PretrainedConfig ,
quant_config : Optional [ QuantizationConfig ] = None ,
2025-03-05 17:11:00 +08:00
prefix : str = " " ,
2024-07-26 17:10:07 -07:00
) :
super ( ) . __init__ ( )
self . tp_size = get_tensor_model_parallel_world_size ( )
self . routed_scaling_factor = config . routed_scaling_factor
self . n_shared_experts = config . n_shared_experts
2025-04-04 16:59:29 +08:00
self . n_share_experts_fusion = (
global_server_args_dict [ " n_share_experts_fusion " ]
if global_server_args_dict [ " n_share_experts_fusion " ] is not None
else 0
)
2024-07-26 17:10:07 -07:00
self . routed_scaling_factor = config . routed_scaling_factor
if self . tp_size > config . n_routed_experts :
raise ValueError (
f " Tensor parallel size { self . tp_size } is greater than "
f " the number of experts { config . n_routed_experts } . "
)
if config . hidden_act != " silu " :
raise ValueError (
f " Unsupported activation: { config . hidden_act } . "
" Only silu is supported for now. "
)
2025-03-05 17:11:00 +08:00
self . gate = MoEGate ( config = config , prefix = add_prefix ( " gate " , prefix ) )
2024-12-24 01:10:22 +08:00
2025-03-19 23:16:31 +08:00
MoEImpl = (
DeepEPMoE
if global_server_args_dict [ " enable_deepep_moe " ]
else ( EPMoE if global_server_args_dict [ " enable_ep_moe " ] else FusedMoE )
)
2025-04-04 16:59:29 +08:00
2025-04-04 15:24:18 +08:00
self . experts = MoEImpl (
2025-04-04 16:59:29 +08:00
num_experts = config . n_routed_experts + self . n_share_experts_fusion ,
top_k = config . num_experts_per_tok + min ( self . n_share_experts_fusion , 1 ) ,
2025-04-04 15:24:18 +08:00
hidden_size = config . hidden_size ,
intermediate_size = config . moe_intermediate_size ,
renormalize = config . norm_topk_prob ,
quant_config = quant_config ,
use_grouped_topk = True ,
num_expert_group = config . n_group ,
topk_group = config . topk_group ,
correction_bias = self . gate . e_score_correction_bias ,
prefix = add_prefix ( " experts " , prefix ) ,
* * (
dict ( deepep_mode = DeepEPMode [ global_server_args_dict [ " deepep_mode " ] ] )
if global_server_args_dict [ " enable_deepep_moe " ]
else { }
) ,
)
2024-07-26 17:10:07 -07:00
2025-04-04 16:59:29 +08:00
if config . n_shared_experts is not None and self . n_share_experts_fusion == 0 :
2024-07-26 17:10:07 -07:00
intermediate_size = config . moe_intermediate_size * config . n_shared_experts
2025-03-19 23:16:31 +08:00
# disable tp for shared experts when enable deepep moe
if not global_server_args_dict [ " enable_deepep_moe " ] :
self . shared_experts = DeepseekV2MLP (
hidden_size = config . hidden_size ,
intermediate_size = intermediate_size ,
hidden_act = config . hidden_act ,
quant_config = quant_config ,
reduce_results = False ,
prefix = add_prefix ( " shared_experts " , prefix ) ,
)
else :
self . shared_experts = DeepseekV2MLP (
hidden_size = config . hidden_size ,
intermediate_size = intermediate_size ,
hidden_act = config . hidden_act ,
quant_config = quant_config ,
reduce_results = False ,
prefix = add_prefix ( " shared_experts " , prefix ) ,
tp_rank = 0 ,
tp_size = 1 ,
)
if global_server_args_dict [ " enable_deepep_moe " ] :
2025-04-02 00:23:25 +08:00
# TODO: we will support tp < ep in the future
self . ep_size = get_tensor_model_parallel_world_size ( )
2025-03-19 23:16:31 +08:00
self . num_experts = config . n_routed_experts
self . top_k = config . num_experts_per_tok
self . renormalize = config . norm_topk_prob
self . topk_group = config . topk_group
self . num_expert_group = config . n_group
self . correction_bias = (
self . gate . e_score_correction_bias . data
if self . gate . e_score_correction_bias is not None
else None
)
self . deepep_dispatcher = DeepEPDispatcher (
group = parallel_state . get_tp_group ( ) . device_group ,
router_topk = self . top_k ,
permute_fusion = True ,
num_experts = config . n_routed_experts ,
num_local_experts = config . n_routed_experts / / self . tp_size ,
2024-07-26 17:10:07 -07:00
hidden_size = config . hidden_size ,
2025-03-19 23:16:31 +08:00
params_dtype = config . torch_dtype ,
2025-04-03 17:56:44 +08:00
deepep_mode = DeepEPMode [ global_server_args_dict [ " deepep_mode " ] ] ,
2025-03-23 13:39:56 +08:00
async_finish = True , # TODO
2025-04-02 00:23:25 +08:00
return_recv_hook = True ,
2024-07-26 17:10:07 -07:00
)
2025-03-19 23:16:31 +08:00
def forward (
self , hidden_states : torch . Tensor , forward_mode : Optional [ ForwardMode ] = None
) - > torch . Tensor :
if not global_server_args_dict [ " enable_deepep_moe " ] :
return self . forward_normal ( hidden_states )
else :
return self . forward_deepep ( hidden_states , forward_mode )
def forward_normal ( self , hidden_states : torch . Tensor ) - > torch . Tensor :
2025-04-04 16:59:29 +08:00
if self . n_shared_experts is not None and self . n_share_experts_fusion == 0 :
2024-07-26 17:10:07 -07:00
shared_output = self . shared_experts ( hidden_states )
2025-04-04 16:59:29 +08:00
else :
shared_output = None
2024-07-26 17:10:07 -07:00
# router_logits: (num_tokens, n_experts)
2024-12-24 01:10:22 +08:00
router_logits = self . gate ( hidden_states )
2024-07-26 17:10:07 -07:00
final_hidden_states = (
self . experts ( hidden_states = hidden_states , router_logits = router_logits )
* self . routed_scaling_factor
)
if shared_output is not None :
final_hidden_states = final_hidden_states + shared_output
if self . tp_size > 1 :
final_hidden_states = tensor_model_parallel_all_reduce ( final_hidden_states )
2025-03-23 13:39:11 +08:00
return final_hidden_states
2025-03-19 23:16:31 +08:00
def forward_deepep (
self , hidden_states : torch . Tensor , forward_mode : ForwardMode
) - > torch . Tensor :
shared_output = None
topk_idx = torch . full (
( 0 , self . top_k ) , - 1 , dtype = torch . int , device = hidden_states . device
)
topk_weights = torch . empty (
( 0 , self . top_k ) , dtype = torch . float32 , device = hidden_states . device
)
2025-03-27 20:09:35 -04:00
if (
forward_mode is not None
and not forward_mode . is_idle ( )
and hidden_states . shape [ 0 ] > 0
) :
2025-03-19 23:16:31 +08:00
# router_logits: (num_tokens, n_experts)
router_logits = self . gate ( hidden_states )
if self . n_shared_experts is not None :
shared_output = self . shared_experts ( hidden_states )
topk_weights , topk_idx = select_experts (
hidden_states = hidden_states ,
router_logits = router_logits ,
top_k = self . top_k ,
use_grouped_topk = True ,
renormalize = self . renormalize ,
topk_group = self . topk_group ,
num_expert_group = self . num_expert_group ,
correction_bias = self . correction_bias ,
)
2025-04-02 00:23:25 +08:00
if self . ep_size > 1 :
(
hidden_states ,
topk_idx ,
topk_weights ,
reorder_topk_ids ,
seg_indptr ,
masked_m ,
expected_m ,
) = self . deepep_dispatcher . dispatch (
hidden_states ,
topk_idx ,
topk_weights ,
self . num_experts ,
forward_mode = forward_mode ,
2025-03-19 23:16:31 +08:00
)
final_hidden_states = (
self . experts (
2025-04-02 00:23:25 +08:00
hidden_states = hidden_states ,
2025-03-23 05:30:34 +08:00
reorder_topk_ids = reorder_topk_ids ,
seg_indptr = seg_indptr ,
2025-04-02 00:23:25 +08:00
masked_m = masked_m ,
expected_m = expected_m ,
2025-03-19 23:16:31 +08:00
forward_mode = forward_mode ,
)
* self . routed_scaling_factor
)
2025-04-02 00:23:25 +08:00
if self . ep_size > 1 :
2025-03-19 23:16:31 +08:00
final_hidden_states = self . deepep_dispatcher . combine (
2025-04-02 00:23:25 +08:00
final_hidden_states ,
topk_idx ,
topk_weights ,
forward_mode ,
2025-03-19 23:16:31 +08:00
)
if shared_output is not None :
final_hidden_states = final_hidden_states + shared_output
2024-07-26 17:10:07 -07:00
2025-03-23 13:39:11 +08:00
return final_hidden_states
2024-07-26 17:10:07 -07:00
def yarn_get_mscale ( scale : float = 1 , mscale : float = 1 ) - > float :
import math
if scale < = 1 :
return 1.0
return 0.1 * mscale * math . log ( scale ) + 1.0
class DeepseekV2Attention ( nn . Module ) :
def __init__ (
self ,
config : PretrainedConfig ,
hidden_size : int ,
num_heads : int ,
qk_nope_head_dim : int ,
qk_rope_head_dim : int ,
v_head_dim : int ,
q_lora_rank : int ,
kv_lora_rank : int ,
rope_theta : float = 10000 ,
rope_scaling : Optional [ Dict [ str , Any ] ] = None ,
max_position_embeddings : int = 8192 ,
quant_config : Optional [ QuantizationConfig ] = None ,
layer_id = None ,
2025-03-13 08:23:56 -07:00
reduce_results : bool = True ,
2025-03-05 17:11:00 +08:00
prefix : str = " " ,
2024-07-26 17:10:07 -07:00
) - > None :
super ( ) . __init__ ( )
self . layer_id = layer_id
self . hidden_size = hidden_size
self . qk_nope_head_dim = qk_nope_head_dim
self . qk_rope_head_dim = qk_rope_head_dim
self . qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self . v_head_dim = v_head_dim
self . q_lora_rank = q_lora_rank
self . kv_lora_rank = kv_lora_rank
2025-03-13 08:23:56 -07:00
self . dp_size = get_attention_dp_size ( )
attn_tp_rank = get_attention_tp_rank ( )
attn_tp_size = get_attention_tp_size ( )
2024-07-26 17:10:07 -07:00
self . num_heads = num_heads
2025-03-13 08:23:56 -07:00
assert num_heads % attn_tp_size == 0
self . num_local_heads = num_heads / / attn_tp_size
2024-07-26 17:10:07 -07:00
self . scaling = self . qk_head_dim * * - 0.5
self . rope_theta = rope_theta
self . max_position_embeddings = max_position_embeddings
if self . q_lora_rank is not None :
self . q_a_proj = ReplicatedLinear (
self . hidden_size ,
self . q_lora_rank ,
bias = False ,
quant_config = quant_config ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " q_a_proj " , prefix ) ,
2024-07-26 17:10:07 -07:00
)
self . q_a_layernorm = RMSNorm ( self . q_lora_rank , eps = config . rms_norm_eps )
self . q_b_proj = ColumnParallelLinear (
q_lora_rank ,
self . num_heads * self . qk_head_dim ,
bias = False ,
quant_config = quant_config ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " q_b_proj " , prefix ) ,
2024-07-26 17:10:07 -07:00
)
else :
self . q_proj = ColumnParallelLinear (
self . hidden_size ,
self . num_heads * self . qk_head_dim ,
bias = False ,
quant_config = quant_config ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " q_proj " , prefix ) ,
2025-03-13 08:23:56 -07:00
tp_rank = attn_tp_rank ,
tp_size = attn_tp_size ,
2024-07-26 17:10:07 -07:00
)
self . kv_a_proj_with_mqa = ReplicatedLinear (
self . hidden_size ,
self . kv_lora_rank + self . qk_rope_head_dim ,
bias = False ,
quant_config = quant_config ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " kv_a_proj_with_mqa " , prefix ) ,
2024-07-26 17:10:07 -07:00
)
self . kv_a_layernorm = RMSNorm ( self . kv_lora_rank , eps = config . rms_norm_eps )
self . kv_b_proj = ColumnParallelLinear (
self . kv_lora_rank ,
self . num_heads * ( self . qk_nope_head_dim + self . v_head_dim ) ,
bias = False ,
quant_config = quant_config ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " kv_b_proj " , prefix ) ,
2024-07-26 17:10:07 -07:00
)
# O projection.
self . o_proj = RowParallelLinear (
self . num_heads * self . v_head_dim ,
self . hidden_size ,
bias = False ,
quant_config = quant_config ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " o_proj " , prefix ) ,
2025-03-13 08:23:56 -07:00
reduce_results = reduce_results ,
tp_rank = attn_tp_rank ,
tp_size = attn_tp_size ,
2024-07-26 17:10:07 -07:00
)
2024-10-19 20:45:41 -07:00
rope_scaling [ " rope_type " ] = " deepseek_yarn "
2025-01-19 21:33:27 +08:00
self . rotary_emb = get_rope_wrapper (
2024-07-26 17:10:07 -07:00
qk_rope_head_dim ,
rotary_dim = qk_rope_head_dim ,
max_position = max_position_embeddings ,
base = rope_theta ,
rope_scaling = rope_scaling ,
is_neox_style = False ,
2025-03-12 23:45:52 -07:00
device = global_server_args_dict [ " device " ] ,
2024-07-26 17:10:07 -07:00
)
if rope_scaling :
mscale_all_dim = rope_scaling . get ( " mscale_all_dim " , False )
scaling_factor = rope_scaling [ " factor " ]
mscale = yarn_get_mscale ( scaling_factor , float ( mscale_all_dim ) )
self . scaling = self . scaling * mscale * mscale
# TODO, support head_size 192
self . attn = RadixAttention (
self . num_local_heads ,
256 ,
self . scaling ,
num_kv_heads = self . num_local_heads ,
layer_id = layer_id ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " attn " , prefix ) ,
2024-07-26 17:10:07 -07:00
)
def forward (
self ,
positions : torch . Tensor ,
hidden_states : torch . Tensor ,
2024-09-30 02:41:11 -07:00
forward_batch : ForwardBatch ,
2024-07-26 17:10:07 -07:00
) - > torch . Tensor :
2025-03-13 08:23:56 -07:00
if hidden_states . shape [ 0 ] == 0 :
assert (
not self . o_proj . reduce_results
) , " short-circuiting allreduce will lead to hangs "
return hidden_states
2024-07-26 17:10:07 -07:00
if self . q_lora_rank is not None :
q = self . q_a_proj ( hidden_states ) [ 0 ]
q = self . q_a_layernorm ( q )
q = self . q_b_proj ( q ) [ 0 ] . view ( - 1 , self . num_local_heads , self . qk_head_dim )
else :
q = self . q_proj ( hidden_states ) [ 0 ] . view (
- 1 , self . num_local_heads , self . qk_head_dim
)
2024-09-01 17:28:06 +10:00
_ , q_pe = q . split ( [ self . qk_nope_head_dim , self . qk_rope_head_dim ] , dim = - 1 )
2024-07-26 17:10:07 -07:00
latent_cache = self . kv_a_proj_with_mqa ( hidden_states ) [ 0 ]
kv_a , _ = latent_cache . split ( [ self . kv_lora_rank , self . qk_rope_head_dim ] , dim = - 1 )
latent_cache = latent_cache . unsqueeze ( 1 )
kv_a = self . kv_a_layernorm ( kv_a . contiguous ( ) )
kv = self . kv_b_proj ( kv_a ) [ 0 ]
kv = kv . view ( - 1 , self . num_local_heads , self . qk_nope_head_dim + self . v_head_dim )
k_nope , v = kv . split ( [ self . qk_nope_head_dim , self . v_head_dim ] , dim = - 1 )
k_pe = latent_cache [ : , : , self . kv_lora_rank : ]
q_pe , k_pe = self . rotary_emb ( positions , q_pe , k_pe )
q [ . . . , self . qk_nope_head_dim : ] = q_pe
k = torch . empty_like ( q )
k [ . . . , : self . qk_nope_head_dim ] = k_nope
k [ . . . , self . qk_nope_head_dim : ] = k_pe
q = torch . nn . functional . pad ( q , [ 0 , 256 - self . qk_head_dim ] , value = 0 ) . view (
- 1 , self . num_local_heads * 256
)
k = torch . nn . functional . pad ( k , [ 0 , 256 - self . qk_head_dim ] , value = 0 ) . view (
- 1 , self . num_local_heads * 256
)
v = torch . nn . functional . pad ( v , [ 0 , 256 - self . v_head_dim ] , value = 0 ) . view (
- 1 , self . num_local_heads * 256
)
2024-09-30 02:41:11 -07:00
attn_output = self . attn ( q , k , v , forward_batch )
2024-07-26 17:10:07 -07:00
attn_output = attn_output . view ( - 1 , self . num_local_heads , 256 ) [
. . . , : self . v_head_dim
] . reshape ( - 1 , self . num_local_heads * self . v_head_dim )
output , _ = self . o_proj ( attn_output )
return output
2024-08-05 01:40:33 +08:00
class DeepseekV2AttentionMLA ( nn . Module ) :
def __init__ (
self ,
config : PretrainedConfig ,
hidden_size : int ,
num_heads : int ,
qk_nope_head_dim : int ,
qk_rope_head_dim : int ,
v_head_dim : int ,
q_lora_rank : int ,
kv_lora_rank : int ,
rope_theta : float = 10000 ,
rope_scaling : Optional [ Dict [ str , Any ] ] = None ,
max_position_embeddings : int = 8192 ,
quant_config : Optional [ QuantizationConfig ] = None ,
2025-03-13 08:23:56 -07:00
reduce_results : bool = True ,
layer_id : int = None ,
2025-03-05 17:11:00 +08:00
prefix : str = " " ,
2024-08-05 01:40:33 +08:00
) - > None :
super ( ) . __init__ ( )
self . layer_id = layer_id
self . hidden_size = hidden_size
self . qk_nope_head_dim = qk_nope_head_dim
self . qk_rope_head_dim = qk_rope_head_dim
self . qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self . v_head_dim = v_head_dim
self . q_lora_rank = q_lora_rank
self . kv_lora_rank = kv_lora_rank
2025-03-13 08:23:56 -07:00
self . dp_size = get_attention_dp_size ( )
attn_tp_rank = get_attention_tp_rank ( )
attn_tp_size = get_attention_tp_size ( )
2024-08-05 01:40:33 +08:00
self . num_heads = num_heads
2025-03-13 08:23:56 -07:00
assert num_heads % attn_tp_size == 0
self . num_local_heads = num_heads / / attn_tp_size
2024-08-05 01:40:33 +08:00
self . scaling = self . qk_head_dim * * - 0.5
self . rope_theta = rope_theta
self . max_position_embeddings = max_position_embeddings
2025-03-13 08:23:56 -07:00
# For tensor parallel attention
if self . q_lora_rank is not None :
self . q_a_proj = ReplicatedLinear (
2024-11-16 17:01:43 +08:00
self . hidden_size ,
2025-03-13 08:23:56 -07:00
self . q_lora_rank ,
2024-08-05 01:40:33 +08:00
bias = False ,
quant_config = quant_config ,
2025-03-13 08:23:56 -07:00
prefix = add_prefix ( " q_a_proj " , prefix ) ,
2024-08-05 01:40:33 +08:00
)
2025-03-13 08:23:56 -07:00
self . q_a_layernorm = RMSNorm ( self . q_lora_rank , eps = config . rms_norm_eps )
self . q_b_proj = ColumnParallelLinear (
q_lora_rank ,
self . num_heads * self . qk_head_dim ,
2024-11-16 17:01:43 +08:00
bias = False ,
quant_config = quant_config ,
2025-03-13 08:23:56 -07:00
prefix = add_prefix ( " q_b_proj " , prefix ) ,
tp_rank = attn_tp_rank ,
tp_size = attn_tp_size ,
2024-11-16 17:01:43 +08:00
)
2025-03-13 08:23:56 -07:00
else :
self . q_proj = ColumnParallelLinear (
2024-08-05 01:40:33 +08:00
self . hidden_size ,
2025-03-13 08:23:56 -07:00
self . num_heads * self . qk_head_dim ,
2024-08-05 01:40:33 +08:00
bias = False ,
quant_config = quant_config ,
2025-03-13 08:23:56 -07:00
prefix = add_prefix ( " q_proj " , prefix ) ,
tp_rank = attn_tp_rank ,
tp_size = attn_tp_size ,
2024-08-05 01:40:33 +08:00
)
2025-03-13 08:23:56 -07:00
self . kv_b_proj = ColumnParallelLinear (
self . kv_lora_rank ,
self . num_heads * ( self . qk_nope_head_dim + self . v_head_dim ) ,
bias = False ,
quant_config = quant_config ,
prefix = add_prefix ( " kv_b_proj " , prefix ) ,
tp_rank = attn_tp_rank ,
tp_size = attn_tp_size ,
)
# O projection.
self . o_proj = RowParallelLinear (
self . num_heads * self . v_head_dim ,
self . hidden_size ,
bias = False ,
quant_config = quant_config ,
reduce_results = reduce_results ,
prefix = add_prefix ( " o_proj " , prefix ) ,
tp_rank = attn_tp_rank ,
tp_size = attn_tp_size ,
)
2024-08-05 01:40:33 +08:00
self . kv_a_proj_with_mqa = ReplicatedLinear (
self . hidden_size ,
self . kv_lora_rank + self . qk_rope_head_dim ,
bias = False ,
quant_config = quant_config ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " kv_a_proj_with_mqa " , prefix ) ,
2024-08-05 01:40:33 +08:00
)
self . kv_a_layernorm = RMSNorm ( self . kv_lora_rank , eps = config . rms_norm_eps )
2024-12-24 01:10:22 +08:00
if rope_scaling :
rope_scaling [ " rope_type " ] = " deepseek_yarn "
2025-03-12 23:45:52 -07:00
self . rotary_emb = get_rope (
2024-08-05 01:40:33 +08:00
qk_rope_head_dim ,
rotary_dim = qk_rope_head_dim ,
max_position = max_position_embeddings ,
base = rope_theta ,
rope_scaling = rope_scaling ,
is_neox_style = False ,
)
if rope_scaling :
mscale_all_dim = rope_scaling . get ( " mscale_all_dim " , False )
scaling_factor = rope_scaling [ " factor " ]
mscale = yarn_get_mscale ( scaling_factor , float ( mscale_all_dim ) )
self . scaling = self . scaling * mscale * mscale
2024-12-24 01:10:22 +08:00
else :
self . rotary_emb . forward = self . rotary_emb . forward_native
2024-08-05 01:40:33 +08:00
2024-12-05 01:50:28 +08:00
self . attn_mqa = RadixAttention (
2024-08-05 01:40:33 +08:00
self . num_local_heads ,
self . kv_lora_rank + self . qk_rope_head_dim ,
self . scaling ,
num_kv_heads = 1 ,
layer_id = layer_id ,
v_head_dim = self . kv_lora_rank ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " attn_mqa " , prefix ) ,
2024-08-05 01:40:33 +08:00
)
2024-12-05 01:50:28 +08:00
self . attn_mha = RadixAttention (
self . num_local_heads ,
self . qk_nope_head_dim + self . qk_rope_head_dim ,
self . scaling ,
num_kv_heads = self . num_local_heads ,
layer_id = layer_id ,
v_head_dim = self . v_head_dim ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " attn_mha " , prefix ) ,
2024-12-05 01:50:28 +08:00
)
2024-08-30 14:45:40 +08:00
self . w_kc = None
self . w_vc = None
2024-09-01 17:28:06 +10:00
self . w_scale = None
2024-08-05 01:40:33 +08:00
2025-03-13 08:23:56 -07:00
self . flashinfer_mla_disable_ragged = global_server_args_dict [
" flashinfer_mla_disable_ragged "
]
2025-03-28 18:30:14 -07:00
self . attention_backend = global_server_args_dict [ " attention_backend " ]
2025-03-13 08:23:56 -07:00
self . rocm_fused_decode_mla = os . getenv ( " SGLANG_ROCM_FUSED_DECODE_MLA " ) == " 1 "
def no_absorb ( self , forward_batch : ForwardBatch ) - > bool :
2025-04-05 01:23:02 -07:00
if self . attention_backend == " flashinfer " :
2025-03-13 08:23:56 -07:00
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
not self . flashinfer_mla_disable_ragged
and forward_batch . forward_mode . is_extend ( )
and not forward_batch . forward_mode . is_target_verify ( )
and not forward_batch . forward_mode . is_draft_extend ( )
2025-03-20 01:02:26 +08:00
and sum ( forward_batch . extend_prefix_lens_cpu ) == 0
2025-03-13 08:23:56 -07:00
)
2025-03-28 18:30:14 -07:00
elif self . attention_backend == " fa3 " :
# Flash Attention: Keep absorbing for all extend/decode
return False
2025-03-13 08:23:56 -07:00
else :
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return (
forward_batch . forward_mode . is_extend ( )
and not forward_batch . forward_mode . is_target_verify ( )
and not forward_batch . forward_mode . is_draft_extend ( )
2025-03-20 01:02:26 +08:00
and sum ( forward_batch . extend_prefix_lens_cpu ) == 0
2025-03-13 08:23:56 -07:00
)
2024-08-05 01:40:33 +08:00
def forward (
self ,
positions : torch . Tensor ,
hidden_states : torch . Tensor ,
2024-09-30 02:41:11 -07:00
forward_batch : ForwardBatch ,
2024-12-05 01:50:28 +08:00
) - > torch . Tensor :
2025-03-13 08:23:56 -07:00
if hidden_states . shape [ 0 ] == 0 :
assert (
not self . o_proj . reduce_results
) , " short-circuiting allreduce will lead to hangs "
return hidden_states
2025-02-24 04:07:25 -08:00
2025-03-13 08:23:56 -07:00
if self . no_absorb ( forward_batch ) :
2025-02-24 04:07:25 -08:00
return self . forward_normal ( positions , hidden_states , forward_batch )
2024-12-05 01:50:28 +08:00
else :
2025-03-11 18:12:56 -07:00
if _is_hip :
2025-02-25 07:44:31 +05:30
if (
2025-03-13 08:23:56 -07:00
self . rocm_fused_decode_mla
2025-02-25 07:44:31 +05:30
and forward_batch . forward_mode . is_decode ( )
) :
return self . forward_absorb_fused_mla_rope (
positions , hidden_states , forward_batch
)
else :
return self . forward_absorb ( positions , hidden_states , forward_batch )
else :
return self . forward_absorb ( positions , hidden_states , forward_batch )
2024-12-05 01:50:28 +08:00
def forward_normal (
self ,
positions : torch . Tensor ,
hidden_states : torch . Tensor ,
forward_batch : ForwardBatch ,
) - > torch . Tensor :
if self . q_lora_rank is not None :
q = self . q_a_proj ( hidden_states ) [ 0 ]
q = self . q_a_layernorm ( q )
q = self . q_b_proj ( q ) [ 0 ] . view ( - 1 , self . num_local_heads , self . qk_head_dim )
else :
q = self . q_proj ( hidden_states ) [ 0 ] . view (
- 1 , self . num_local_heads , self . qk_head_dim
)
_ , q_pe = q . split ( [ self . qk_nope_head_dim , self . qk_rope_head_dim ] , dim = - 1 )
latent_cache = self . kv_a_proj_with_mqa ( hidden_states ) [ 0 ]
kv_a , _ = latent_cache . split ( [ self . kv_lora_rank , self . qk_rope_head_dim ] , dim = - 1 )
latent_cache = latent_cache . unsqueeze ( 1 )
kv_a = self . kv_a_layernorm ( kv_a . contiguous ( ) )
kv = self . kv_b_proj ( kv_a ) [ 0 ]
kv = kv . view ( - 1 , self . num_local_heads , self . qk_nope_head_dim + self . v_head_dim )
k_nope = kv [ . . . , : self . qk_nope_head_dim ]
v = kv [ . . . , self . qk_nope_head_dim : ]
k_pe = latent_cache [ : , : , self . kv_lora_rank : ]
q_pe , k_pe = self . rotary_emb ( positions , q_pe , k_pe )
q [ . . . , self . qk_nope_head_dim : ] = q_pe
k = torch . empty_like ( q )
k [ . . . , : self . qk_nope_head_dim ] = k_nope
k [ . . . , self . qk_nope_head_dim : ] = k_pe
latent_cache [ : , : , : self . kv_lora_rank ] = kv_a . unsqueeze ( 1 )
latent_cache [ : , : , self . kv_lora_rank : ] = k_pe
# Save latent cache
forward_batch . token_to_kv_pool . set_kv_buffer (
self . attn_mha , forward_batch . out_cache_loc , latent_cache , None
)
attn_output = self . attn_mha ( q , k , v , forward_batch , save_kv_cache = False )
attn_output = attn_output . reshape ( - 1 , self . num_local_heads * self . v_head_dim )
output , _ = self . o_proj ( attn_output )
return output
def forward_absorb (
self ,
positions : torch . Tensor ,
hidden_states : torch . Tensor ,
forward_batch : ForwardBatch ,
2024-08-05 01:40:33 +08:00
) - > torch . Tensor :
q_len = hidden_states . shape [ 0 ]
q_input = hidden_states . new_empty (
q_len , self . num_local_heads , self . kv_lora_rank + self . qk_rope_head_dim
)
if self . q_lora_rank is not None :
q = self . q_a_proj ( hidden_states ) [ 0 ]
q = self . q_a_layernorm ( q )
q = self . q_b_proj ( q ) [ 0 ] . view ( - 1 , self . num_local_heads , self . qk_head_dim )
else :
q = self . q_proj ( hidden_states ) [ 0 ] . view (
- 1 , self . num_local_heads , self . qk_head_dim
)
q_nope , q_pe = q . split ( [ self . qk_nope_head_dim , self . qk_rope_head_dim ] , dim = - 1 )
2024-09-01 17:28:06 +10:00
2025-03-12 23:45:52 -07:00
if self . w_kc . dtype == torch . float8_e4m3fnuz :
2024-12-30 05:31:12 -08:00
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
q_nope_out = torch . bmm (
q_nope . to ( torch . bfloat16 ) . transpose ( 0 , 1 ) ,
self . w_kc . to ( torch . bfloat16 ) * self . w_scale ,
)
2025-03-12 23:45:52 -07:00
elif self . w_kc . dtype == torch . float8_e4m3fn :
2024-09-01 17:28:06 +10:00
q_nope_val , q_nope_scale = input_to_float8 (
q_nope . transpose ( 0 , 1 ) , torch . float8_e4m3fn
)
q_nope_out = bmm_fp8 (
q_nope_val , self . w_kc , q_nope_scale , self . w_scale , torch . bfloat16
)
else :
q_nope_out = torch . bmm ( q_nope . transpose ( 0 , 1 ) , self . w_kc )
q_input [ . . . , : self . kv_lora_rank ] = q_nope_out . transpose ( 0 , 1 )
2024-08-05 01:40:33 +08:00
2024-08-13 13:48:07 +08:00
latent_cache = self . kv_a_proj_with_mqa ( hidden_states ) [ 0 ]
v_input = latent_cache [ . . . , : self . kv_lora_rank ]
v_input = self . kv_a_layernorm ( v_input . contiguous ( ) ) . unsqueeze ( 1 )
k_input = latent_cache . unsqueeze ( 1 )
2024-08-05 01:40:33 +08:00
k_input [ . . . , : self . kv_lora_rank ] = v_input
2024-08-13 13:48:07 +08:00
k_pe = k_input [ . . . , self . kv_lora_rank : ]
2024-08-05 01:40:33 +08:00
q_pe , k_pe = self . rotary_emb ( positions , q_pe , k_pe )
q_input [ . . . , self . kv_lora_rank : ] = q_pe
k_input [ . . . , self . kv_lora_rank : ] = k_pe
2024-12-05 01:50:28 +08:00
attn_output = self . attn_mqa ( q_input , k_input , v_input , forward_batch )
2024-08-05 01:40:33 +08:00
attn_output = attn_output . view ( - 1 , self . num_local_heads , self . kv_lora_rank )
2025-03-12 23:45:52 -07:00
if self . w_vc . dtype == torch . float8_e4m3fnuz :
2025-02-25 07:44:31 +05:30
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
attn_bmm_output = torch . bmm (
attn_output . to ( torch . bfloat16 ) . transpose ( 0 , 1 ) ,
self . w_vc . to ( torch . bfloat16 ) * self . w_scale ,
)
2025-03-12 23:45:52 -07:00
elif self . w_vc . dtype == torch . float8_e4m3fn :
2025-02-25 07:44:31 +05:30
attn_output_val , attn_output_scale = input_to_float8 (
attn_output . transpose ( 0 , 1 ) , torch . float8_e4m3fn
)
attn_bmm_output = bmm_fp8 (
attn_output_val ,
self . w_vc ,
attn_output_scale ,
self . w_scale ,
torch . bfloat16 ,
)
else :
attn_bmm_output = torch . bmm ( attn_output . transpose ( 0 , 1 ) , self . w_vc )
attn_output = attn_bmm_output . transpose ( 0 , 1 ) . flatten ( 1 , 2 )
output , _ = self . o_proj ( attn_output )
return output
def forward_absorb_fused_mla_rope (
self ,
positions : torch . Tensor ,
hidden_states : torch . Tensor ,
forward_batch : ForwardBatch ,
) - > torch . Tensor :
enable_rope_fusion = (
os . getenv ( " SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION " , " 1 " ) == " 1 "
)
q_len = hidden_states . shape [ 0 ]
q_input = hidden_states . new_empty (
q_len , self . num_local_heads , self . kv_lora_rank + self . qk_rope_head_dim
)
if self . q_lora_rank is not None :
q = self . q_a_proj ( hidden_states ) [ 0 ]
q = self . q_a_layernorm ( q )
q = self . q_b_proj ( q ) [ 0 ] . view ( - 1 , self . num_local_heads , self . qk_head_dim )
else :
q = self . q_proj ( hidden_states ) [ 0 ] . view (
- 1 , self . num_local_heads , self . qk_head_dim
)
q_nope , q_pe = q . split ( [ self . qk_nope_head_dim , self . qk_rope_head_dim ] , dim = - 1 )
if self . w_kc . dtype == torch . float8_e4m3fnuz :
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
q_nope_out = torch . bmm (
q_nope . to ( torch . bfloat16 ) . transpose ( 0 , 1 ) ,
self . w_kc . to ( torch . bfloat16 ) * self . w_scale ,
)
elif self . w_kc . dtype == torch . float8_e4m3fn :
q_nope_val , q_nope_scale = input_to_float8 (
q_nope . transpose ( 0 , 1 ) , torch . float8_e4m3fn
)
q_nope_out = bmm_fp8 (
q_nope_val , self . w_kc , q_nope_scale , self . w_scale , torch . bfloat16
)
else :
q_nope_out = torch . bmm ( q_nope . transpose ( 0 , 1 ) , self . w_kc )
q_input [ . . . , : self . kv_lora_rank ] = q_nope_out . transpose ( 0 , 1 )
latent_cache = self . kv_a_proj_with_mqa ( hidden_states ) [ 0 ]
v_input = latent_cache [ . . . , : self . kv_lora_rank ]
v_input = self . kv_a_layernorm ( v_input . contiguous ( ) ) . unsqueeze ( 1 )
k_input = latent_cache . unsqueeze ( 1 )
k_input [ . . . , : self . kv_lora_rank ] = v_input
if not enable_rope_fusion :
k_pe = k_input [ . . . , self . kv_lora_rank : ]
q_pe , k_pe = self . rotary_emb ( positions , q_pe , k_pe )
q_input [ . . . , self . kv_lora_rank : ] = q_pe
k_input [ . . . , self . kv_lora_rank : ] = k_pe
k_pe_output = None
else :
k_pe_output = torch . empty_like ( k_input [ . . . , self . kv_lora_rank : ] )
q_input [ . . . , self . kv_lora_rank : ] = q_pe
# attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
# Use Fused ROPE with use_rope=OFF.
attn_output = torch . empty (
( q_len , self . num_local_heads , self . kv_lora_rank ) ,
dtype = q . dtype ,
device = q . device ,
)
attn_logits , _ , kv_indptr , kv_indices , _ , _ , _ = (
forward_batch . attn_backend . forward_metadata
)
cos_sin_cache = self . rotary_emb . cos_sin_cache
num_kv_split = forward_batch . attn_backend . num_kv_splits
sm_scale = self . attn_mqa . scaling
if attn_logits is None :
attn_logits = torch . empty (
(
forward_batch . batch_size ,
self . num_local_heads ,
num_kv_split ,
self . kv_lora_rank + 1 ,
) ,
dtype = torch . float32 ,
device = q . device ,
)
# save current latent cache.
forward_batch . token_to_kv_pool . set_kv_buffer (
self . attn_mqa , forward_batch . out_cache_loc , k_input , None
)
key_cache_buf = forward_batch . token_to_kv_pool . get_key_buffer (
self . attn_mqa . layer_id
)
val_cache_buf = key_cache_buf [ . . . , : self . kv_lora_rank ]
decode_attention_fwd_grouped_rope (
q_input ,
key_cache_buf ,
val_cache_buf ,
attn_output ,
kv_indptr ,
kv_indices ,
k_pe_output ,
self . kv_lora_rank ,
self . rotary_emb . rotary_dim ,
cos_sin_cache ,
positions ,
attn_logits ,
num_kv_split ,
sm_scale ,
logit_cap = self . attn_mqa . logit_cap ,
use_rope = enable_rope_fusion ,
is_neox_style = self . rotary_emb . is_neox_style ,
)
if enable_rope_fusion :
k_input [ . . . , self . kv_lora_rank : ] = k_pe_output
forward_batch . token_to_kv_pool . set_kv_buffer (
self . attn_mqa , forward_batch . out_cache_loc , k_input , None
)
attn_output = attn_output . view ( - 1 , self . num_local_heads , self . kv_lora_rank )
2024-12-30 05:31:12 -08:00
if self . w_vc . dtype == torch . float8_e4m3fnuz :
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
attn_bmm_output = torch . bmm (
attn_output . to ( torch . bfloat16 ) . transpose ( 0 , 1 ) ,
self . w_vc . to ( torch . bfloat16 ) * self . w_scale ,
)
elif self . w_vc . dtype == torch . float8_e4m3fn :
2024-09-01 17:28:06 +10:00
attn_output_val , attn_output_scale = input_to_float8 (
attn_output . transpose ( 0 , 1 ) , torch . float8_e4m3fn
)
attn_bmm_output = bmm_fp8 (
attn_output_val ,
self . w_vc ,
attn_output_scale ,
self . w_scale ,
torch . bfloat16 ,
)
else :
attn_bmm_output = torch . bmm ( attn_output . transpose ( 0 , 1 ) , self . w_vc )
attn_output = attn_bmm_output . transpose ( 0 , 1 ) . flatten ( 1 , 2 )
2024-08-05 01:40:33 +08:00
output , _ = self . o_proj ( attn_output )
return output
2024-07-26 17:10:07 -07:00
class DeepseekV2DecoderLayer ( nn . Module ) :
def __init__ (
self ,
config : PretrainedConfig ,
layer_id : int ,
quant_config : Optional [ QuantizationConfig ] = None ,
2025-02-15 05:28:34 +08:00
is_nextn : bool = False ,
2025-03-05 17:11:00 +08:00
prefix : str = " " ,
2024-07-26 17:10:07 -07:00
) - > None :
2025-03-27 20:09:35 -04:00
def is_sparse_layer ( l : int ) :
return (
config . n_routed_experts is not None
and l > = config . first_k_dense_replace
and l % config . moe_layer_freq == 0
)
2024-07-26 17:10:07 -07:00
super ( ) . __init__ ( )
self . hidden_size = config . hidden_size
rope_theta = getattr ( config , " rope_theta " , 10000 )
rope_scaling = getattr ( config , " rope_scaling " , None )
max_position_embeddings = getattr ( config , " max_position_embeddings " , 8192 )
2025-03-13 08:23:56 -07:00
self . enable_dp_attention = global_server_args_dict [ " enable_dp_attention " ]
self . layer_id = layer_id
self . dp_size = get_attention_dp_size ( )
2025-03-27 20:09:35 -04:00
self . attn_tp_size = get_attention_tp_size ( )
self . attn_tp_rank = get_attention_tp_rank ( )
2025-03-13 08:23:56 -07:00
2024-09-17 19:42:48 +08:00
if not global_server_args_dict [ " disable_mla " ] :
2024-08-05 01:40:33 +08:00
self . self_attn = DeepseekV2AttentionMLA (
config = config ,
hidden_size = self . hidden_size ,
num_heads = config . num_attention_heads ,
qk_nope_head_dim = config . qk_nope_head_dim ,
qk_rope_head_dim = config . qk_rope_head_dim ,
v_head_dim = config . v_head_dim ,
q_lora_rank = (
config . q_lora_rank if hasattr ( config , " q_lora_rank " ) else None
) ,
kv_lora_rank = config . kv_lora_rank ,
rope_theta = rope_theta ,
rope_scaling = rope_scaling ,
max_position_embeddings = max_position_embeddings ,
quant_config = quant_config ,
layer_id = layer_id ,
2025-03-13 08:23:56 -07:00
reduce_results = False ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " self_attn " , prefix ) ,
2024-08-05 01:40:33 +08:00
)
else :
self . self_attn = DeepseekV2Attention (
config = config ,
hidden_size = self . hidden_size ,
num_heads = config . num_attention_heads ,
qk_nope_head_dim = config . qk_nope_head_dim ,
qk_rope_head_dim = config . qk_rope_head_dim ,
v_head_dim = config . v_head_dim ,
q_lora_rank = (
config . q_lora_rank if hasattr ( config , " q_lora_rank " ) else None
) ,
kv_lora_rank = config . kv_lora_rank ,
rope_theta = rope_theta ,
rope_scaling = rope_scaling ,
max_position_embeddings = max_position_embeddings ,
quant_config = quant_config ,
layer_id = layer_id ,
2025-03-13 08:23:56 -07:00
reduce_results = False ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " self_attn " , prefix ) ,
2024-08-05 01:40:33 +08:00
)
2025-03-13 08:23:56 -07:00
2025-03-27 20:09:35 -04:00
if is_nextn or is_sparse_layer ( layer_id ) :
2025-03-05 17:11:00 +08:00
self . mlp = DeepseekV2MoE (
config = config ,
quant_config = quant_config ,
prefix = add_prefix ( " mlp " , prefix ) ,
)
2025-03-27 20:09:35 -04:00
self . is_sparse = True
2024-07-26 17:10:07 -07:00
else :
self . mlp = DeepseekV2MLP (
hidden_size = config . hidden_size ,
intermediate_size = config . intermediate_size ,
hidden_act = config . hidden_act ,
quant_config = quant_config ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( " mlp " , prefix ) ,
2024-07-26 17:10:07 -07:00
)
2025-03-27 20:09:35 -04:00
self . is_sparse = False
self . input_is_scattered = (
is_sparse_layer ( layer_id - 1 )
and global_server_args_dict [ " enable_deepep_moe " ]
)
self . is_last_layer = self . layer_id == config . num_hidden_layers - 1
2024-07-26 17:10:07 -07:00
self . input_layernorm = RMSNorm ( config . hidden_size , eps = config . rms_norm_eps )
self . post_attention_layernorm = RMSNorm (
config . hidden_size , eps = config . rms_norm_eps
)
def forward (
self ,
positions : torch . Tensor ,
hidden_states : torch . Tensor ,
2024-09-30 02:41:11 -07:00
forward_batch : ForwardBatch ,
2024-07-26 17:10:07 -07:00
residual : Optional [ torch . Tensor ] ,
) - > torch . Tensor :
2025-03-27 20:09:35 -04:00
if global_server_args_dict [ " enable_deepep_moe " ] and self . is_sparse :
return self . forward_deepep (
positions , hidden_states , forward_batch , residual
)
else :
return self . forward_normal (
positions , hidden_states , forward_batch , residual
)
def forward_normal (
self ,
positions : torch . Tensor ,
hidden_states : torch . Tensor ,
forward_batch : ForwardBatch ,
residual : Optional [ torch . Tensor ] ,
) - > torch . Tensor :
2025-03-18 16:41:36 -04:00
if hidden_states . shape [ 0 ] == 0 :
2025-03-13 08:23:56 -07:00
residual = hidden_states
else :
2025-03-18 16:41:36 -04:00
if residual is None :
residual = hidden_states
hidden_states = self . input_layernorm ( hidden_states )
else :
hidden_states , residual = self . input_layernorm ( hidden_states , residual )
2025-03-13 08:23:56 -07:00
2025-03-31 11:10:21 +08:00
assert not (
self . attn_tp_size != 1 and self . input_is_scattered
) , " moe_layer_freq > 1 is not supported when attn_tp_size > 1 "
2025-03-18 16:41:36 -04:00
# Self Attention
hidden_states = self . self_attn (
positions = positions ,
hidden_states = hidden_states ,
forward_batch = forward_batch ,
2025-03-13 08:23:56 -07:00
)
# Gather
if get_tensor_model_parallel_world_size ( ) > 1 :
# all gather and all reduce
if self . dp_size != 1 :
2025-03-27 20:09:35 -04:00
if self . attn_tp_rank == 0 :
hidden_states + = residual
hidden_states , local_hidden_states = (
forward_batch . gathered_buffer ,
hidden_states ,
)
dp_gather_partial ( hidden_states , local_hidden_states , forward_batch )
dp_scatter ( residual , hidden_states , forward_batch )
hidden_states = self . post_attention_layernorm ( hidden_states )
2024-11-16 17:01:43 +08:00
else :
2025-03-13 08:23:56 -07:00
hidden_states = tensor_model_parallel_all_reduce ( hidden_states )
2025-03-18 16:41:36 -04:00
hidden_states , residual = self . post_attention_layernorm (
hidden_states , residual
)
else :
hidden_states , residual = self . post_attention_layernorm (
hidden_states , residual
)
2024-07-26 17:10:07 -07:00
# Fully Connected
2025-03-13 08:23:56 -07:00
hidden_states = self . mlp ( hidden_states )
2025-03-18 16:41:36 -04:00
2025-03-27 20:09:35 -04:00
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
2025-03-18 16:41:36 -04:00
# Scatter
if self . dp_size != 1 :
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states , global_hidden_states = (
forward_batch . gathered_buffer [ : forward_batch . input_ids . shape [ 0 ] ] ,
hidden_states ,
)
dp_scatter ( hidden_states , global_hidden_states , forward_batch )
2024-07-26 17:10:07 -07:00
return hidden_states , residual
2025-03-27 20:09:35 -04:00
def forward_deepep (
self ,
positions : torch . Tensor ,
hidden_states : torch . Tensor ,
forward_batch : ForwardBatch ,
residual : Optional [ torch . Tensor ] ,
) - > torch . Tensor :
if hidden_states . shape [ 0 ] == 0 :
residual = hidden_states
else :
if residual is None :
residual = hidden_states
hidden_states = self . input_layernorm ( hidden_states )
else :
hidden_states , residual = self . input_layernorm ( hidden_states , residual )
if self . attn_tp_size != 1 and self . input_is_scattered :
hidden_states , local_hidden_states = (
forward_batch . gathered_buffer [ : forward_batch . input_ids . shape [ 0 ] ] ,
hidden_states ,
)
tp_all_gather (
list ( hidden_states . tensor_split ( self . attn_tp_size ) ) , local_hidden_states
)
# Self Attention
hidden_states = self . self_attn (
positions = positions ,
hidden_states = hidden_states ,
forward_batch = forward_batch ,
)
if self . attn_tp_size != 1 :
if self . input_is_scattered :
tensor_list = list ( hidden_states . tensor_split ( self . attn_tp_size ) )
hidden_states = tensor_list [ self . attn_tp_rank ]
tp_reduce_scatter ( hidden_states , tensor_list )
if hidden_states . shape [ 0 ] != 0 :
hidden_states , residual = self . post_attention_layernorm (
hidden_states , residual
)
else :
if self . attn_tp_rank == 0 :
hidden_states + = residual
tensor_list = list ( hidden_states . tensor_split ( self . attn_tp_size ) )
hidden_states = tensor_list [ self . attn_tp_rank ]
tp_reduce_scatter ( hidden_states , tensor_list )
residual = hidden_states
if hidden_states . shape [ 0 ] != 0 :
hidden_states = self . post_attention_layernorm ( hidden_states )
else :
if hidden_states . shape [ 0 ] != 0 :
hidden_states , residual = self . post_attention_layernorm (
hidden_states , residual
)
hidden_states = self . mlp ( hidden_states , forward_batch . forward_mode )
if self . is_last_layer and self . attn_tp_size != 1 :
2025-03-31 11:10:21 +08:00
hidden_states + = residual
residual = None
2025-03-27 20:09:35 -04:00
hidden_states , local_hidden_states = (
forward_batch . gathered_buffer [ : forward_batch . input_ids . shape [ 0 ] ] ,
hidden_states ,
)
tp_all_gather (
list ( hidden_states . tensor_split ( self . attn_tp_size ) ) , local_hidden_states
)
return hidden_states , residual
2024-07-26 17:10:07 -07:00
class DeepseekV2Model ( nn . Module ) :
fall_back_to_pt_during_load = False
def __init__ (
self ,
config : PretrainedConfig ,
quant_config : Optional [ QuantizationConfig ] = None ,
2025-03-05 17:11:00 +08:00
prefix : str = " " ,
2024-07-26 17:10:07 -07:00
) - > None :
super ( ) . __init__ ( )
self . padding_id = config . pad_token_id
self . vocab_size = config . vocab_size
self . embed_tokens = VocabParallelEmbedding (
config . vocab_size ,
config . hidden_size ,
2024-11-16 17:01:43 +08:00
enable_tp = not global_server_args_dict [ " enable_dp_attention " ] ,
2024-07-26 17:10:07 -07:00
)
self . layers = nn . ModuleList (
[
DeepseekV2DecoderLayer (
config ,
layer_id ,
quant_config = quant_config ,
2025-03-05 17:11:00 +08:00
prefix = add_prefix ( f " layers. { layer_id } " , prefix ) ,
2024-07-26 17:10:07 -07:00
)
for layer_id in range ( config . num_hidden_layers )
]
)
self . norm = RMSNorm ( config . hidden_size , eps = config . rms_norm_eps )
2025-03-13 08:23:56 -07:00
self . dp_size = get_attention_dp_size ( )
2024-07-26 17:10:07 -07:00
def forward (
self ,
input_ids : torch . Tensor ,
positions : torch . Tensor ,
2024-09-30 02:41:11 -07:00
forward_batch : ForwardBatch ,
2025-03-17 14:07:59 +08:00
input_embeds : torch . Tensor = None ,
2024-07-26 17:10:07 -07:00
) - > torch . Tensor :
2025-03-13 08:23:56 -07:00
2025-03-17 14:07:59 +08:00
if input_embeds is None :
hidden_states = self . embed_tokens ( input_ids )
else :
hidden_states = input_embeds
2024-07-26 17:10:07 -07:00
residual = None
for i in range ( len ( self . layers ) ) :
2025-03-24 21:34:19 -07:00
expert_distribution_recorder . set_current_layer ( i )
2024-07-26 17:10:07 -07:00
layer = self . layers [ i ]
hidden_states , residual = layer (
2024-09-30 02:41:11 -07:00
positions , hidden_states , forward_batch , residual
2024-07-26 17:10:07 -07:00
)
2024-11-16 17:01:43 +08:00
if not forward_batch . forward_mode . is_idle ( ) :
2025-03-31 11:10:21 +08:00
if residual is None :
hidden_states = self . norm ( hidden_states )
else :
hidden_states , _ = self . norm ( hidden_states , residual )
2024-07-26 17:10:07 -07:00
return hidden_states
class DeepseekV2ForCausalLM ( nn . Module ) :
def __init__ (
self ,
config : PretrainedConfig ,
quant_config : Optional [ QuantizationConfig ] = None ,
2025-03-05 17:11:00 +08:00
prefix : str = " " ,
2024-07-26 17:10:07 -07:00
) - > None :
super ( ) . __init__ ( )
self . config = config
2025-04-04 16:59:29 +08:00
self . tp_size = get_tensor_model_parallel_world_size ( )
2024-07-26 17:10:07 -07:00
self . quant_config = quant_config
2025-04-04 16:59:29 +08:00
self . n_share_experts_fusion = global_server_args_dict [ " n_share_experts_fusion " ]
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
global_server_args_dict . get ( " disable_shared_experts_fusion " , False )
or self . config . architectures [ 0 ] != " DeepseekV3ForCausalLM "
or self . config . n_routed_experts != 256
or self . config . routed_scaling_factor != 2.5
) :
self . n_share_experts_fusion = None
global_server_args_dict [ " n_share_experts_fusion " ] = None
logger . info (
" Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled. "
)
elif self . n_share_experts_fusion is None :
global_server_args_dict [ " n_share_experts_fusion " ] = self . tp_size
self . n_share_experts_fusion = self . tp_size
logger . info (
f " Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to { self . tp_size } . You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion. "
)
2025-03-05 17:11:00 +08:00
self . model = DeepseekV2Model (
config , quant_config , prefix = add_prefix ( " model " , prefix )
)
2025-03-13 08:23:56 -07:00
self . lm_head = ParallelLMHead (
config . vocab_size ,
config . hidden_size ,
quant_config = quant_config ,
prefix = add_prefix ( " lm_head " , prefix ) ,
)
self . logits_processor = LogitsProcessor ( config )
self . dp_size = get_attention_dp_size ( )
2024-07-26 17:10:07 -07:00
2025-04-01 00:57:51 +08:00
def get_input_embeddings ( self ) - > nn . Embedding :
return self . model . embed_tokens
2024-09-17 15:52:08 +08:00
@torch.no_grad ( )
2024-07-26 17:10:07 -07:00
def forward (
self ,
input_ids : torch . Tensor ,
positions : torch . Tensor ,
2024-09-30 02:41:11 -07:00
forward_batch : ForwardBatch ,
2025-03-17 14:07:59 +08:00
input_embeds : torch . Tensor = None ,
2024-07-26 17:10:07 -07:00
) - > torch . Tensor :
2025-03-17 14:07:59 +08:00
hidden_states = self . model ( input_ids , positions , forward_batch , input_embeds )
2025-03-13 08:23:56 -07:00
2025-01-16 11:15:00 -08:00
return self . logits_processor (
input_ids , hidden_states , self . lm_head , forward_batch
)
2024-07-26 17:10:07 -07:00
2025-04-05 06:22:37 +08:00
def post_load_weights ( self ) :
# Perform post-processing after loading weights
if not global_server_args_dict [ " disable_mla " ] :
for layer_id in range ( self . config . num_hidden_layers ) :
self_attn = self . model . layers [ layer_id ] . self_attn
if hasattr ( self_attn . kv_b_proj , " qweight " ) :
# AWQ compatible
if _is_cuda :
w = awq_dequantize (
self_attn . kv_b_proj . qweight ,
self_attn . kv_b_proj . scales ,
self_attn . kv_b_proj . qzeros ,
) . T
else :
w = ops . awq_dequantize (
self_attn . kv_b_proj . qweight ,
self_attn . kv_b_proj . scales ,
self_attn . kv_b_proj . qzeros ,
0 ,
0 ,
0 ,
) . T
else :
w = self_attn . kv_b_proj . weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if hasattr ( self . quant_config , " weight_block_size " ) and w . dtype in (
torch . float8_e4m3fn ,
torch . float8_e4m3fnuz ,
) :
weight_block_size = self . quant_config . weight_block_size
if weight_block_size is not None :
assert hasattr ( self_attn . kv_b_proj , " weight_scale_inv " )
if _is_hip :
weight , weight_scale , _ = normalize_e4m3fn_to_e4m3fnuz (
weight = w ,
weight_scale = self_attn . kv_b_proj . weight_scale_inv ,
input_scale = None ,
)
else :
weight = w
weight_scale = self_attn . kv_b_proj . weight_scale_inv
w , scale = block_quant_to_tensor_quant (
weight , weight_scale , weight_block_size
)
self_attn . w_scale = scale
if w . dtype == torch . int8 :
if hasattr ( self . quant_config , " weight_block_size " ) :
# block-wise int8 need it
weight_block_size = self . quant_config . weight_block_size
if weight_block_size is not None :
assert hasattr ( self_attn . kv_b_proj , " weight_scale_inv " )
weight = w
weight_scale = self_attn . kv_b_proj . weight_scale_inv
w = int8_block_dequant (
weight , weight_scale , weight_block_size
) . to ( torch . bfloat16 )
else :
# channel-wise int8 need it
w = w . to ( torch . bfloat16 ) * self_attn . kv_b_proj . weight_scale . to (
torch . bfloat16
)
w_kc , w_vc = w . unflatten (
0 , ( - 1 , self_attn . qk_nope_head_dim + self_attn . v_head_dim )
) . split ( [ self_attn . qk_nope_head_dim , self_attn . v_head_dim ] , dim = 1 )
self_attn . w_kc = w_kc . transpose ( 1 , 2 ) . contiguous ( ) . transpose ( 1 , 2 )
self_attn . w_vc = w_vc . contiguous ( ) . transpose ( 1 , 2 )
if (
hasattr ( self_attn . kv_b_proj , " weight_scale " )
and self_attn . w_scale is None
) :
self_attn . w_scale = self_attn . kv_b_proj . weight_scale
if _is_hip :
self_attn . w_scale * = 2.0
2024-07-26 17:10:07 -07:00
def load_weights ( self , weights : Iterable [ Tuple [ str , torch . Tensor ] ] ) :
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
( " gate_up_proj " , " gate_proj " , 0 ) ,
( " gate_up_proj " , " up_proj " , 1 ) ,
]
2025-04-04 16:59:29 +08:00
if self . n_share_experts_fusion is not None and self . n_share_experts_fusion > 0 :
weights_list = list ( weights )
weights_dict = dict ( weights_list )
suffix_list = [
" down_proj.weight " ,
" down_proj.weight_scale_inv " ,
" gate_proj.weight " ,
" gate_proj.weight_scale_inv " ,
" up_proj.weight " ,
" up_proj.weight_scale_inv " ,
]
names_to_remove = [ ]
for moe_layer in tqdm (
range (
self . config . first_k_dense_replace ,
self . config . num_hidden_layers ,
self . config . moe_layer_freq ,
) ,
desc = f " Cloning { self . n_share_experts_fusion } "
" replicas of the shared expert into MoE " ,
) :
for num_repeat in range ( self . n_share_experts_fusion ) :
for suffix in suffix_list :
shared_expert_weight_name = (
f " model.layers. { moe_layer } .mlp.shared_experts. { suffix } "
)
weights_list . append (
(
f " model.layers. { moe_layer } . "
f " mlp.experts. "
f " { self . config . n_routed_experts + num_repeat } "
f " . { suffix } " ,
weights_dict [ shared_expert_weight_name ] . clone ( ) ,
)
)
names_to_remove + = [ shared_expert_weight_name ]
weights = [ w for w in weights_list if w [ 0 ] not in names_to_remove ]
2024-07-26 17:10:07 -07:00
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
2025-03-19 23:16:31 +08:00
MoEImpl = (
DeepEPMoE
if global_server_args_dict [ " enable_deepep_moe " ]
else ( EPMoE if global_server_args_dict [ " enable_ep_moe " ] else FusedMoE )
)
2024-12-06 15:05:21 +08:00
expert_params_mapping = MoEImpl . make_expert_params_mapping (
2024-07-26 17:10:07 -07:00
ckpt_gate_proj_name = " gate_proj " ,
ckpt_down_proj_name = " down_proj " ,
ckpt_up_proj_name = " up_proj " ,
2025-04-04 16:59:29 +08:00
num_experts = self . config . n_routed_experts
+ (
self . n_share_experts_fusion
if self . n_share_experts_fusion is not None
else 0
) ,
2024-07-26 17:10:07 -07:00
)
params_dict = dict ( self . named_parameters ( ) )
for name , loaded_weight in weights :
2024-12-26 00:02:14 +08:00
# TODO(HandH1998): Modify it when nextn is supported.
if hasattr ( self . config , " num_nextn_predict_layers " ) :
num_nextn_layers = self . config . num_nextn_predict_layers
if num_nextn_layers > 0 and name . startswith ( " model.layers " ) :
name_list = name . split ( " . " )
if (
len ( name_list ) > = 3
and int ( name_list [ 2 ] ) > = self . config . num_hidden_layers
) :
continue
2024-07-26 17:10:07 -07:00
if " rotary_emb.inv_freq " in name :
continue
for param_name , weight_name , shard_id in stacked_params_mapping :
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name :
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ( " mlp.experts. " in name ) and name not in params_dict :
continue
name = name . replace ( weight_name , param_name )
# Skip loading extra bias for GPTQ models.
if name . endswith ( " .bias " ) and name not in params_dict :
continue
param = params_dict [ name ]
weight_loader = param . weight_loader
weight_loader ( param , loaded_weight , shard_id )
break
else :
for mapping in expert_params_mapping :
param_name , weight_name , expert_id , shard_id = mapping
if weight_name not in name :
continue
name = name . replace ( weight_name , param_name )
param = params_dict [ name ]
weight_loader = param . weight_loader
weight_loader (
param ,
loaded_weight ,
2024-09-01 00:44:29 +10:00
name ,
2024-07-26 17:10:07 -07:00
shard_id = shard_id ,
expert_id = expert_id ,
)
break
else :
# Skip loading extra bias for GPTQ models.
if name . endswith ( " .bias " ) and name not in params_dict :
continue
param = params_dict [ name ]
weight_loader = getattr (
param , " weight_loader " , default_weight_loader
)
weight_loader ( param , loaded_weight )
2025-04-05 06:22:37 +08:00
self . post_load_weights ( )
2024-08-30 14:45:40 +08:00
2025-03-04 05:30:04 +08:00
def get_embed_and_head ( self ) :
return self . model . embed_tokens . weight , self . lm_head . weight
def set_embed_and_head ( self , embed , head ) :
del self . model . embed_tokens . weight
del self . lm_head . weight
self . model . embed_tokens . weight = embed
self . lm_head . weight = head
torch . cuda . empty_cache ( )
torch . cuda . synchronize ( )
2024-07-26 17:10:07 -07:00
2024-12-26 00:02:14 +08:00
class DeepseekV3ForCausalLM ( DeepseekV2ForCausalLM ) :
pass
EntryClass = [ DeepseekV2ForCausalLM , DeepseekV3ForCausalLM ]