[Feature] Integrate DeepEP into SGLang (#4232)
Co-authored-by: Cheng Wan <cwan39@gatech.edu> Co-authored-by: Xuting Zhou <xutingz@nvidia.com>
This commit is contained in:
@@ -2,6 +2,13 @@ import logging
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
# TODO: use deep_gemm masked kernel after low latency dispatch
|
||||
# import deep_gemm
|
||||
# from deep_gemm import (
|
||||
# get_col_major_tma_aligned_tensor,
|
||||
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
# )
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
@@ -25,6 +32,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -39,6 +47,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
_buffer = None
|
||||
|
||||
|
||||
class GroupedGemmRunner(torch.nn.Module):
|
||||
flashinfer_gemm_warpper = None
|
||||
@@ -773,3 +783,267 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DeepEPMoE(EPMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
||||
"""
|
||||
|
||||
_has_printed = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
activation: str = "silu",
|
||||
):
|
||||
super().__init__(
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
quant_config,
|
||||
tp_size,
|
||||
prefix,
|
||||
correction_bias,
|
||||
custom_routing_function,
|
||||
activation,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
tokens_per_expert: torch.Tensor,
|
||||
forward_mode: ForwardMode,
|
||||
):
|
||||
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
|
||||
if True: # not forward_mode.is_decode():
|
||||
return self.forward_normal(hidden_states, tokens_per_expert)
|
||||
else:
|
||||
return self.forward_deepgemm_masked(hidden_states, tokens_per_expert)
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
tokens_per_expert: torch.Tensor,
|
||||
):
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
if self.grouped_gemm_runner is None:
|
||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
||||
)
|
||||
seg_indptr_cur_rank = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
1, device=tokens_per_expert.device, dtype=tokens_per_expert.dtype
|
||||
),
|
||||
torch.cumsum(tokens_per_expert, dim=0),
|
||||
]
|
||||
)
|
||||
reorder_topk_ids = torch.repeat_interleave(tokens_per_expert)
|
||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||
max_value = (
|
||||
torch.max(hidden_states)
|
||||
.repeat(self.num_experts_per_partition)
|
||||
.to(torch.float32)
|
||||
)
|
||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||
weight_indices_cur_rank = torch.arange(
|
||||
0,
|
||||
self.num_experts_per_partition,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.w13_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
if hidden_states.shape[0] > 0:
|
||||
gateup_output = self.grouped_gemm_runner(
|
||||
a=hidden_states,
|
||||
b=self.w13_weight,
|
||||
c=gateup_output,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
scale_a=self.w13_input_scale,
|
||||
scale_b=(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1] // 2,
|
||||
device=gateup_output.device,
|
||||
dtype=(
|
||||
self.fp8_dtype
|
||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
if self.w2_input_scale is None and not self.use_block_quant:
|
||||
self.w2_input_scale = torch.ones(
|
||||
self.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
if self.activation == "silu":
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
0,
|
||||
self.num_experts_per_partition - 1,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
down_input.shape[0],
|
||||
self.w2_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
if down_input.shape[0] > 0:
|
||||
down_output = self.grouped_gemm_runner(
|
||||
a=down_input,
|
||||
b=self.w2_weight,
|
||||
c=down_output,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
scale_a=self.w2_input_scale,
|
||||
scale_b=(
|
||||
self.w2_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w2_weight_scale
|
||||
),
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
return down_output
|
||||
|
||||
def forward_deepgemm_masked(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
seg_indptr: torch.Tensor,
|
||||
):
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
|
||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||
max_value = (
|
||||
torch.max(hidden_states)
|
||||
.repeat(self.num_experts_per_partition)
|
||||
.to(torch.float32)
|
||||
)
|
||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.w13_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
if hidden_states.shape[0] > 0:
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
hidden_states = (
|
||||
hidden_states[0],
|
||||
get_col_major_tma_aligned_tensor(hidden_states[1]),
|
||||
)
|
||||
"""
|
||||
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
hidden_states, self.w13_weight, out, masked_m, expected_m
|
||||
)
|
||||
"""
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1] // 2,
|
||||
device=gateup_output.device,
|
||||
dtype=(
|
||||
self.fp8_dtype
|
||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
if self.w2_input_scale is None and not self.use_block_quant:
|
||||
self.w2_input_scale = torch.ones(
|
||||
self.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
if self.activation == "silu":
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
0,
|
||||
self.num_experts_per_partition - 1,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
down_input.shape[0],
|
||||
self.w2_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
if down_input.shape[0] > 0:
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
down_input = (
|
||||
down_input[0],
|
||||
get_col_major_tma_aligned_tensor(down_input[1]),
|
||||
)
|
||||
"""
|
||||
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
down_input, self.w2_weight, out, masked_m, expected_m
|
||||
)
|
||||
"""
|
||||
|
||||
return down_output
|
||||
|
||||
Reference in New Issue
Block a user