[Feature] Support DeepEP Low Latency (#4767)
Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: laixinn <xielx@shanghaitech.edu.cn> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -3,12 +3,16 @@ from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
# TODO: use deep_gemm masked kernel after low latency dispatch
|
||||
# import deep_gemm
|
||||
# from deep_gemm import (
|
||||
# get_col_major_tma_aligned_tensor,
|
||||
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
# )
|
||||
try:
|
||||
from deep_gemm import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
)
|
||||
|
||||
use_deep_gemm = True
|
||||
except ImportError:
|
||||
use_deep_gemm = False
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
@@ -22,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
post_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel,
|
||||
run_moe_ep_preproess,
|
||||
silu_and_mul_masked_post_quant_fwd,
|
||||
silu_and_mul_triton_kernel,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
@@ -809,6 +814,7 @@ class DeepEPMoE(EPMoE):
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
activation: str = "silu",
|
||||
deepep_mode: str = "auto",
|
||||
):
|
||||
super().__init__(
|
||||
num_experts,
|
||||
@@ -827,21 +833,41 @@ class DeepEPMoE(EPMoE):
|
||||
custom_routing_function,
|
||||
activation,
|
||||
)
|
||||
self.deepep_mode = deepep_mode
|
||||
if self.deepep_mode in ["low_latency", "auto"]:
|
||||
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
seg_indptr: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
forward_mode: ForwardMode,
|
||||
):
|
||||
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
|
||||
if True: # not forward_mode.is_decode():
|
||||
if self.deepep_mode == "normal" or (
|
||||
self.deepep_mode == "auto" and not forward_mode.is_decode()
|
||||
):
|
||||
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
||||
elif self.deepep_mode == "low_latency" or (
|
||||
self.deepep_mode == "auto" and forward_mode.is_decode()
|
||||
):
|
||||
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
||||
else:
|
||||
return self.forward_deepgemm_masked(
|
||||
hidden_states, reorder_topk_ids, seg_indptr
|
||||
)
|
||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
@@ -958,89 +984,66 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
def forward_deepgemm_masked(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
seg_indptr: torch.Tensor,
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
):
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
|
||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||
max_value = (
|
||||
torch.max(hidden_states)
|
||||
.repeat(self.num_experts_per_partition)
|
||||
.to(torch.float32)
|
||||
)
|
||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||
assert (
|
||||
hidden_states_fp8[0].size(0) % 4 == 0
|
||||
), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
|
||||
|
||||
# GroupGemm-0
|
||||
num_groups, m, k = hidden_states_fp8[0].size()
|
||||
n = self.w13_weight.size(1)
|
||||
expected_m = min(expected_m, m)
|
||||
gateup_output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.w13_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
||||
)
|
||||
if hidden_states.shape[0] > 0:
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
hidden_states = (
|
||||
hidden_states[0],
|
||||
get_col_major_tma_aligned_tensor(hidden_states[1]),
|
||||
)
|
||||
"""
|
||||
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
hidden_states, self.w13_weight, out, masked_m, expected_m
|
||||
)
|
||||
"""
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1] // 2,
|
||||
device=gateup_output.device,
|
||||
dtype=(
|
||||
self.fp8_dtype
|
||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
if self.w2_input_scale is None and not self.use_block_quant:
|
||||
self.w2_input_scale = torch.ones(
|
||||
self.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
if self.activation == "silu":
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
0,
|
||||
self.num_experts_per_partition - 1,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
gateup_output.shape[2] // 2,
|
||||
),
|
||||
device=gateup_output.device,
|
||||
dtype=self.fp8_dtype,
|
||||
)
|
||||
scale_block_size = 128
|
||||
down_input_scale = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2 // scale_block_size,
|
||||
),
|
||||
device=gateup_output.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
silu_and_mul_masked_post_quant_fwd(
|
||||
gateup_output,
|
||||
down_input,
|
||||
down_input_scale,
|
||||
scale_block_size,
|
||||
masked_m,
|
||||
)
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
down_input.shape[0],
|
||||
self.w2_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
n = self.w2_weight.size(1)
|
||||
down_input_fp8 = (
|
||||
down_input,
|
||||
get_col_major_tma_aligned_tensor(down_input_scale),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
||||
)
|
||||
if down_input.shape[0] > 0:
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
down_input = (
|
||||
down_input[0],
|
||||
get_col_major_tma_aligned_tensor(down_input[1]),
|
||||
)
|
||||
"""
|
||||
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
down_input, self.w2_weight, out, masked_m, expected_m
|
||||
)
|
||||
"""
|
||||
|
||||
return down_output
|
||||
|
||||
Reference in New Issue
Block a user