Files
sglang/python/sglang/srt/layers/moe/ep_moe/layer.py

750 lines
27 KiB
Python

from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
from sglang.srt import single_batch_overlap
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
sglang_per_token_group_quant_fp8,
)
from sglang.srt.layers.quantization.modelopt_quant import (
CUTEDSL_MOE_NVFP4_DISPATCH,
ModelOptNvFp4FusedMoEMethod,
)
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
from sglang.srt.utils.offloader import get_offloader
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
DeepEPLLOutput,
DeepEPNormalOutput,
DispatchOutput,
)
_is_hip = is_hip()
_is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not (_is_npu or _is_hip):
from sgl_kernel import silu_and_mul
if _use_aiter:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
logger = logging.getLogger(__name__)
class DeepEPMoE(FusedMoE):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
Mooncake EP shares the same class, as they expose the same interface.
"""
_has_printed = False
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
):
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
layer_id=layer_id,
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
activation=activation,
routed_scaling_factor=routed_scaling_factor,
)
if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
self.use_w4afp8 = False
elif isinstance(quant_config, W4AFp8Config):
self.use_w4afp8 = True
self.use_fp8_w8a8 = False
self.use_block_quant = False
else:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.deepep_mode = get_deepep_mode()
# TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
deepep_mode=self.deepep_mode,
async_finish=True, # TODO
return_recv_hook=True,
)
if self.deepep_mode.enable_low_latency() and not _is_npu:
# NPU supports low_latency deepep without deepgemm
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if _use_aiter:
# expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
# self.expert_mask = [1, 1, 1, 1, 0]
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
self.expert_mask = torch.zeros(
(self.num_local_experts + 1),
device=torch.cuda.current_device(),
dtype=torch.int,
)
# the last one is invalid rank_id
self.expert_mask[:-1] = 1
elif not _is_npu:
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant or self.use_w4afp8
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant or self.use_w4afp8
else self.w2_weight_scale
),
)
def forward(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
forward_shared_experts=None,
alt_stream=None,
disable_sbo=False,
):
# We have to call SBO inside MoE to be compatible with hooks used in offloading
return single_batch_overlap.execute_sbo(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
# SBO args
experts=self,
forward_shared_experts=forward_shared_experts,
alt_stream=alt_stream,
disable_sbo=disable_sbo,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
):
return self.deepep_dispatcher.dispatch(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
input_global_scale=(
self.w13_input_scale_quant
if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
and self.quant_method.enable_flashinfer_cutedsl_moe
and CUTEDSL_MOE_NVFP4_DISPATCH
else None
),
)
def moe_impl(
self,
dispatch_output: DispatchOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
):
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
if _use_aiter:
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output)
if _is_npu:
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
return self.forward_npu(dispatch_output)
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
if self.use_w4afp8:
return self.forward_cutlass_w4afp8(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if get_moe_runner_backend().is_flashinfer_cutedsl():
return self.forward_flashinfer_cutedsl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_masked(dispatch_output)
else:
raise ValueError(
f"Dispatch output format {dispatch_output.format} is not supported"
)
def combine(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
forward_batch: ForwardBatch,
overlap_args: Optional[Dict[str, Any]] = None,
):
return self.deepep_dispatcher.combine(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
overlap_args=overlap_args,
)
def forward_aiter(
self,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
):
hidden_states, topk_idx, topk_weights = (
dispatch_output.hidden_states,
dispatch_output.topk_idx,
dispatch_output.topk_weights,
)
if hidden_states.shape[0] == 0:
return hidden_states
# in original deepep, idx == -1 meaning invalid and will not be processed.
# aiter does not accept -1, we use a expert mask to make these idx invalid
# (idx == num_local_experts) meaning not used in aiter fused_moe
topk_idx_copy = topk_idx.to(torch.int32)
topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
return fused_moe(
hidden_states,
self.w13_weight,
self.w2_weight,
topk_weights,
topk_idx_copy,
w1_scale=self.w13_weight_scale_inv,
w2_scale=self.w2_weight_scale_inv,
quant_type=QuantType.per_128x128,
activation=(
ActivationType.Silu
if self.moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
expert_mask=self.expert_mask,
)
def forward_deepgemm_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
dispatch_output
)
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
if num_recv_tokens_per_expert is None:
return hidden_states_fp8.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states_fp8.bfloat16()
M, K = hidden_states_fp8.size()
N = self.w13_weight.size(1)
scale_block_size = 128
w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
w2_weight_fp8 = (
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
),
)
hidden_states_fp8_shape = hidden_states_fp8.shape
hidden_states_fp8_device = hidden_states_fp8.device
hidden_states_fp8_dtype = hidden_states_fp8.dtype
input_tensor = [
torch.empty(
(all_tokens, K),
device=hidden_states_fp8.device,
dtype=hidden_states_fp8.dtype,
),
(
# TODO check whether need `zeros`
torch.zeros(
(ceil_div(K // 128, 4), all_tokens),
device=hidden_states_fp8.device,
dtype=torch.int,
).transpose(0, 1)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else torch.empty(
(all_tokens, K // 128),
device=hidden_states_fp8.device,
dtype=torch.float32,
)
),
]
m_indices = torch.empty(
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
)
output_index = torch.empty_like(topk_idx)
if get_offloader().forbid_copy_engine_usage:
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
num_recv_tokens_per_expert
)
else:
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter(
hidden_states_fp8,
hidden_states_scale,
topk_idx,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
input_tensor[0],
input_tensor[1],
m_indices,
output_index,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
dispose_tensor(hidden_states_fp8)
gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
input_tensor[1] = tma_align_input_scale(input_tensor[1])
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
input_tensor, w13_weight_fp8, gateup_output, m_indices
)
del input_tensor
down_input = torch.empty(
(
all_tokens,
N // 2,
),
device=gateup_output.device,
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
del gateup_output
down_output = torch.empty(
(all_tokens, K),
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input,
scale_block_size,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del down_input
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = tma_align_input_scale(down_input_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
w2_weight_fp8,
down_output,
m_indices,
)
del down_input_fp8, down_input_scale
gather_out = torch.empty(
hidden_states_fp8_shape,
device=hidden_states_fp8_device,
dtype=torch.bfloat16,
)
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
return gather_out
def forward_flashinfer_cutedsl(
self,
dispatch_output: DeepEPLLOutput,
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
):
hidden_states, _, _, masked_m, _ = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
output = self.quant_method.apply_without_routing_weights(
layer=self,
x=hidden_states,
masked_m=masked_m,
moe_runner_config=self.moe_runner_config,
down_gemm_overlap_args=down_gemm_overlap_args,
)
return output
def forward_cutlass_w4afp8(
self,
dispatch_output: DeepEPNormalOutput,
):
assert self.moe_runner_config.activation == "silu"
assert isinstance(self.quant_method, W4AFp8MoEMethod)
return self.quant_method.apply_deepep_normal(
layer=self,
dispatch_output=dispatch_output,
)
def forward_deepgemm_masked(
self,
dispatch_output: DeepEPLLOutput,
):
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# GroupGemm-0
num_groups, m, k = hidden_states_fp8[0].size()
n = self.w13_weight.size(1)
expected_m = min(expected_m, m)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
hidden_states_fp8,
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
)
dispose_tensor(hidden_states_fp8[0])
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=gateup_output.device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=gateup_output.device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
(
down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
),
)
down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
)
return down_output
def forward_npu(
self,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
):
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
import torch_npu
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
# NOTE: Ascend's Dispatch & Combine does not support FP16
output_dtype = torch.bfloat16
group_list_type = 1
def _forward_normal(dispatch_output: DeepEPNormalOutput):
if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPNormalOutput)
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
hidden_states.device
)
if self.w13_weight.dtype != torch.int8:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight.permute(0, 2, 1)],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
else:
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
def _forward_ll(dispatch_output: DeepEPLLOutput):
if TYPE_CHECKING:
assert isinstance(dispatch_output, DeepEPLLOutput)
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list = group_list.to(torch.int64)
if self.w13_weight.dtype != torch.int8:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight.permute(0, 2, 1)],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
else:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32,
)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=per_token_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
return _forward_normal(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
return _forward_ll(dispatch_output)
else:
raise ValueError(f"Not Supported DeepEP format {dispatch_output.format}")
def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
return DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if get_moe_runner_backend().is_flashinfer_trtllm():
# FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
# If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
if quant_config is None:
return FusedMoE
try:
# Check the quantization argument directly
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FlashInferFP4MoE,
)
return FlashInferFP4MoE
except:
pass
if should_use_flashinfer_trtllm_moe() and quant_config is not None:
# FIXME: FlashInferFusedMoE only supports fp8 quant now
return FlashInferFusedMoE
if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE
return FusedMoE
def copy_list_to_gpu_no_ce(arr: List[int]):
from sgl_kernel.elementwise import copy_to_gpu_no_ce
tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
return tensor_gpu