DeepEP normal support deepgemm-contiguous (#5626)
Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com> Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Co-authored-by: Xuting Zhou <xutingz@nvidia.com> Co-authored-by: ZhengHSI <zhenghsi@qq.com>
This commit is contained in:
@@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
|
||||
try:
|
||||
from deep_gemm import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
)
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
use_deep_gemm = True
|
||||
except ImportError:
|
||||
@@ -20,6 +28,8 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
ep_gather,
|
||||
ep_scatter,
|
||||
gelu_and_mul_triton_kernel,
|
||||
grouped_gemm_triton,
|
||||
post_reorder_triton_kernel,
|
||||
@@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
run_moe_ep_preproess,
|
||||
silu_and_mul_masked_post_quant_fwd,
|
||||
silu_and_mul_triton_kernel,
|
||||
tma_align_input_scale,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
||||
@@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
seg_indptr: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
num_recv_tokens_per_expert: List[int],
|
||||
forward_mode: ForwardMode,
|
||||
):
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||
if resolved_deepep_mode == DeepEPMode.normal:
|
||||
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
return self.forward_deepgemm_contiguous(
|
||||
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
||||
)
|
||||
else:
|
||||
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
||||
else:
|
||||
@@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
return down_output
|
||||
|
||||
def forward_deepgemm_contiguous(
|
||||
self,
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert: List[int],
|
||||
):
|
||||
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
if num_recv_tokens_per_expert is None:
|
||||
return hidden_states_fp8.bfloat16()
|
||||
all_tokens = sum(num_recv_tokens_per_expert)
|
||||
if all_tokens <= 0:
|
||||
return hidden_states_fp8.bfloat16()
|
||||
M, K = hidden_states_fp8.size()
|
||||
N = self.w13_weight.size(1)
|
||||
scale_block_size = 128
|
||||
|
||||
gather_out = torch.empty_like(
|
||||
hidden_states_fp8,
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
input_tensor = [
|
||||
torch.empty(
|
||||
(all_tokens, K),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=hidden_states_fp8.dtype,
|
||||
),
|
||||
torch.empty(
|
||||
(all_tokens, K // 128),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
]
|
||||
m_indices = torch.empty(
|
||||
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
|
||||
)
|
||||
output_index = torch.empty_like(topk_idx)
|
||||
|
||||
num_recv_tokens_per_expert_gpu = torch.tensor(
|
||||
num_recv_tokens_per_expert,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True,
|
||||
device="cpu",
|
||||
).cuda(non_blocking=True)
|
||||
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
||||
|
||||
ep_scatter(
|
||||
hidden_states_fp8,
|
||||
hidden_states_scale,
|
||||
topk_idx,
|
||||
num_recv_tokens_per_expert_gpu,
|
||||
expert_start_loc,
|
||||
input_tensor[0],
|
||||
input_tensor[1],
|
||||
m_indices,
|
||||
output_index,
|
||||
)
|
||||
|
||||
gateup_output = torch.empty(
|
||||
(all_tokens, N),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
||||
)
|
||||
down_input = torch.empty(
|
||||
(
|
||||
all_tokens,
|
||||
N // 2,
|
||||
),
|
||||
device=gateup_output.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
silu_and_mul(gateup_output.view(-1, N), down_input)
|
||||
down_output = torch.empty(
|
||||
(all_tokens, K),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
||||
down_input, scale_block_size
|
||||
)
|
||||
down_input_scale = tma_align_input_scale(down_input_scale)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(down_input_fp8, down_input_scale),
|
||||
self.w2_weight_fp8,
|
||||
down_output,
|
||||
m_indices,
|
||||
)
|
||||
|
||||
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
||||
|
||||
return gather_out
|
||||
|
||||
def forward_deepgemm_masked(
|
||||
self,
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
Reference in New Issue
Block a user