DeepEP normal support deepgemm-contiguous (#5626)

Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com>
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: Xuting Zhou <xutingz@nvidia.com>
Co-authored-by: ZhengHSI <zhenghsi@qq.com>
This commit is contained in:
lukec
2025-05-08 16:20:32 +08:00
committed by GitHub
parent a05bd83a94
commit acc816d8a2
6 changed files with 568 additions and 59 deletions

View File

@@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple
import torch
from torch.nn import Module
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
try:
from deep_gemm import (
get_col_major_tma_aligned_tensor,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
)
from sgl_kernel import silu_and_mul
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
use_deep_gemm = True
except ImportError:
@@ -20,6 +28,8 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
post_reorder_triton_kernel,
@@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
run_moe_ep_preproess,
silu_and_mul_masked_post_quant_fwd,
silu_and_mul_triton_kernel,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
@@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE):
def forward(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
reorder_topk_ids: torch.Tensor,
seg_indptr: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
num_recv_tokens_per_expert: List[int],
forward_mode: ForwardMode,
):
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
if _ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
)
else:
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
elif resolved_deepep_mode == DeepEPMode.low_latency:
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
else:
@@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE):
)
return down_output
def forward_deepgemm_contiguous(
self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
topk_idx,
topk_weights,
num_recv_tokens_per_expert: List[int],
):
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None
assert self.activation == "silu"
if num_recv_tokens_per_expert is None:
return hidden_states_fp8.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states_fp8.bfloat16()
M, K = hidden_states_fp8.size()
N = self.w13_weight.size(1)
scale_block_size = 128
gather_out = torch.empty_like(
hidden_states_fp8,
device=hidden_states_fp8.device,
dtype=torch.bfloat16,
)
input_tensor = [
torch.empty(
(all_tokens, K),
device=hidden_states_fp8.device,
dtype=hidden_states_fp8.dtype,
),
torch.empty(
(all_tokens, K // 128),
device=hidden_states_fp8.device,
dtype=torch.float32,
),
]
m_indices = torch.empty(
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
)
output_index = torch.empty_like(topk_idx)
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter(
hidden_states_fp8,
hidden_states_scale,
topk_idx,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
input_tensor[0],
input_tensor[1],
m_indices,
output_index,
)
gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_fp8.device,
dtype=torch.bfloat16,
)
input_tensor[1] = tma_align_input_scale(input_tensor[1])
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
)
down_input = torch.empty(
(
all_tokens,
N // 2,
),
device=gateup_output.device,
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
down_output = torch.empty(
(all_tokens, K),
device=hidden_states_fp8.device,
dtype=torch.bfloat16,
)
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input, scale_block_size
)
down_input_scale = tma_align_input_scale(down_input_scale)
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(down_input_fp8, down_input_scale),
self.w2_weight_fp8,
down_output,
m_indices,
)
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
return gather_out
def forward_deepgemm_masked(
self,
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],