2025-07-28 11:37:17 -07:00
from __future__ import annotations
2024-12-06 15:05:21 +08:00
import logging
2025-10-07 21:51:41 -07:00
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union
2024-12-06 15:05:21 +08:00
import torch
2025-03-19 23:16:31 +08:00
2025-10-18 08:45:54 +08:00
from sglang . srt import single_batch_overlap
2025-10-17 18:57:54 -07:00
from sglang . srt . layers import deep_gemm_wrapper
2025-08-14 21:14:53 -07:00
from sglang . srt . layers . moe import (
get_deepep_mode ,
get_moe_a2a_backend ,
get_moe_runner_backend ,
should_use_flashinfer_trtllm_moe ,
)
2024-12-24 01:10:22 +08:00
from sglang . srt . layers . moe . ep_moe . kernels import (
2025-05-08 16:20:32 +08:00
ep_gather ,
ep_scatter ,
2025-04-02 00:23:25 +08:00
silu_and_mul_masked_post_quant_fwd ,
2025-05-08 16:20:32 +08:00
tma_align_input_scale ,
2024-12-06 15:05:21 +08:00
)
2025-08-04 03:10:02 -07:00
from sglang . srt . layers . moe . fused_moe_triton . layer import FlashInferFusedMoE , FusedMoE
2025-10-20 10:11:46 -07:00
from sglang . srt . layers . moe . topk import TopKOutput
2025-08-01 01:20:03 -07:00
from sglang . srt . layers . quantization . base_config import QuantizationConfig
2025-08-14 21:14:53 -07:00
from sglang . srt . layers . quantization . fp8 import Fp8Config
2025-05-30 17:26:30 -07:00
from sglang . srt . layers . quantization . fp8_kernel import (
2025-06-17 06:29:45 +08:00
is_fp8_fnuz ,
2025-06-14 11:41:03 +08:00
sglang_per_token_group_quant_fp8 ,
2025-05-30 17:26:30 -07:00
)
2025-10-15 11:10:53 +08:00
from sglang . srt . layers . quantization . w4afp8 import W4AFp8Config , W4AFp8MoEMethod
2025-10-02 18:04:36 +08:00
from sglang . srt . single_batch_overlap import DownGemmOverlapArgs
2025-10-07 21:51:41 -07:00
from sglang . srt . utils import ceil_div , dispose_tensor , get_bool_env_var , is_hip , is_npu
2025-10-09 16:46:15 -07:00
from sglang . srt . utils . offloader import get_offloader
2025-03-12 00:08:03 -07:00
2025-07-28 11:37:17 -07:00
if TYPE_CHECKING :
2025-08-01 01:20:03 -07:00
from sglang . srt . layers . moe . token_dispatcher import (
2025-07-28 11:37:17 -07:00
DeepEPLLOutput ,
DeepEPNormalOutput ,
DispatchOutput ,
)
2025-04-16 15:26:49 -07:00
_is_hip = is_hip ( )
2025-07-03 19:23:19 +03:00
_is_npu = is_npu ( )
2025-06-17 06:29:45 +08:00
_is_fp8_fnuz = is_fp8_fnuz ( )
2025-06-24 17:05:47 +08:00
_use_aiter = get_bool_env_var ( " SGLANG_USE_AITER " ) and _is_hip
2025-07-31 19:56:34 -07:00
2025-07-21 17:33:19 -07:00
if not ( _is_npu or _is_hip ) :
2025-07-03 19:23:19 +03:00
from sgl_kernel import silu_and_mul
2025-06-24 17:05:47 +08:00
if _use_aiter :
from aiter import ActivationType , QuantType
from aiter . fused_moe import fused_moe
2024-12-06 15:05:21 +08:00
logger = logging . getLogger ( __name__ )
2025-10-07 21:51:41 -07:00
class DeepEPMoE ( FusedMoE ) :
2024-12-06 15:05:21 +08:00
"""
2025-10-07 21:51:41 -07:00
MoE Expert Parallel Impl based on DeepEP ( https : / / github . com / deepseek - ai / DeepEP / tree / main )
2025-10-15 10:40:54 +08:00
Mooncake EP shares the same class , as they expose the same interface .
2024-12-06 15:05:21 +08:00
"""
2025-10-07 21:51:41 -07:00
_has_printed = False
2024-12-06 15:05:21 +08:00
def __init__ (
self ,
num_experts : int ,
top_k : int ,
hidden_size : int ,
intermediate_size : int ,
2025-05-20 15:31:42 +08:00
layer_id : int ,
2025-07-29 16:02:41 -07:00
num_fused_shared_experts : int = 0 ,
2024-12-06 15:05:21 +08:00
params_dtype : Optional [ torch . dtype ] = None ,
quant_config : Optional [ QuantizationConfig ] = None ,
prefix : str = " " ,
2025-01-27 00:23:37 -08:00
activation : str = " silu " ,
2025-04-20 04:17:35 +08:00
routed_scaling_factor : Optional [ float ] = None ,
2024-12-06 15:05:21 +08:00
) :
2025-07-27 01:00:21 -07:00
super ( ) . __init__ (
num_experts = num_experts ,
2025-10-07 21:51:41 -07:00
top_k = top_k ,
2025-07-27 01:00:21 -07:00
hidden_size = hidden_size ,
intermediate_size = intermediate_size ,
layer_id = layer_id ,
2025-10-07 21:51:41 -07:00
num_fused_shared_experts = num_fused_shared_experts ,
2025-07-27 01:00:21 -07:00
params_dtype = params_dtype ,
quant_config = quant_config ,
prefix = prefix ,
activation = activation ,
routed_scaling_factor = routed_scaling_factor ,
)
2024-12-06 15:05:21 +08:00
2025-07-31 02:34:02 -07:00
if isinstance ( quant_config , Fp8Config ) :
2025-02-26 18:29:37 +08:00
self . use_block_quant = getattr ( self . quant_method , " block_quant " , False )
2025-07-31 02:34:02 -07:00
self . use_fp8_w8a8 = True
2024-12-06 15:05:21 +08:00
self . fp8_dtype = torch . float8_e4m3fn
2025-10-15 11:10:53 +08:00
self . use_w4afp8 = False
elif isinstance ( quant_config , W4AFp8Config ) :
self . use_w4afp8 = True
self . use_fp8_w8a8 = False
self . use_block_quant = False
2025-07-27 01:00:21 -07:00
else :
2025-07-31 02:34:02 -07:00
self . use_fp8_w8a8 = False
self . use_block_quant = False
2025-06-23 16:38:58 +08:00
2025-08-14 21:14:53 -07:00
self . deepep_mode = get_deepep_mode ( )
2025-07-17 12:33:29 +08:00
2025-08-09 16:35:00 +08:00
if self . deepep_mode . enable_low_latency ( ) and not _is_npu :
# NPU supports low_latency deepep without deepgemm
2025-06-14 11:41:03 +08:00
assert (
deep_gemm_wrapper . ENABLE_JIT_DEEPGEMM
) , f " DeepEP { self . deepep_mode } mode requires deep_gemm "
2025-06-24 17:05:47 +08:00
if _use_aiter :
2025-07-27 01:00:21 -07:00
# expert_mask is of size (self.num_local_experts + 1),
2025-06-24 17:05:47 +08:00
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
# self.expert_mask = [1, 1, 1, 1, 0]
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
self . expert_mask = torch . zeros (
2025-07-27 01:00:21 -07:00
( self . num_local_experts + 1 ) ,
2025-06-24 17:05:47 +08:00
device = torch . cuda . current_device ( ) ,
dtype = torch . int ,
)
# the last one is invalid rank_id
self . expert_mask [ : - 1 ] = 1
2025-08-09 16:35:00 +08:00
elif not _is_npu :
2025-06-24 17:05:47 +08:00
self . w13_weight_fp8 = (
self . w13_weight ,
(
self . w13_weight_scale_inv
2025-10-15 11:10:53 +08:00
if self . use_block_quant or self . use_w4afp8
2025-06-24 17:05:47 +08:00
else self . w13_weight_scale
) ,
)
self . w2_weight_fp8 = (
self . w2_weight ,
(
self . w2_weight_scale_inv
2025-10-15 11:10:53 +08:00
if self . use_block_quant or self . use_w4afp8
2025-06-24 17:05:47 +08:00
else self . w2_weight_scale
) ,
)
2025-03-19 23:16:31 +08:00
def forward (
self ,
hidden_states : torch . Tensor ,
2025-10-20 10:11:46 -07:00
topk_output : TopKOutput ,
2025-10-18 08:45:54 +08:00
forward_shared_experts = None ,
alt_stream = None ,
2025-10-19 16:10:44 +08:00
disable_sbo = False ,
2025-03-19 23:16:31 +08:00
) :
2025-10-20 10:11:46 -07:00
2025-10-18 08:45:54 +08:00
# We have to call SBO inside MoE to be compatible with hooks used in offloading
return single_batch_overlap . execute_sbo (
hidden_states = hidden_states ,
2025-10-20 10:11:46 -07:00
topk_output = topk_output ,
2025-10-18 08:45:54 +08:00
# SBO args
experts = self ,
forward_shared_experts = forward_shared_experts ,
alt_stream = alt_stream ,
2025-10-19 16:10:44 +08:00
disable_sbo = disable_sbo ,
2025-07-28 11:37:17 -07:00
)
def dispatch (
self ,
hidden_states : torch . Tensor ,
2025-10-20 10:11:46 -07:00
topk_output : TopKOutput ,
2025-07-28 11:37:17 -07:00
) :
2025-10-20 10:11:46 -07:00
return self . dispatcher . dispatch (
2025-07-28 11:37:17 -07:00
hidden_states = hidden_states ,
2025-10-20 10:11:46 -07:00
topk_output = topk_output ,
2025-07-28 11:37:17 -07:00
)
2025-10-20 10:11:46 -07:00
def run_moe_core (
2025-10-02 18:04:36 +08:00
self ,
dispatch_output : DispatchOutput ,
down_gemm_overlap_args : Optional [ DownGemmOverlapArgs ] = None ,
) :
2025-08-14 21:14:53 -07:00
from sglang . srt . layers . moe . token_dispatcher import DispatchOutputChecker
2025-06-24 17:05:47 +08:00
if _use_aiter :
2025-08-14 21:14:53 -07:00
assert DispatchOutputChecker . format_is_deepep ( dispatch_output )
2025-06-24 17:05:47 +08:00
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
2025-07-28 11:37:17 -07:00
return self . forward_aiter ( dispatch_output )
2025-08-09 16:35:00 +08:00
if _is_npu :
2025-09-11 11:35:26 +08:00
assert DispatchOutputChecker . format_is_deepep ( dispatch_output )
2025-08-09 16:35:00 +08:00
return self . forward_npu ( dispatch_output )
2025-08-14 21:14:53 -07:00
if DispatchOutputChecker . format_is_deepep_normal ( dispatch_output ) :
2025-10-15 11:10:53 +08:00
if self . use_w4afp8 :
return self . forward_cutlass_w4afp8 ( dispatch_output )
2025-07-31 02:34:02 -07:00
assert deep_gemm_wrapper . ENABLE_JIT_DEEPGEMM and self . use_fp8_w8a8
return self . forward_deepgemm_contiguous ( dispatch_output )
2025-08-14 21:14:53 -07:00
elif DispatchOutputChecker . format_is_deepep_ll ( dispatch_output ) :
2025-09-11 22:18:43 -05:00
if get_moe_runner_backend ( ) . is_flashinfer_cutedsl ( ) :
2025-10-02 18:04:36 +08:00
return self . forward_flashinfer_cutedsl (
dispatch_output , down_gemm_overlap_args = down_gemm_overlap_args
)
2025-07-31 02:34:02 -07:00
assert deep_gemm_wrapper . ENABLE_JIT_DEEPGEMM and self . use_fp8_w8a8
2025-07-28 11:37:17 -07:00
return self . forward_deepgemm_masked ( dispatch_output )
2025-03-19 23:16:31 +08:00
else :
2025-07-31 02:34:02 -07:00
raise ValueError (
f " Dispatch output format { dispatch_output . format } is not supported "
)
2025-03-19 23:16:31 +08:00
2025-07-28 11:37:17 -07:00
def combine (
self ,
hidden_states : torch . Tensor ,
2025-10-20 10:11:46 -07:00
topk_ids : torch . Tensor ,
2025-07-28 11:37:17 -07:00
topk_weights : torch . Tensor ,
2025-10-02 18:04:36 +08:00
overlap_args : Optional [ Dict [ str , Any ] ] = None ,
2025-07-28 11:37:17 -07:00
) :
2025-10-20 10:11:46 -07:00
return self . dispatcher . combine (
2025-07-28 11:37:17 -07:00
hidden_states = hidden_states ,
2025-10-20 10:11:46 -07:00
topk_ids = topk_ids ,
2025-07-28 11:37:17 -07:00
topk_weights = topk_weights ,
2025-10-02 18:04:36 +08:00
overlap_args = overlap_args ,
2025-07-28 11:37:17 -07:00
)
2025-06-24 17:05:47 +08:00
def forward_aiter (
self ,
2025-08-14 21:14:53 -07:00
dispatch_output : Union [ DeepEPNormalOutput , DeepEPLLOutput ] ,
2025-06-24 17:05:47 +08:00
) :
2025-10-20 10:11:46 -07:00
hidden_states , topk_ids , topk_weights = (
2025-07-28 11:37:17 -07:00
dispatch_output . hidden_states ,
2025-10-20 10:11:46 -07:00
dispatch_output . topk_ids ,
2025-07-28 11:37:17 -07:00
dispatch_output . topk_weights ,
)
2025-06-24 17:05:47 +08:00
if hidden_states . shape [ 0 ] == 0 :
return hidden_states
# in original deepep, idx == -1 meaning invalid and will not be processed.
# aiter does not accept -1, we use a expert mask to make these idx invalid
2025-07-27 01:00:21 -07:00
# (idx == num_local_experts) meaning not used in aiter fused_moe
2025-10-20 10:11:46 -07:00
topk_ids_copy = topk_ids . to ( torch . int32 )
topk_ids_copy [ topk_ids_copy == - 1 ] = self . num_local_experts
2025-06-24 17:05:47 +08:00
return fused_moe (
hidden_states ,
self . w13_weight ,
self . w2_weight ,
topk_weights ,
2025-10-20 10:11:46 -07:00
topk_ids_copy ,
2025-06-24 17:05:47 +08:00
w1_scale = self . w13_weight_scale_inv ,
w2_scale = self . w2_weight_scale_inv ,
quant_type = QuantType . per_128x128 ,
activation = (
ActivationType . Silu
2025-08-14 21:14:53 -07:00
if self . moe_runner_config . activation == " silu "
2025-06-24 17:05:47 +08:00
else ActivationType . Gelu
) ,
expert_mask = self . expert_mask ,
)
2025-05-08 16:20:32 +08:00
def forward_deepgemm_contiguous (
self ,
2025-07-28 11:37:17 -07:00
dispatch_output : DeepEPNormalOutput ,
2025-05-08 16:20:32 +08:00
) :
2025-10-20 10:11:46 -07:00
(
hidden_states ,
hidden_states_scale ,
topk_ids ,
topk_weights ,
num_recv_tokens_per_expert ,
) = dispatch_output
2025-05-08 16:20:32 +08:00
assert self . quant_method is not None
2025-08-14 21:14:53 -07:00
assert self . moe_runner_config . activation == " silu "
2025-05-08 16:20:32 +08:00
if num_recv_tokens_per_expert is None :
2025-10-20 10:11:46 -07:00
return hidden_states . bfloat16 ( )
2025-05-08 16:20:32 +08:00
all_tokens = sum ( num_recv_tokens_per_expert )
if all_tokens < = 0 :
2025-10-20 10:11:46 -07:00
return hidden_states . bfloat16 ( )
M , K = hidden_states . size ( )
2025-05-08 16:20:32 +08:00
N = self . w13_weight . size ( 1 )
scale_block_size = 128
2025-09-14 16:14:28 +08:00
w13_weight_fp8 = (
self . w13_weight ,
(
self . w13_weight_scale_inv
if self . use_block_quant
else self . w13_weight_scale
) ,
)
w2_weight_fp8 = (
self . w2_weight ,
(
self . w2_weight_scale_inv
if self . use_block_quant
else self . w2_weight_scale
) ,
)
2025-10-20 10:11:46 -07:00
hidden_states_shape = hidden_states . shape
hidden_states_device = hidden_states . device
hidden_states_dtype = hidden_states . dtype
2025-05-08 16:20:32 +08:00
input_tensor = [
torch . empty (
( all_tokens , K ) ,
2025-10-20 10:11:46 -07:00
device = hidden_states . device ,
dtype = hidden_states . dtype ,
2025-05-08 16:20:32 +08:00
) ,
2025-06-28 16:45:30 +08:00
(
# TODO check whether need `zeros`
torch . zeros (
( ceil_div ( K / / 128 , 4 ) , all_tokens ) ,
2025-10-20 10:11:46 -07:00
device = hidden_states . device ,
2025-06-28 16:45:30 +08:00
dtype = torch . int ,
) . transpose ( 0 , 1 )
if deep_gemm_wrapper . DEEPGEMM_SCALE_UE8M0
else torch . empty (
( all_tokens , K / / 128 ) ,
2025-10-20 10:11:46 -07:00
device = hidden_states . device ,
2025-06-28 16:45:30 +08:00
dtype = torch . float32 ,
)
2025-05-08 16:20:32 +08:00
) ,
]
m_indices = torch . empty (
2025-10-20 10:11:46 -07:00
all_tokens , device = hidden_states . device , dtype = torch . int32
2025-05-08 16:20:32 +08:00
)
2025-10-20 10:11:46 -07:00
output_index = torch . empty_like ( topk_ids )
2025-05-08 16:20:32 +08:00
2025-09-14 16:14:28 +08:00
if get_offloader ( ) . forbid_copy_engine_usage :
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce (
num_recv_tokens_per_expert
)
else :
num_recv_tokens_per_expert_gpu = torch . tensor (
num_recv_tokens_per_expert ,
dtype = torch . int32 ,
pin_memory = True ,
device = " cpu " ,
) . cuda ( non_blocking = True )
2025-05-08 16:20:32 +08:00
expert_start_loc = torch . empty_like ( num_recv_tokens_per_expert_gpu )
ep_scatter (
2025-10-20 10:11:46 -07:00
hidden_states ,
2025-05-08 16:20:32 +08:00
hidden_states_scale ,
2025-10-20 10:11:46 -07:00
topk_ids ,
2025-05-08 16:20:32 +08:00
num_recv_tokens_per_expert_gpu ,
expert_start_loc ,
input_tensor [ 0 ] ,
input_tensor [ 1 ] ,
m_indices ,
output_index ,
2025-06-28 16:45:30 +08:00
scale_ue8m0 = deep_gemm_wrapper . DEEPGEMM_SCALE_UE8M0 ,
2025-05-08 16:20:32 +08:00
)
2025-10-20 10:11:46 -07:00
dispose_tensor ( hidden_states )
2025-05-08 16:20:32 +08:00
gateup_output = torch . empty (
( all_tokens , N ) ,
2025-10-20 10:11:46 -07:00
device = hidden_states_device ,
2025-05-08 16:20:32 +08:00
dtype = torch . bfloat16 ,
)
2025-06-28 16:45:30 +08:00
if not deep_gemm_wrapper . DEEPGEMM_SCALE_UE8M0 :
input_tensor [ 1 ] = tma_align_input_scale ( input_tensor [ 1 ] )
2025-06-14 11:41:03 +08:00
deep_gemm_wrapper . grouped_gemm_nt_f8f8bf16_contig (
2025-09-14 16:14:28 +08:00
input_tensor , w13_weight_fp8 , gateup_output , m_indices
2025-05-08 16:20:32 +08:00
)
2025-05-16 00:38:28 +08:00
del input_tensor
2025-05-08 16:20:32 +08:00
down_input = torch . empty (
(
all_tokens ,
N / / 2 ,
) ,
device = gateup_output . device ,
dtype = torch . bfloat16 ,
)
silu_and_mul ( gateup_output . view ( - 1 , N ) , down_input )
2025-05-16 00:38:28 +08:00
del gateup_output
2025-05-08 16:20:32 +08:00
down_output = torch . empty (
( all_tokens , K ) ,
2025-10-20 10:11:46 -07:00
device = hidden_states_device ,
2025-05-08 16:20:32 +08:00
dtype = torch . bfloat16 ,
)
down_input_fp8 , down_input_scale = sglang_per_token_group_quant_fp8 (
2025-06-28 16:45:30 +08:00
down_input ,
scale_block_size ,
column_major_scales = deep_gemm_wrapper . DEEPGEMM_SCALE_UE8M0 ,
scale_tma_aligned = deep_gemm_wrapper . DEEPGEMM_SCALE_UE8M0 ,
scale_ue8m0 = deep_gemm_wrapper . DEEPGEMM_SCALE_UE8M0 ,
2025-05-08 16:20:32 +08:00
)
2025-05-16 00:38:28 +08:00
del down_input
2025-06-28 16:45:30 +08:00
if not deep_gemm_wrapper . DEEPGEMM_SCALE_UE8M0 :
down_input_scale = tma_align_input_scale ( down_input_scale )
2025-06-14 11:41:03 +08:00
deep_gemm_wrapper . grouped_gemm_nt_f8f8bf16_contig (
2025-05-08 16:20:32 +08:00
( down_input_fp8 , down_input_scale ) ,
2025-09-14 16:14:28 +08:00
w2_weight_fp8 ,
2025-05-08 16:20:32 +08:00
down_output ,
m_indices ,
)
2025-05-16 00:38:28 +08:00
del down_input_fp8 , down_input_scale
2025-05-08 16:20:32 +08:00
2025-05-16 00:38:28 +08:00
gather_out = torch . empty (
2025-10-20 10:11:46 -07:00
hidden_states_shape ,
device = hidden_states_device ,
2025-05-16 00:38:28 +08:00
dtype = torch . bfloat16 ,
)
2025-10-20 10:11:46 -07:00
ep_gather ( down_output , topk_ids , topk_weights , output_index , gather_out )
2025-05-08 16:20:32 +08:00
return gather_out
2025-09-11 22:18:43 -05:00
def forward_flashinfer_cutedsl (
self ,
dispatch_output : DeepEPLLOutput ,
2025-10-02 18:04:36 +08:00
down_gemm_overlap_args : Optional [ DownGemmOverlapArgs ] ,
2025-09-11 22:18:43 -05:00
) :
2025-10-20 10:11:46 -07:00
hidden_states , hidden_states_scale , _ , _ , masked_m , _ = dispatch_output
2025-09-11 22:18:43 -05:00
assert self . quant_method is not None
assert self . moe_runner_config . activation == " silu "
output = self . quant_method . apply_without_routing_weights (
layer = self ,
2025-10-20 10:11:46 -07:00
x = ( hidden_states , hidden_states_scale ) ,
2025-09-11 22:18:43 -05:00
masked_m = masked_m ,
moe_runner_config = self . moe_runner_config ,
2025-10-02 18:04:36 +08:00
down_gemm_overlap_args = down_gemm_overlap_args ,
2025-09-11 22:18:43 -05:00
)
return output
2025-10-15 11:10:53 +08:00
def forward_cutlass_w4afp8 (
self ,
dispatch_output : DeepEPNormalOutput ,
) :
assert self . moe_runner_config . activation == " silu "
assert isinstance ( self . quant_method , W4AFp8MoEMethod )
return self . quant_method . apply_deepep_normal (
layer = self ,
dispatch_output = dispatch_output ,
)
2025-03-19 23:16:31 +08:00
def forward_deepgemm_masked (
self ,
2025-07-28 11:37:17 -07:00
dispatch_output : DeepEPLLOutput ,
2025-03-19 23:16:31 +08:00
) :
2025-10-20 10:11:46 -07:00
hidden_states , hidden_states_scale , _ , _ , masked_m , expected_m = dispatch_output
2025-03-19 23:16:31 +08:00
assert self . quant_method is not None
2025-08-14 21:14:53 -07:00
assert self . moe_runner_config . activation == " silu "
2025-10-20 10:11:46 -07:00
assert (
hidden_states_scale . dtype == torch . float32
) , f " hidden_states_scale.dtype: { hidden_states_scale . dtype } "
2025-03-19 23:16:31 +08:00
# GroupGemm-0
2025-10-20 10:11:46 -07:00
num_groups , m , k = hidden_states . size ( )
2025-04-02 00:23:25 +08:00
n = self . w13_weight . size ( 1 )
expected_m = min ( expected_m , m )
2025-03-19 23:16:31 +08:00
gateup_output = torch . empty (
2025-10-20 10:11:46 -07:00
( num_groups , m , n ) , device = hidden_states . device , dtype = torch . bfloat16
2025-04-02 00:23:25 +08:00
)
2025-06-14 11:41:03 +08:00
deep_gemm_wrapper . grouped_gemm_nt_f8f8bf16_masked (
2025-10-20 10:11:46 -07:00
( hidden_states , hidden_states_scale ) ,
2025-06-14 11:41:03 +08:00
self . w13_weight_fp8 ,
gateup_output ,
masked_m ,
expected_m ,
2025-03-19 23:16:31 +08:00
)
2025-10-20 10:11:46 -07:00
dispose_tensor ( hidden_states )
2025-03-19 23:16:31 +08:00
# Act
down_input = torch . empty (
2025-04-02 00:23:25 +08:00
(
gateup_output . shape [ 0 ] ,
gateup_output . shape [ 1 ] ,
gateup_output . shape [ 2 ] / / 2 ,
2025-03-19 23:16:31 +08:00
) ,
2025-04-02 00:23:25 +08:00
device = gateup_output . device ,
dtype = self . fp8_dtype ,
2025-03-19 23:16:31 +08:00
)
2025-04-02 00:23:25 +08:00
scale_block_size = 128
down_input_scale = torch . empty (
(
gateup_output . shape [ 0 ] ,
2025-03-19 23:16:31 +08:00
gateup_output . shape [ 1 ] ,
2025-04-02 00:23:25 +08:00
gateup_output . shape [ 2 ] / / 2 / / scale_block_size ,
) ,
device = gateup_output . device ,
dtype = torch . float32 ,
)
silu_and_mul_masked_post_quant_fwd (
gateup_output ,
down_input ,
down_input_scale ,
scale_block_size ,
masked_m ,
2025-06-14 14:00:17 +08:00
scale_ue8m0 = deep_gemm_wrapper . DEEPGEMM_SCALE_UE8M0 ,
2025-04-02 00:23:25 +08:00
)
2025-05-16 00:38:28 +08:00
del gateup_output
2025-03-19 23:16:31 +08:00
# GroupGemm-1
2025-04-02 00:23:25 +08:00
n = self . w2_weight . size ( 1 )
down_input_fp8 = (
down_input ,
2025-06-14 14:00:17 +08:00
(
down_input_scale
if deep_gemm_wrapper . DEEPGEMM_SCALE_UE8M0
2025-08-28 03:01:30 +08:00
else deep_gemm_wrapper . get_mn_major_tma_aligned_tensor ( down_input_scale )
2025-06-14 14:00:17 +08:00
) ,
2025-04-02 00:23:25 +08:00
)
2025-03-19 23:16:31 +08:00
down_output = torch . empty (
2025-04-02 00:23:25 +08:00
( num_groups , m , n ) , device = down_input . device , dtype = torch . bfloat16
)
2025-06-14 11:41:03 +08:00
deep_gemm_wrapper . grouped_gemm_nt_f8f8bf16_masked (
down_input_fp8 ,
self . w2_weight_fp8 ,
down_output ,
masked_m ,
expected_m ,
2025-03-19 23:16:31 +08:00
)
return down_output
2025-05-17 10:06:03 +08:00
2025-08-09 16:35:00 +08:00
def forward_npu (
self ,
2025-09-11 11:35:26 +08:00
dispatch_output : Union [ DeepEPNormalOutput , DeepEPLLOutput ] ,
2025-08-09 16:35:00 +08:00
) :
assert self . quant_method is not None
2025-08-20 01:09:48 +08:00
assert self . moe_runner_config . activation == " silu "
2025-08-09 16:35:00 +08:00
2025-09-11 11:35:26 +08:00
import torch_npu
from sglang . srt . layers . moe . token_dispatcher import DispatchOutputChecker
2025-08-09 16:35:00 +08:00
# NOTE: Ascend's Dispatch & Combine does not support FP16
output_dtype = torch . bfloat16
2025-09-11 11:35:26 +08:00
group_list_type = 1
2025-08-09 16:35:00 +08:00
2025-09-11 11:35:26 +08:00
def _forward_normal ( dispatch_output : DeepEPNormalOutput ) :
if TYPE_CHECKING :
assert isinstance ( dispatch_output , DeepEPNormalOutput )
2025-10-20 10:11:46 -07:00
hidden_states , hidden_states_scale , _ , _ , num_recv_tokens_per_expert = (
dispatch_output
)
2025-08-09 16:35:00 +08:00
2025-09-11 11:35:26 +08:00
group_list = torch . tensor ( num_recv_tokens_per_expert , dtype = torch . int64 ) . to (
hidden_states . device
)
2025-10-06 15:24:15 +08:00
if self . w13_weight . dtype != torch . int8 :
# gmm1: gate_up_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ self . w13_weight . permute ( 0 , 2 , 1 ) ] ,
2025-10-20 10:11:46 -07:00
# per_token_scale=[hidden_states_scale],
2025-10-06 15:24:15 +08:00
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = output_dtype ,
) [ 0 ]
hidden_states = torch_npu . npu_swiglu ( hidden_states )
# gmm2: down_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ self . w2_weight . permute ( 0 , 2 , 1 ) ] ,
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = output_dtype ,
) [ 0 ]
else :
if not get_bool_env_var ( " DEEP_NORMAL_MODE_USE_INT8_QUANT " ) :
2025-10-20 10:11:46 -07:00
hidden_states , hidden_states_scale = torch_npu . npu_dynamic_quant (
2025-10-06 15:24:15 +08:00
hidden_states
)
# gmm1: gate_up_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ self . w13_weight ] ,
scale = [ self . w13_weight_scale . to ( output_dtype ) ] ,
2025-10-20 10:11:46 -07:00
per_token_scale = [ hidden_states_scale ] ,
2025-10-06 15:24:15 +08:00
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = output_dtype ,
) [ 0 ]
# act_fn: swiglu
hidden_states = torch_npu . npu_swiglu ( hidden_states )
hidden_states , swiglu_out_scale = torch_npu . npu_dynamic_quant (
hidden_states
)
2025-08-09 16:35:00 +08:00
2025-10-06 15:24:15 +08:00
# gmm2: down_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ self . w2_weight ] ,
scale = [ self . w2_weight_scale . to ( output_dtype ) ] ,
per_token_scale = [ swiglu_out_scale ] ,
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = output_dtype ,
) [ 0 ]
2025-08-09 16:35:00 +08:00
2025-09-11 11:35:26 +08:00
return hidden_states
2025-08-09 16:35:00 +08:00
2025-09-11 11:35:26 +08:00
def _forward_ll ( dispatch_output : DeepEPLLOutput ) :
if TYPE_CHECKING :
assert isinstance ( dispatch_output , DeepEPLLOutput )
2025-10-20 10:11:46 -07:00
(
hidden_states ,
hidden_states_scale ,
topk_ids ,
topk_weights ,
group_list ,
_ ,
) = dispatch_output
2025-09-11 11:35:26 +08:00
group_list = group_list . to ( torch . int64 )
2025-10-06 15:24:15 +08:00
if self . w13_weight . dtype != torch . int8 :
# gmm1: gate_up_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ self . w13_weight . permute ( 0 , 2 , 1 ) ] ,
2025-10-20 10:11:46 -07:00
# per_token_scale=[hidden_states_scale],
2025-10-06 15:24:15 +08:00
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = output_dtype ,
) [ 0 ]
hidden_states = torch_npu . npu_swiglu ( hidden_states )
# gmm2: down_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ self . w2_weight . permute ( 0 , 2 , 1 ) ] ,
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = output_dtype ,
) [ 0 ]
else :
# gmm1: gate_up_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ self . w13_weight ] ,
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = torch . int32 ,
) [ 0 ]
# act_fn: swiglu
hidden_states , swiglu_out_scale = torch_npu . npu_dequant_swiglu_quant (
x = hidden_states ,
weight_scale = self . w13_weight_scale . to ( torch . float32 ) ,
2025-10-20 10:11:46 -07:00
activation_scale = hidden_states_scale ,
2025-10-06 15:24:15 +08:00
bias = None ,
quant_scale = None ,
quant_offset = None ,
group_index = group_list ,
activate_left = True ,
quant_mode = 1 ,
)
2025-09-11 11:35:26 +08:00
2025-10-06 15:24:15 +08:00
# gmm2: down_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ self . w2_weight ] ,
scale = [ self . w2_weight_scale . to ( output_dtype ) ] ,
per_token_scale = [ swiglu_out_scale ] ,
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = output_dtype ,
) [ 0 ]
2025-09-11 11:35:26 +08:00
return hidden_states
if DispatchOutputChecker . format_is_deepep_normal ( dispatch_output ) :
return _forward_normal ( dispatch_output )
elif DispatchOutputChecker . format_is_deepep_ll ( dispatch_output ) :
return _forward_ll ( dispatch_output )
else :
raise ValueError ( f " Not Supported DeepEP format { dispatch_output . format } " )
2025-08-09 16:35:00 +08:00
2025-05-17 10:06:03 +08:00
2025-09-14 19:16:25 -07:00
def get_moe_impl_class ( quant_config : Optional [ QuantizationConfig ] ) :
2025-10-15 10:40:54 +08:00
if get_moe_a2a_backend ( ) . is_deepep ( ) or get_moe_a2a_backend ( ) . is_mooncake ( ) :
2025-05-17 10:06:03 +08:00
return DeepEPMoE
2025-08-04 03:10:02 -07:00
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
2025-08-14 21:14:53 -07:00
if get_moe_runner_backend ( ) . is_flashinfer_trtllm ( ) :
2025-08-21 00:54:01 -07:00
# FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
# If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
if quant_config is None :
return FusedMoE
2025-08-04 03:10:02 -07:00
try :
# Check the quantization argument directly
2025-09-14 19:16:25 -07:00
if quant_config is not None and quant_config . get_name ( ) == " modelopt_fp4 " :
2025-08-04 03:10:02 -07:00
from sglang . srt . layers . moe . fused_moe_triton . layer import (
FlashInferFP4MoE ,
)
return FlashInferFP4MoE
except :
pass
2025-09-14 19:16:25 -07:00
if should_use_flashinfer_trtllm_moe ( ) and quant_config is not None :
# FIXME: FlashInferFusedMoE only supports fp8 quant now
2025-08-04 16:30:13 -07:00
return FlashInferFusedMoE
2025-08-14 21:14:53 -07:00
if get_moe_runner_backend ( ) . is_flashinfer_cutlass ( ) :
2025-06-22 13:38:47 -07:00
return FusedMoE
2025-08-04 16:30:13 -07:00
return FusedMoE
2025-09-14 16:14:28 +08:00
def copy_list_to_gpu_no_ce ( arr : List [ int ] ) :
from sgl_kernel . elementwise import copy_to_gpu_no_ce
tensor_cpu = torch . tensor ( arr , dtype = torch . int32 , device = " cpu " )
tensor_gpu = torch . empty_like ( tensor_cpu , device = " cuda " )
copy_to_gpu_no_ce ( tensor_cpu , tensor_gpu )
return tensor_gpu