[Feature] Integrate DeepEP into SGLang (#4232)

Co-authored-by: Cheng Wan <cwan39@gatech.edu>
Co-authored-by: Xuting Zhou <xutingz@nvidia.com>
This commit is contained in:
Jinyan Chen
2025-03-19 23:16:31 +08:00
committed by GitHub
parent f9c53cbb42
commit f44db16c8e
12 changed files with 1228 additions and 35 deletions

View File

@@ -2,6 +2,13 @@ import logging
from typing import Callable, List, Optional, Tuple
import torch
# TODO: use deep_gemm masked kernel after low latency dispatch
# import deep_gemm
# from deep_gemm import (
# get_col_major_tma_aligned_tensor,
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
# )
from torch.nn import Module
from sglang.srt.custom_op import CustomOp
@@ -25,6 +32,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
_is_cuda = is_cuda()
@@ -39,6 +47,8 @@ logger = logging.getLogger(__name__)
_is_hip = is_hip()
_buffer = None
class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None
@@ -773,3 +783,267 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
raise NotImplementedError
class DeepEPMoE(EPMoE):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
"""
_has_printed = False
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
):
super().__init__(
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype,
renormalize,
use_grouped_topk,
num_expert_group,
topk_group,
quant_config,
tp_size,
prefix,
correction_bias,
custom_routing_function,
activation,
)
def forward(
self,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
forward_mode: ForwardMode,
):
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
if True: # not forward_mode.is_decode():
return self.forward_normal(hidden_states, tokens_per_expert)
else:
return self.forward_deepgemm_masked(hidden_states, tokens_per_expert)
def forward_normal(
self,
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
):
assert self.quant_method is not None
assert self.activation == "silu"
if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
)
seg_indptr_cur_rank = torch.cat(
[
torch.zeros(
1, device=tokens_per_expert.device, dtype=tokens_per_expert.dtype
),
torch.cumsum(tokens_per_expert, dim=0),
]
)
reorder_topk_ids = torch.repeat_interleave(tokens_per_expert)
if self.activation_scheme == "dynamic" and not self.use_block_quant:
max_value = (
torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
weight_indices_cur_rank = torch.arange(
0,
self.num_experts_per_partition,
device=hidden_states.device,
dtype=torch.int64,
)
# GroupGemm-0
gateup_output = torch.empty(
hidden_states.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if hidden_states.shape[0] > 0:
gateup_output = self.grouped_gemm_runner(
a=hidden_states,
b=self.w13_weight,
c=gateup_output,
batch_size=self.num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w13_input_scale,
scale_b=(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
block_shape=self.block_shape,
)
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
),
)
if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
)
if self.activation == "silu":
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
0,
self.num_experts_per_partition - 1,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if down_input.shape[0] > 0:
down_output = self.grouped_gemm_runner(
a=down_input,
b=self.w2_weight,
c=down_output,
batch_size=self.num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w2_input_scale,
scale_b=(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
),
block_shape=self.block_shape,
)
return down_output
def forward_deepgemm_masked(
self,
hidden_states: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
):
assert self.quant_method is not None
assert self.activation == "silu"
if self.activation_scheme == "dynamic" and not self.use_block_quant:
max_value = (
torch.max(hidden_states)
.repeat(self.num_experts_per_partition)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
# GroupGemm-0
gateup_output = torch.empty(
hidden_states.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if hidden_states.shape[0] > 0:
# Transpose earlier so that the testing will not trigger transposing kernels
hidden_states = (
hidden_states[0],
get_col_major_tma_aligned_tensor(hidden_states[1]),
)
"""
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states, self.w13_weight, out, masked_m, expected_m
)
"""
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states.dtype
),
)
if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
)
if self.activation == "silu":
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
0,
self.num_experts_per_partition - 1,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if down_input.shape[0] > 0:
# Transpose earlier so that the testing will not trigger transposing kernels
down_input = (
down_input[0],
get_col_major_tma_aligned_tensor(down_input[1]),
)
"""
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input, self.w2_weight, out, masked_m, expected_m
)
"""
return down_output