DeepEP normal support deepgemm-contiguous (#5626)
Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com> Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Co-authored-by: Xuting Zhou <xutingz@nvidia.com> Co-authored-by: ZhengHSI <zhenghsi@qq.com>
This commit is contained in:
@@ -5,16 +5,23 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import is_cuda
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
try:
|
||||||
|
from deep_gemm import ceil_div
|
||||||
|
except ImportError:
|
||||||
|
logger.error(f"Failed to import ceil_div from deep_gemm.")
|
||||||
|
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -704,3 +711,334 @@ def grouped_gemm_triton(
|
|||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_ep_scatter_1(
|
||||||
|
num_recv_tokens_per_expert,
|
||||||
|
expert_start_loc,
|
||||||
|
m_indices,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
BLOCK_E: tl.constexpr,
|
||||||
|
BLOCK_EXPERT_NUM: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_expert = tl.program_id(0)
|
||||||
|
|
||||||
|
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
|
||||||
|
tokens_per_expert = tl.load(
|
||||||
|
num_recv_tokens_per_expert + offset_cumsum,
|
||||||
|
mask=offset_cumsum < num_experts,
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
|
||||||
|
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
|
||||||
|
|
||||||
|
cur_expert_start = tl.load(expert_start_loc + cur_expert)
|
||||||
|
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
|
||||||
|
|
||||||
|
m_indices_start_ptr = m_indices + cur_expert_start
|
||||||
|
off_expert = tl.arange(0, BLOCK_E)
|
||||||
|
|
||||||
|
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
|
||||||
|
tl.store(
|
||||||
|
m_indices_start_ptr + start_m + off_expert,
|
||||||
|
cur_expert,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_ep_scatter_2(
|
||||||
|
total_token_num,
|
||||||
|
expert_start_loc,
|
||||||
|
recv_x,
|
||||||
|
recv_x_stride0,
|
||||||
|
recv_x_stride1,
|
||||||
|
recv_x_scale,
|
||||||
|
recv_x_scale_stride0,
|
||||||
|
recv_x_scale_stride1,
|
||||||
|
recv_topk,
|
||||||
|
recv_topk_stride0,
|
||||||
|
recv_topk_stride1,
|
||||||
|
output_tensor,
|
||||||
|
output_tensor_stride0,
|
||||||
|
output_tensor_stride1,
|
||||||
|
output_tensor_scale,
|
||||||
|
output_tensor_scale_stride0,
|
||||||
|
output_tensor_scale_stride1,
|
||||||
|
output_index,
|
||||||
|
output_index_stride0,
|
||||||
|
output_index_stride1,
|
||||||
|
topk_num: tl.constexpr,
|
||||||
|
HIDDEN_SIZE: tl.constexpr,
|
||||||
|
HIDDEN_SIZE_PAD: tl.constexpr,
|
||||||
|
SCALE_HIDDEN_SIZE: tl.constexpr,
|
||||||
|
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
|
||||||
|
):
|
||||||
|
start_token_id = tl.program_id(0)
|
||||||
|
grid_num = tl.num_programs(0)
|
||||||
|
|
||||||
|
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
|
||||||
|
mask = offset_in < HIDDEN_SIZE
|
||||||
|
|
||||||
|
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
||||||
|
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
|
||||||
|
|
||||||
|
for token_id in range(start_token_id, total_token_num, grid_num):
|
||||||
|
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
||||||
|
to_copy_s = tl.load(
|
||||||
|
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
|
||||||
|
)
|
||||||
|
|
||||||
|
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
|
||||||
|
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
|
||||||
|
if expert_id >= 0:
|
||||||
|
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
|
||||||
|
tl.store(
|
||||||
|
output_index + token_id * output_index_stride0 + topk_index,
|
||||||
|
dest_token_index,
|
||||||
|
)
|
||||||
|
output_tensor_ptr = (
|
||||||
|
output_tensor + dest_token_index * output_tensor_stride0
|
||||||
|
)
|
||||||
|
output_tensor_scale_ptr = (
|
||||||
|
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
|
||||||
|
)
|
||||||
|
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
|
||||||
|
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
|
||||||
|
|
||||||
|
|
||||||
|
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
|
||||||
|
@torch.no_grad()
|
||||||
|
def ep_scatter(
|
||||||
|
recv_x: torch.Tensor,
|
||||||
|
recv_x_scale: torch.Tensor,
|
||||||
|
recv_topk: torch.Tensor,
|
||||||
|
num_recv_tokens_per_expert: torch.Tensor,
|
||||||
|
expert_start_loc: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
output_tensor_scale: torch.Tensor,
|
||||||
|
m_indices: torch.Tensor,
|
||||||
|
output_index: torch.Tensor,
|
||||||
|
):
|
||||||
|
BLOCK_E = 128 # token num of per expert is aligned to 128
|
||||||
|
BLOCK_D = 128 # block size of quantization
|
||||||
|
num_warps = 8
|
||||||
|
num_experts = num_recv_tokens_per_expert.shape[0]
|
||||||
|
hidden_size = recv_x.shape[1]
|
||||||
|
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
|
||||||
|
grid = num_experts
|
||||||
|
|
||||||
|
assert m_indices.shape[0] % BLOCK_E == 0
|
||||||
|
|
||||||
|
_fwd_kernel_ep_scatter_1[(grid,)](
|
||||||
|
num_recv_tokens_per_expert,
|
||||||
|
expert_start_loc,
|
||||||
|
m_indices,
|
||||||
|
num_experts=num_experts,
|
||||||
|
num_warps=num_warps,
|
||||||
|
BLOCK_E=BLOCK_E,
|
||||||
|
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = min(recv_topk.shape[0], 1024 * 8)
|
||||||
|
|
||||||
|
_fwd_kernel_ep_scatter_2[(grid,)](
|
||||||
|
recv_topk.shape[0],
|
||||||
|
expert_start_loc,
|
||||||
|
recv_x,
|
||||||
|
recv_x.stride(0),
|
||||||
|
recv_x.stride(1),
|
||||||
|
recv_x_scale,
|
||||||
|
recv_x_scale.stride(0),
|
||||||
|
recv_x_scale.stride(1),
|
||||||
|
recv_topk,
|
||||||
|
recv_topk.stride(0),
|
||||||
|
recv_topk.stride(1),
|
||||||
|
output_tensor,
|
||||||
|
output_tensor.stride(0),
|
||||||
|
output_tensor.stride(1),
|
||||||
|
output_tensor_scale,
|
||||||
|
output_tensor_scale.stride(0),
|
||||||
|
output_tensor_scale.stride(1),
|
||||||
|
output_index,
|
||||||
|
output_index.stride(0),
|
||||||
|
output_index.stride(1),
|
||||||
|
topk_num=recv_topk.shape[1],
|
||||||
|
num_warps=num_warps,
|
||||||
|
HIDDEN_SIZE=hidden_size,
|
||||||
|
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
|
||||||
|
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
|
||||||
|
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_ep_gather(
|
||||||
|
total_token_num,
|
||||||
|
input_tensor,
|
||||||
|
input_tensor_stride0,
|
||||||
|
input_tensor_stride1,
|
||||||
|
recv_topk_ids,
|
||||||
|
recv_topk_ids_stride0,
|
||||||
|
recv_topk_ids_stride1,
|
||||||
|
recv_topk_weight,
|
||||||
|
recv_topk_weight_stride0,
|
||||||
|
recv_topk_weight_stride1,
|
||||||
|
input_index,
|
||||||
|
input_index_stride0,
|
||||||
|
input_index_stride1,
|
||||||
|
output_tensor,
|
||||||
|
output_tensor_stride0,
|
||||||
|
output_tensor_stride1,
|
||||||
|
topk_num: tl.constexpr,
|
||||||
|
BLOCK_D: tl.constexpr,
|
||||||
|
):
|
||||||
|
cur_block = tl.program_id(0)
|
||||||
|
start_cur_token = tl.program_id(1)
|
||||||
|
grid_num = tl.num_programs(1)
|
||||||
|
|
||||||
|
for cur_token in range(start_cur_token, total_token_num, grid_num):
|
||||||
|
off_d = tl.arange(0, BLOCK_D)
|
||||||
|
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
|
||||||
|
for topk_index in range(0, topk_num):
|
||||||
|
expert_id = tl.load(
|
||||||
|
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
|
||||||
|
)
|
||||||
|
if expert_id >= 0:
|
||||||
|
source_token_index = tl.load(
|
||||||
|
input_index + cur_token * input_index_stride0 + topk_index
|
||||||
|
)
|
||||||
|
acc_weight = tl.load(
|
||||||
|
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
|
||||||
|
)
|
||||||
|
tmp = tl.load(
|
||||||
|
input_tensor
|
||||||
|
+ source_token_index * input_tensor_stride0
|
||||||
|
+ cur_block * BLOCK_D
|
||||||
|
+ off_d
|
||||||
|
)
|
||||||
|
accumulator += tmp.to(tl.float32) * acc_weight
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
output_tensor
|
||||||
|
+ cur_token * output_tensor_stride0
|
||||||
|
+ cur_block * BLOCK_D
|
||||||
|
+ off_d,
|
||||||
|
accumulator.to(output_tensor.dtype.element_ty),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ep_gather(
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
recv_topk_ids: torch.Tensor,
|
||||||
|
recv_topk_weight: torch.Tensor,
|
||||||
|
input_index: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
):
|
||||||
|
BLOCK_D = 1024 # block size of quantization
|
||||||
|
num_warps = 2
|
||||||
|
num_tokens = output_tensor.shape[0]
|
||||||
|
hidden_size = input_tensor.shape[1]
|
||||||
|
assert hidden_size % BLOCK_D == 0
|
||||||
|
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
|
||||||
|
_fwd_kernel_ep_gather[grid](
|
||||||
|
num_tokens,
|
||||||
|
input_tensor,
|
||||||
|
input_tensor.stride(0),
|
||||||
|
input_tensor.stride(1),
|
||||||
|
recv_topk_ids,
|
||||||
|
recv_topk_ids.stride(0),
|
||||||
|
recv_topk_ids.stride(1),
|
||||||
|
recv_topk_weight,
|
||||||
|
recv_topk_weight.stride(0),
|
||||||
|
recv_topk_weight.stride(1),
|
||||||
|
input_index,
|
||||||
|
input_index.stride(0),
|
||||||
|
input_index.stride(1),
|
||||||
|
output_tensor,
|
||||||
|
output_tensor.stride(0),
|
||||||
|
output_tensor.stride(1),
|
||||||
|
topk_num=recv_topk_ids.shape[1],
|
||||||
|
num_warps=num_warps,
|
||||||
|
BLOCK_D=BLOCK_D,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# copy from
|
||||||
|
# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58
|
||||||
|
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||||
|
"""
|
||||||
|
Global memory address of TMA must be 16-byte aligned.
|
||||||
|
Since we use column-major layout for the LHS scaling tensor,
|
||||||
|
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
x: original M-axis shape of the LHS scaling tensor.
|
||||||
|
element_size: element size of the LHS scaling tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
M-axis shape of the LHS scaling tensor after padding.
|
||||||
|
"""
|
||||||
|
tma_alignment_bytes = 16
|
||||||
|
assert tma_alignment_bytes % element_size == 0
|
||||||
|
alignment = tma_alignment_bytes // element_size
|
||||||
|
return ceil_div(x, alignment) * alignment
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _tma_align_input_scale_kernel(
|
||||||
|
input_scale_ptr,
|
||||||
|
output_ptr,
|
||||||
|
m,
|
||||||
|
k_div_block_size,
|
||||||
|
input_scale_stride_m,
|
||||||
|
input_scale_stride_k,
|
||||||
|
output_stride_m,
|
||||||
|
output_stride_k,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid_m = tl.program_id(axis=0)
|
||||||
|
grid_m = tl.num_programs(0)
|
||||||
|
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
|
||||||
|
for m_base in range(pid_m, m, grid_m):
|
||||||
|
input_offset = (
|
||||||
|
input_scale_ptr
|
||||||
|
+ m_base * input_scale_stride_m
|
||||||
|
+ k_offsets * input_scale_stride_k
|
||||||
|
)
|
||||||
|
input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size)
|
||||||
|
|
||||||
|
output_offset = (
|
||||||
|
output_ptr + k_offsets * output_stride_k + m_base * output_stride_m
|
||||||
|
)
|
||||||
|
tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size)
|
||||||
|
|
||||||
|
|
||||||
|
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py
|
||||||
|
def tma_align_input_scale(input_scale: torch.Tensor):
|
||||||
|
assert input_scale.dim() == 2
|
||||||
|
m, k_div_block_size = input_scale.shape
|
||||||
|
padd_m = get_tma_aligned_size(m, input_scale.element_size())
|
||||||
|
output = torch.empty(
|
||||||
|
(k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device
|
||||||
|
)
|
||||||
|
|
||||||
|
grid_m = min(m, 8192)
|
||||||
|
BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size)
|
||||||
|
|
||||||
|
_tma_align_input_scale_kernel[(grid_m,)](
|
||||||
|
input_scale_ptr=input_scale,
|
||||||
|
output_ptr=output,
|
||||||
|
m=m,
|
||||||
|
k_div_block_size=k_div_block_size,
|
||||||
|
input_scale_stride_m=input_scale.stride(0),
|
||||||
|
input_scale_stride_k=input_scale.stride(1),
|
||||||
|
output_stride_m=output.stride(1), # Note: these are swapped
|
||||||
|
output_stride_k=output.stride(0), # for column-major
|
||||||
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
|
)
|
||||||
|
return output.t()[:m]
|
||||||
|
|||||||
@@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deep_gemm import (
|
from deep_gemm import (
|
||||||
get_col_major_tma_aligned_tensor,
|
get_col_major_tma_aligned_tensor,
|
||||||
|
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||||
)
|
)
|
||||||
|
from sgl_kernel import silu_and_mul
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
sglang_per_token_group_quant_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
use_deep_gemm = True
|
use_deep_gemm = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -20,6 +28,8 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
|
ep_gather,
|
||||||
|
ep_scatter,
|
||||||
gelu_and_mul_triton_kernel,
|
gelu_and_mul_triton_kernel,
|
||||||
grouped_gemm_triton,
|
grouped_gemm_triton,
|
||||||
post_reorder_triton_kernel,
|
post_reorder_triton_kernel,
|
||||||
@@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
run_moe_ep_preproess,
|
run_moe_ep_preproess,
|
||||||
silu_and_mul_masked_post_quant_fwd,
|
silu_and_mul_masked_post_quant_fwd,
|
||||||
silu_and_mul_triton_kernel,
|
silu_and_mul_triton_kernel,
|
||||||
|
tma_align_input_scale,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
||||||
@@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
reorder_topk_ids: torch.Tensor,
|
reorder_topk_ids: torch.Tensor,
|
||||||
seg_indptr: torch.Tensor,
|
seg_indptr: torch.Tensor,
|
||||||
masked_m: torch.Tensor,
|
masked_m: torch.Tensor,
|
||||||
expected_m: int,
|
expected_m: int,
|
||||||
|
num_recv_tokens_per_expert: List[int],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
):
|
):
|
||||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||||
if resolved_deepep_mode == DeepEPMode.normal:
|
if resolved_deepep_mode == DeepEPMode.normal:
|
||||||
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
if _ENABLE_JIT_DEEPGEMM:
|
||||||
|
return self.forward_deepgemm_contiguous(
|
||||||
|
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
||||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||||
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
||||||
else:
|
else:
|
||||||
@@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE):
|
|||||||
)
|
)
|
||||||
return down_output
|
return down_output
|
||||||
|
|
||||||
|
def forward_deepgemm_contiguous(
|
||||||
|
self,
|
||||||
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
num_recv_tokens_per_expert: List[int],
|
||||||
|
):
|
||||||
|
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
||||||
|
assert self.quant_method is not None
|
||||||
|
assert self.activation == "silu"
|
||||||
|
if num_recv_tokens_per_expert is None:
|
||||||
|
return hidden_states_fp8.bfloat16()
|
||||||
|
all_tokens = sum(num_recv_tokens_per_expert)
|
||||||
|
if all_tokens <= 0:
|
||||||
|
return hidden_states_fp8.bfloat16()
|
||||||
|
M, K = hidden_states_fp8.size()
|
||||||
|
N = self.w13_weight.size(1)
|
||||||
|
scale_block_size = 128
|
||||||
|
|
||||||
|
gather_out = torch.empty_like(
|
||||||
|
hidden_states_fp8,
|
||||||
|
device=hidden_states_fp8.device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_tensor = [
|
||||||
|
torch.empty(
|
||||||
|
(all_tokens, K),
|
||||||
|
device=hidden_states_fp8.device,
|
||||||
|
dtype=hidden_states_fp8.dtype,
|
||||||
|
),
|
||||||
|
torch.empty(
|
||||||
|
(all_tokens, K // 128),
|
||||||
|
device=hidden_states_fp8.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
m_indices = torch.empty(
|
||||||
|
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
output_index = torch.empty_like(topk_idx)
|
||||||
|
|
||||||
|
num_recv_tokens_per_expert_gpu = torch.tensor(
|
||||||
|
num_recv_tokens_per_expert,
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=True,
|
||||||
|
device="cpu",
|
||||||
|
).cuda(non_blocking=True)
|
||||||
|
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
||||||
|
|
||||||
|
ep_scatter(
|
||||||
|
hidden_states_fp8,
|
||||||
|
hidden_states_scale,
|
||||||
|
topk_idx,
|
||||||
|
num_recv_tokens_per_expert_gpu,
|
||||||
|
expert_start_loc,
|
||||||
|
input_tensor[0],
|
||||||
|
input_tensor[1],
|
||||||
|
m_indices,
|
||||||
|
output_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
gateup_output = torch.empty(
|
||||||
|
(all_tokens, N),
|
||||||
|
device=hidden_states_fp8.device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
||||||
|
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
|
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
||||||
|
)
|
||||||
|
down_input = torch.empty(
|
||||||
|
(
|
||||||
|
all_tokens,
|
||||||
|
N // 2,
|
||||||
|
),
|
||||||
|
device=gateup_output.device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
silu_and_mul(gateup_output.view(-1, N), down_input)
|
||||||
|
down_output = torch.empty(
|
||||||
|
(all_tokens, K),
|
||||||
|
device=hidden_states_fp8.device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
||||||
|
down_input, scale_block_size
|
||||||
|
)
|
||||||
|
down_input_scale = tma_align_input_scale(down_input_scale)
|
||||||
|
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
|
(down_input_fp8, down_input_scale),
|
||||||
|
self.w2_weight_fp8,
|
||||||
|
down_output,
|
||||||
|
m_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
||||||
|
|
||||||
|
return gather_out
|
||||||
|
|
||||||
def forward_deepgemm_masked(
|
def forward_deepgemm_masked(
|
||||||
self,
|
self,
|
||||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
from sglang.srt.utils import DeepEPMode
|
from sglang.srt.utils import DeepEPMode
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deep_ep import Buffer
|
from deep_ep import Buffer
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
sglang_per_token_group_quant_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
use_deepep = True
|
use_deepep = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
use_deepep = False
|
use_deepep = False
|
||||||
|
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -78,7 +83,6 @@ class DeepEPBuffer:
|
|||||||
),
|
),
|
||||||
num_rdma_bytes,
|
num_rdma_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
cls._buffer = Buffer(
|
cls._buffer = Buffer(
|
||||||
group,
|
group,
|
||||||
num_nvl_bytes,
|
num_nvl_bytes,
|
||||||
@@ -181,44 +185,74 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
):
|
):
|
||||||
topk_idx = topk_idx.to(torch.int64)
|
topk_idx = topk_idx.to(torch.int64)
|
||||||
|
if _ENABLE_JIT_DEEPGEMM:
|
||||||
|
# TODO hard code 128 block quant,use fp8 communication
|
||||||
|
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
|
||||||
previous_event = Buffer.capture() if self.async_finish else None
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
return hidden_states, topk_idx, topk_weights, previous_event
|
return hidden_states, topk_idx, topk_weights, previous_event
|
||||||
|
|
||||||
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
||||||
(
|
if _ENABLE_JIT_DEEPGEMM:
|
||||||
hidden_states,
|
(
|
||||||
topk_idx,
|
hidden_states,
|
||||||
topk_weights,
|
topk_idx,
|
||||||
event,
|
topk_weights,
|
||||||
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
num_recv_tokens_per_expert_list,
|
||||||
event.current_stream_wait() if self.async_finish else ()
|
event,
|
||||||
if hidden_states.shape[0] > 0:
|
) = self._dispatch_core(
|
||||||
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
hidden_states, topk_idx, topk_weights, previous_event
|
||||||
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
)
|
||||||
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
|
return (
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
None,
|
||||||
|
num_recv_tokens_per_expert_list,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
reorder_topk_ids = torch.empty(
|
(
|
||||||
(0,), device=hidden_states.device, dtype=torch.int64
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
num_recv_tokens_per_expert_list,
|
||||||
|
event,
|
||||||
|
) = self._dispatch_core(
|
||||||
|
hidden_states, topk_idx, topk_weights, previous_event
|
||||||
)
|
)
|
||||||
seg_indptr = torch.zeros(
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
if hidden_states.shape[0] > 0:
|
||||||
|
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
||||||
|
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reorder_topk_ids = torch.empty(
|
||||||
|
(0,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
seg_indptr = torch.zeros(
|
||||||
|
(self.num_experts + 1,),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
|
||||||
|
masked_m = expected_m = None
|
||||||
|
return (
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
reorder_topk_ids,
|
||||||
|
None,
|
||||||
|
seg_indptr,
|
||||||
|
masked_m,
|
||||||
|
expected_m,
|
||||||
)
|
)
|
||||||
|
|
||||||
masked_m = expected_m = None
|
|
||||||
|
|
||||||
return (
|
|
||||||
hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
reorder_topk_ids,
|
|
||||||
seg_indptr,
|
|
||||||
masked_m,
|
|
||||||
expected_m,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _dispatch_core(
|
def _dispatch_core(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
previous_event,
|
previous_event,
|
||||||
@@ -246,7 +280,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
recv_x,
|
recv_x,
|
||||||
recv_topk_idx,
|
recv_topk_idx,
|
||||||
recv_topk_weights,
|
recv_topk_weights,
|
||||||
_, # num_recv_tokens_per_expert_list
|
num_recv_tokens_per_expert_list,
|
||||||
self.handle,
|
self.handle,
|
||||||
event,
|
event,
|
||||||
) = buffer.dispatch(
|
) = buffer.dispatch(
|
||||||
@@ -260,12 +294,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
previous_event=previous_event,
|
previous_event=previous_event,
|
||||||
async_finish=self.async_finish,
|
async_finish=self.async_finish,
|
||||||
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
||||||
|
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
recv_x,
|
recv_x,
|
||||||
recv_topk_idx,
|
recv_topk_idx,
|
||||||
recv_topk_weights,
|
recv_topk_weights,
|
||||||
|
num_recv_tokens_per_expert_list,
|
||||||
event,
|
event,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -314,29 +350,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
):
|
):
|
||||||
if hidden_states.shape[0] > 0:
|
if _ENABLE_JIT_DEEPGEMM:
|
||||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
output = hidden_states
|
||||||
output = torch.empty(
|
|
||||||
(num_tokens, hidden_states.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
|
||||||
hidden_states,
|
|
||||||
output,
|
|
||||||
self.src2dst,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
self.router_topk,
|
|
||||||
hidden_states.shape[1],
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
output = torch.zeros(
|
if hidden_states.shape[0] > 0:
|
||||||
(0, hidden_states.shape[1]),
|
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||||
device=hidden_states.device,
|
output = torch.empty(
|
||||||
dtype=hidden_states.dtype,
|
(num_tokens, hidden_states.shape[1]),
|
||||||
)
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||||
|
hidden_states,
|
||||||
|
output,
|
||||||
|
self.src2dst,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
self.router_topk,
|
||||||
|
hidden_states.shape[1],
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = torch.zeros(
|
||||||
|
(0, hidden_states.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
previous_event = Buffer.capture() if self.async_finish else None
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
return output, previous_event
|
return output, previous_event
|
||||||
|
|
||||||
@@ -360,6 +399,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
|
|
||||||
def _get_buffer(self):
|
def _get_buffer(self):
|
||||||
DeepEPBuffer.set_dispatch_mode_as_normal()
|
DeepEPBuffer.set_dispatch_mode_as_normal()
|
||||||
|
|
||||||
return DeepEPBuffer.get_deepep_buffer(
|
return DeepEPBuffer.get_deepep_buffer(
|
||||||
self.group,
|
self.group,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
@@ -426,6 +466,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
reorder_topk_ids,
|
reorder_topk_ids,
|
||||||
|
None,
|
||||||
seg_indptr,
|
seg_indptr,
|
||||||
masked_m,
|
masked_m,
|
||||||
expected_m,
|
expected_m,
|
||||||
@@ -570,7 +611,8 @@ class DeepEPDispatcher:
|
|||||||
|
|
||||||
def dispatch(self, *args, **kwargs) -> Tuple:
|
def dispatch(self, *args, **kwargs) -> Tuple:
|
||||||
self.dispatch_a(*args, **kwargs)
|
self.dispatch_a(*args, **kwargs)
|
||||||
return self.dispatch_b()
|
ret = self.dispatch_b()
|
||||||
|
return ret
|
||||||
|
|
||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
@@ -593,7 +635,8 @@ class DeepEPDispatcher:
|
|||||||
|
|
||||||
def combine(self, *args, **kwargs) -> Tuple:
|
def combine(self, *args, **kwargs) -> Tuple:
|
||||||
self.combine_a(*args, **kwargs)
|
self.combine_a(*args, **kwargs)
|
||||||
return self.combine_b()
|
ret = self.combine_b()
|
||||||
|
return ret
|
||||||
|
|
||||||
def combine_a(
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -28,6 +28,11 @@ if is_cuda():
|
|||||||
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
||||||
_ENABLE_JIT_DEEPGEMM = True
|
_ENABLE_JIT_DEEPGEMM = True
|
||||||
|
|
||||||
|
|
||||||
|
def get_enable_jit_deepgemm():
|
||||||
|
return _ENABLE_JIT_DEEPGEMM
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
||||||
|
|||||||
@@ -308,8 +308,8 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
device=x.device,
|
device=x.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
if x.shape[0] > 0:
|
||||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
||||||
|
|
||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
|
|
||||||
|
|||||||
@@ -357,6 +357,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
reorder_topk_ids,
|
reorder_topk_ids,
|
||||||
|
num_recv_tokens_per_expert,
|
||||||
seg_indptr,
|
seg_indptr,
|
||||||
masked_m,
|
masked_m,
|
||||||
expected_m,
|
expected_m,
|
||||||
@@ -368,10 +369,13 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
)
|
)
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
topk_idx=topk_idx,
|
||||||
|
topk_weights=topk_weights,
|
||||||
reorder_topk_ids=reorder_topk_ids,
|
reorder_topk_ids=reorder_topk_ids,
|
||||||
seg_indptr=seg_indptr,
|
seg_indptr=seg_indptr,
|
||||||
masked_m=masked_m,
|
masked_m=masked_m,
|
||||||
expected_m=expected_m,
|
expected_m=expected_m,
|
||||||
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
)
|
)
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user