DeepEP normal support deepgemm-contiguous (#5626)
Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com> Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Co-authored-by: Xuting Zhou <xutingz@nvidia.com> Co-authored-by: ZhengHSI <zhenghsi@qq.com>
This commit is contained in:
@@ -5,16 +5,23 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from deep_gemm import ceil_div
|
||||
except ImportError:
|
||||
logger.error(f"Failed to import ceil_div from deep_gemm.")
|
||||
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -704,3 +711,334 @@ def grouped_gemm_triton(
|
||||
**config,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_ep_scatter_1(
|
||||
num_recv_tokens_per_expert,
|
||||
expert_start_loc,
|
||||
m_indices,
|
||||
num_experts: tl.constexpr,
|
||||
BLOCK_E: tl.constexpr,
|
||||
BLOCK_EXPERT_NUM: tl.constexpr,
|
||||
):
|
||||
cur_expert = tl.program_id(0)
|
||||
|
||||
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
|
||||
tokens_per_expert = tl.load(
|
||||
num_recv_tokens_per_expert + offset_cumsum,
|
||||
mask=offset_cumsum < num_experts,
|
||||
other=0,
|
||||
)
|
||||
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
|
||||
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
|
||||
|
||||
cur_expert_start = tl.load(expert_start_loc + cur_expert)
|
||||
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
|
||||
|
||||
m_indices_start_ptr = m_indices + cur_expert_start
|
||||
off_expert = tl.arange(0, BLOCK_E)
|
||||
|
||||
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
|
||||
tl.store(
|
||||
m_indices_start_ptr + start_m + off_expert,
|
||||
cur_expert,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_ep_scatter_2(
|
||||
total_token_num,
|
||||
expert_start_loc,
|
||||
recv_x,
|
||||
recv_x_stride0,
|
||||
recv_x_stride1,
|
||||
recv_x_scale,
|
||||
recv_x_scale_stride0,
|
||||
recv_x_scale_stride1,
|
||||
recv_topk,
|
||||
recv_topk_stride0,
|
||||
recv_topk_stride1,
|
||||
output_tensor,
|
||||
output_tensor_stride0,
|
||||
output_tensor_stride1,
|
||||
output_tensor_scale,
|
||||
output_tensor_scale_stride0,
|
||||
output_tensor_scale_stride1,
|
||||
output_index,
|
||||
output_index_stride0,
|
||||
output_index_stride1,
|
||||
topk_num: tl.constexpr,
|
||||
HIDDEN_SIZE: tl.constexpr,
|
||||
HIDDEN_SIZE_PAD: tl.constexpr,
|
||||
SCALE_HIDDEN_SIZE: tl.constexpr,
|
||||
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
|
||||
):
|
||||
start_token_id = tl.program_id(0)
|
||||
grid_num = tl.num_programs(0)
|
||||
|
||||
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
|
||||
mask = offset_in < HIDDEN_SIZE
|
||||
|
||||
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
||||
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
|
||||
|
||||
for token_id in range(start_token_id, total_token_num, grid_num):
|
||||
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
||||
to_copy_s = tl.load(
|
||||
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
|
||||
)
|
||||
|
||||
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
|
||||
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
|
||||
if expert_id >= 0:
|
||||
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
|
||||
tl.store(
|
||||
output_index + token_id * output_index_stride0 + topk_index,
|
||||
dest_token_index,
|
||||
)
|
||||
output_tensor_ptr = (
|
||||
output_tensor + dest_token_index * output_tensor_stride0
|
||||
)
|
||||
output_tensor_scale_ptr = (
|
||||
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
|
||||
)
|
||||
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
|
||||
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
|
||||
|
||||
|
||||
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
|
||||
@torch.no_grad()
|
||||
def ep_scatter(
|
||||
recv_x: torch.Tensor,
|
||||
recv_x_scale: torch.Tensor,
|
||||
recv_topk: torch.Tensor,
|
||||
num_recv_tokens_per_expert: torch.Tensor,
|
||||
expert_start_loc: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
output_tensor_scale: torch.Tensor,
|
||||
m_indices: torch.Tensor,
|
||||
output_index: torch.Tensor,
|
||||
):
|
||||
BLOCK_E = 128 # token num of per expert is aligned to 128
|
||||
BLOCK_D = 128 # block size of quantization
|
||||
num_warps = 8
|
||||
num_experts = num_recv_tokens_per_expert.shape[0]
|
||||
hidden_size = recv_x.shape[1]
|
||||
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
|
||||
grid = num_experts
|
||||
|
||||
assert m_indices.shape[0] % BLOCK_E == 0
|
||||
|
||||
_fwd_kernel_ep_scatter_1[(grid,)](
|
||||
num_recv_tokens_per_expert,
|
||||
expert_start_loc,
|
||||
m_indices,
|
||||
num_experts=num_experts,
|
||||
num_warps=num_warps,
|
||||
BLOCK_E=BLOCK_E,
|
||||
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
|
||||
)
|
||||
|
||||
grid = min(recv_topk.shape[0], 1024 * 8)
|
||||
|
||||
_fwd_kernel_ep_scatter_2[(grid,)](
|
||||
recv_topk.shape[0],
|
||||
expert_start_loc,
|
||||
recv_x,
|
||||
recv_x.stride(0),
|
||||
recv_x.stride(1),
|
||||
recv_x_scale,
|
||||
recv_x_scale.stride(0),
|
||||
recv_x_scale.stride(1),
|
||||
recv_topk,
|
||||
recv_topk.stride(0),
|
||||
recv_topk.stride(1),
|
||||
output_tensor,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
output_tensor_scale,
|
||||
output_tensor_scale.stride(0),
|
||||
output_tensor_scale.stride(1),
|
||||
output_index,
|
||||
output_index.stride(0),
|
||||
output_index.stride(1),
|
||||
topk_num=recv_topk.shape[1],
|
||||
num_warps=num_warps,
|
||||
HIDDEN_SIZE=hidden_size,
|
||||
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
|
||||
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
|
||||
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_ep_gather(
|
||||
total_token_num,
|
||||
input_tensor,
|
||||
input_tensor_stride0,
|
||||
input_tensor_stride1,
|
||||
recv_topk_ids,
|
||||
recv_topk_ids_stride0,
|
||||
recv_topk_ids_stride1,
|
||||
recv_topk_weight,
|
||||
recv_topk_weight_stride0,
|
||||
recv_topk_weight_stride1,
|
||||
input_index,
|
||||
input_index_stride0,
|
||||
input_index_stride1,
|
||||
output_tensor,
|
||||
output_tensor_stride0,
|
||||
output_tensor_stride1,
|
||||
topk_num: tl.constexpr,
|
||||
BLOCK_D: tl.constexpr,
|
||||
):
|
||||
cur_block = tl.program_id(0)
|
||||
start_cur_token = tl.program_id(1)
|
||||
grid_num = tl.num_programs(1)
|
||||
|
||||
for cur_token in range(start_cur_token, total_token_num, grid_num):
|
||||
off_d = tl.arange(0, BLOCK_D)
|
||||
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
|
||||
for topk_index in range(0, topk_num):
|
||||
expert_id = tl.load(
|
||||
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
|
||||
)
|
||||
if expert_id >= 0:
|
||||
source_token_index = tl.load(
|
||||
input_index + cur_token * input_index_stride0 + topk_index
|
||||
)
|
||||
acc_weight = tl.load(
|
||||
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
|
||||
)
|
||||
tmp = tl.load(
|
||||
input_tensor
|
||||
+ source_token_index * input_tensor_stride0
|
||||
+ cur_block * BLOCK_D
|
||||
+ off_d
|
||||
)
|
||||
accumulator += tmp.to(tl.float32) * acc_weight
|
||||
|
||||
tl.store(
|
||||
output_tensor
|
||||
+ cur_token * output_tensor_stride0
|
||||
+ cur_block * BLOCK_D
|
||||
+ off_d,
|
||||
accumulator.to(output_tensor.dtype.element_ty),
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ep_gather(
|
||||
input_tensor: torch.Tensor,
|
||||
recv_topk_ids: torch.Tensor,
|
||||
recv_topk_weight: torch.Tensor,
|
||||
input_index: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
):
|
||||
BLOCK_D = 1024 # block size of quantization
|
||||
num_warps = 2
|
||||
num_tokens = output_tensor.shape[0]
|
||||
hidden_size = input_tensor.shape[1]
|
||||
assert hidden_size % BLOCK_D == 0
|
||||
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
|
||||
_fwd_kernel_ep_gather[grid](
|
||||
num_tokens,
|
||||
input_tensor,
|
||||
input_tensor.stride(0),
|
||||
input_tensor.stride(1),
|
||||
recv_topk_ids,
|
||||
recv_topk_ids.stride(0),
|
||||
recv_topk_ids.stride(1),
|
||||
recv_topk_weight,
|
||||
recv_topk_weight.stride(0),
|
||||
recv_topk_weight.stride(1),
|
||||
input_index,
|
||||
input_index.stride(0),
|
||||
input_index.stride(1),
|
||||
output_tensor,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
topk_num=recv_topk_ids.shape[1],
|
||||
num_warps=num_warps,
|
||||
BLOCK_D=BLOCK_D,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# copy from
|
||||
# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58
|
||||
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
"""
|
||||
Global memory address of TMA must be 16-byte aligned.
|
||||
Since we use column-major layout for the LHS scaling tensor,
|
||||
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
|
||||
|
||||
Arguments:
|
||||
x: original M-axis shape of the LHS scaling tensor.
|
||||
element_size: element size of the LHS scaling tensor.
|
||||
|
||||
Returns:
|
||||
M-axis shape of the LHS scaling tensor after padding.
|
||||
"""
|
||||
tma_alignment_bytes = 16
|
||||
assert tma_alignment_bytes % element_size == 0
|
||||
alignment = tma_alignment_bytes // element_size
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _tma_align_input_scale_kernel(
|
||||
input_scale_ptr,
|
||||
output_ptr,
|
||||
m,
|
||||
k_div_block_size,
|
||||
input_scale_stride_m,
|
||||
input_scale_stride_k,
|
||||
output_stride_m,
|
||||
output_stride_k,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
grid_m = tl.num_programs(0)
|
||||
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
for m_base in range(pid_m, m, grid_m):
|
||||
input_offset = (
|
||||
input_scale_ptr
|
||||
+ m_base * input_scale_stride_m
|
||||
+ k_offsets * input_scale_stride_k
|
||||
)
|
||||
input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size)
|
||||
|
||||
output_offset = (
|
||||
output_ptr + k_offsets * output_stride_k + m_base * output_stride_m
|
||||
)
|
||||
tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size)
|
||||
|
||||
|
||||
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py
|
||||
def tma_align_input_scale(input_scale: torch.Tensor):
|
||||
assert input_scale.dim() == 2
|
||||
m, k_div_block_size = input_scale.shape
|
||||
padd_m = get_tma_aligned_size(m, input_scale.element_size())
|
||||
output = torch.empty(
|
||||
(k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device
|
||||
)
|
||||
|
||||
grid_m = min(m, 8192)
|
||||
BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size)
|
||||
|
||||
_tma_align_input_scale_kernel[(grid_m,)](
|
||||
input_scale_ptr=input_scale,
|
||||
output_ptr=output,
|
||||
m=m,
|
||||
k_div_block_size=k_div_block_size,
|
||||
input_scale_stride_m=input_scale.stride(0),
|
||||
input_scale_stride_k=input_scale.stride(1),
|
||||
output_stride_m=output.stride(1), # Note: these are swapped
|
||||
output_stride_k=output.stride(0), # for column-major
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
)
|
||||
return output.t()[:m]
|
||||
|
||||
@@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
|
||||
try:
|
||||
from deep_gemm import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
)
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
use_deep_gemm = True
|
||||
except ImportError:
|
||||
@@ -20,6 +28,8 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
ep_gather,
|
||||
ep_scatter,
|
||||
gelu_and_mul_triton_kernel,
|
||||
grouped_gemm_triton,
|
||||
post_reorder_triton_kernel,
|
||||
@@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
run_moe_ep_preproess,
|
||||
silu_and_mul_masked_post_quant_fwd,
|
||||
silu_and_mul_triton_kernel,
|
||||
tma_align_input_scale,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
||||
@@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
seg_indptr: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
num_recv_tokens_per_expert: List[int],
|
||||
forward_mode: ForwardMode,
|
||||
):
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||
if resolved_deepep_mode == DeepEPMode.normal:
|
||||
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
return self.forward_deepgemm_contiguous(
|
||||
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
||||
)
|
||||
else:
|
||||
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
||||
else:
|
||||
@@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
return down_output
|
||||
|
||||
def forward_deepgemm_contiguous(
|
||||
self,
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert: List[int],
|
||||
):
|
||||
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
if num_recv_tokens_per_expert is None:
|
||||
return hidden_states_fp8.bfloat16()
|
||||
all_tokens = sum(num_recv_tokens_per_expert)
|
||||
if all_tokens <= 0:
|
||||
return hidden_states_fp8.bfloat16()
|
||||
M, K = hidden_states_fp8.size()
|
||||
N = self.w13_weight.size(1)
|
||||
scale_block_size = 128
|
||||
|
||||
gather_out = torch.empty_like(
|
||||
hidden_states_fp8,
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
input_tensor = [
|
||||
torch.empty(
|
||||
(all_tokens, K),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=hidden_states_fp8.dtype,
|
||||
),
|
||||
torch.empty(
|
||||
(all_tokens, K // 128),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
]
|
||||
m_indices = torch.empty(
|
||||
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
|
||||
)
|
||||
output_index = torch.empty_like(topk_idx)
|
||||
|
||||
num_recv_tokens_per_expert_gpu = torch.tensor(
|
||||
num_recv_tokens_per_expert,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True,
|
||||
device="cpu",
|
||||
).cuda(non_blocking=True)
|
||||
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
||||
|
||||
ep_scatter(
|
||||
hidden_states_fp8,
|
||||
hidden_states_scale,
|
||||
topk_idx,
|
||||
num_recv_tokens_per_expert_gpu,
|
||||
expert_start_loc,
|
||||
input_tensor[0],
|
||||
input_tensor[1],
|
||||
m_indices,
|
||||
output_index,
|
||||
)
|
||||
|
||||
gateup_output = torch.empty(
|
||||
(all_tokens, N),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
||||
)
|
||||
down_input = torch.empty(
|
||||
(
|
||||
all_tokens,
|
||||
N // 2,
|
||||
),
|
||||
device=gateup_output.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
silu_and_mul(gateup_output.view(-1, N), down_input)
|
||||
down_output = torch.empty(
|
||||
(all_tokens, K),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
||||
down_input, scale_block_size
|
||||
)
|
||||
down_input_scale = tma_align_input_scale(down_input_scale)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(down_input_fp8, down_input_scale),
|
||||
self.w2_weight_fp8,
|
||||
down_output,
|
||||
m_indices,
|
||||
)
|
||||
|
||||
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
||||
|
||||
return gather_out
|
||||
|
||||
def forward_deepgemm_masked(
|
||||
self,
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.utils import DeepEPMode
|
||||
|
||||
try:
|
||||
from deep_ep import Buffer
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
use_deepep = True
|
||||
except ImportError:
|
||||
use_deepep = False
|
||||
|
||||
from enum import IntEnum, auto
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -78,7 +83,6 @@ class DeepEPBuffer:
|
||||
),
|
||||
num_rdma_bytes,
|
||||
)
|
||||
|
||||
cls._buffer = Buffer(
|
||||
group,
|
||||
num_nvl_bytes,
|
||||
@@ -181,44 +185,74 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
# TODO hard code 128 block quant,use fp8 communication
|
||||
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
return hidden_states, topk_idx, topk_weights, previous_event
|
||||
|
||||
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
event,
|
||||
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
if hidden_states.shape[0] > 0:
|
||||
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
||||
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
event,
|
||||
) = self._dispatch_core(
|
||||
hidden_states, topk_idx, topk_weights, previous_event
|
||||
)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
return (
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
None,
|
||||
num_recv_tokens_per_expert_list,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
reorder_topk_ids = torch.empty(
|
||||
(0,), device=hidden_states.device, dtype=torch.int64
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
event,
|
||||
) = self._dispatch_core(
|
||||
hidden_states, topk_idx, topk_weights, previous_event
|
||||
)
|
||||
seg_indptr = torch.zeros(
|
||||
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
if hidden_states.shape[0] > 0:
|
||||
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
|
||||
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
reorder_topk_ids = torch.empty(
|
||||
(0,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
seg_indptr = torch.zeros(
|
||||
(self.num_experts + 1,),
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
masked_m = expected_m = None
|
||||
return (
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
None,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
masked_m = expected_m = None
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
def _dispatch_core(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
previous_event,
|
||||
@@ -246,7 +280,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
_, # num_recv_tokens_per_expert_list
|
||||
num_recv_tokens_per_expert_list,
|
||||
self.handle,
|
||||
event,
|
||||
) = buffer.dispatch(
|
||||
@@ -260,12 +294,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
previous_event=previous_event,
|
||||
async_finish=self.async_finish,
|
||||
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
||||
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
||||
)
|
||||
|
||||
return (
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
event,
|
||||
)
|
||||
|
||||
@@ -314,29 +350,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
if hidden_states.shape[0] > 0:
|
||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||
output = torch.empty(
|
||||
(num_tokens, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||
hidden_states,
|
||||
output,
|
||||
self.src2dst,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
output = hidden_states
|
||||
else:
|
||||
output = torch.zeros(
|
||||
(0, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
if hidden_states.shape[0] > 0:
|
||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||
output = torch.empty(
|
||||
(num_tokens, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||
hidden_states,
|
||||
output,
|
||||
self.src2dst,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
output = torch.zeros(
|
||||
(0, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
return output, previous_event
|
||||
|
||||
@@ -360,6 +399,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
|
||||
def _get_buffer(self):
|
||||
DeepEPBuffer.set_dispatch_mode_as_normal()
|
||||
|
||||
return DeepEPBuffer.get_deepep_buffer(
|
||||
self.group,
|
||||
self.hidden_size,
|
||||
@@ -426,6 +466,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
None,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
@@ -570,7 +611,8 @@ class DeepEPDispatcher:
|
||||
|
||||
def dispatch(self, *args, **kwargs) -> Tuple:
|
||||
self.dispatch_a(*args, **kwargs)
|
||||
return self.dispatch_b()
|
||||
ret = self.dispatch_b()
|
||||
return ret
|
||||
|
||||
def dispatch_a(
|
||||
self,
|
||||
@@ -593,7 +635,8 @@ class DeepEPDispatcher:
|
||||
|
||||
def combine(self, *args, **kwargs) -> Tuple:
|
||||
self.combine_a(*args, **kwargs)
|
||||
return self.combine_b()
|
||||
ret = self.combine_b()
|
||||
return ret
|
||||
|
||||
def combine_a(
|
||||
self,
|
||||
|
||||
@@ -28,6 +28,11 @@ if is_cuda():
|
||||
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
||||
_ENABLE_JIT_DEEPGEMM = True
|
||||
|
||||
|
||||
def get_enable_jit_deepgemm():
|
||||
return _ENABLE_JIT_DEEPGEMM
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
||||
|
||||
@@ -308,8 +308,8 @@ def sglang_per_token_group_quant_fp8(
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
||||
if x.shape[0] > 0:
|
||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@@ -357,6 +357,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
num_recv_tokens_per_expert,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
@@ -368,10 +369,13 @@ class DeepseekV2MoE(nn.Module):
|
||||
)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
reorder_topk_ids=reorder_topk_ids,
|
||||
seg_indptr=seg_indptr,
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
|
||||
Reference in New Issue
Block a user