[8/N] MoE Refactor: deprecate EPMoE (#11211)

This commit is contained in:
Cheng Wan
2025-10-07 21:51:41 -07:00
committed by GitHub
parent 7c3f07dbcb
commit 3c06b673af
19 changed files with 526 additions and 1808 deletions

View File

@@ -13,22 +13,18 @@ from sgl_kernel import (
from sglang.srt.layers.moe.ep_moe.kernels import (
post_reorder_triton_kernel_for_cutlass_moe,
pre_reorder_triton_kernel_for_cutlass_moe,
run_cutlass_moe_ep_preproess,
run_moe_ep_preproess,
)
def cutlass_w4a8_moe(
start_expert_id: int,
end_expert_id: int,
total_num_experts: int,
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids_: torch.Tensor,
local_topk_ids: torch.Tensor,
topk_ids: torch.Tensor,
a_strides1: torch.Tensor,
b_strides1: torch.Tensor,
c_strides1: torch.Tensor,
@@ -64,6 +60,7 @@ def cutlass_w4a8_moe(
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts, N // 512, K * 4]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The ids of each token->expert mapping.
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
@@ -83,7 +80,7 @@ def cutlass_w4a8_moe(
Returns:
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_q.dtype == torch.int8
assert w2_q.dtype == torch.int8
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
@@ -96,20 +93,21 @@ def cutlass_w4a8_moe(
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
num_experts = w1_q.size(0)
num_local_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(2) * 2 # w1_q is transposed and packed
n = w2_q.size(2) * 2 # w2_q is transposed and packed
topk = topk_ids_.size(1)
topk = topk_ids.size(1)
if apply_router_weight_on_input:
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
device = a.device
topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
_, src2dst, _ = run_cutlass_moe_ep_preproess(
local_topk_ids,
num_experts,
_, src2dst, _ = run_moe_ep_preproess(
topk_ids,
num_local_experts,
)
gateup_input = torch.empty(
@@ -122,9 +120,9 @@ def cutlass_w4a8_moe(
a,
gateup_input,
src2dst,
local_topk_ids,
topk_ids,
a1_scale,
total_num_experts,
num_local_experts,
topk,
k,
BLOCK_SIZE=512,
@@ -133,16 +131,16 @@ def cutlass_w4a8_moe(
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
# they are kept to allow for a quick switch of the permutation logic
# from the current triton kernel implementation to the cutlass-based one if needed.
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
get_cutlass_w4a8_moe_mm_data(
local_topk_ids,
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
num_experts,
num_local_experts,
n,
k,
)
@@ -195,12 +193,11 @@ def cutlass_w4a8_moe(
c2,
output,
src2dst,
local_topk_ids,
topk_ids,
topk_weights,
num_experts,
topk,
num_local_experts,
k,
0,
BLOCK_SIZE=512,
)
return output

View File

@@ -130,28 +130,30 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
@triton.jit
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
expert = tl.program_id(0)
expert_id_minus_1 = tl.program_id(0) - 1
low = 0
high = num_toks - 1
target_location = -1
while low <= high:
mid = (low + high) // 2
if tl.load(reorder_topk_ids + mid) > expert:
if tl.load(reorder_topk_ids + mid) > expert_id_minus_1:
high = mid - 1
else:
low = mid + 1
target_location = mid
tl.store(seg_indptr + expert + 1, target_location + 1)
tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1)
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
seg_indptr = torch.zeros(
num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
compute_seg_indptr_triton_kernel[(num_experts,)](
compute_seg_indptr_triton_kernel[(num_local_experts,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)
@@ -164,25 +166,6 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
return reorder_topk_ids, src2dst, seg_indptr
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
)
src2dst = torch.empty(
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
)
BLOCK_SIZE = 512
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
compute_src2dst_triton_kernel[grid](
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
)
return reorder_topk_ids, src2dst, seg_indptr
@triton.jit
def pre_reorder_triton_kernel_for_cutlass_moe(
input_ptr,
@@ -190,52 +173,13 @@ def pre_reorder_triton_kernel_for_cutlass_moe(
src2dst_ptr,
topk_ids_ptr,
a1_scales_ptr,
num_experts,
num_local_experts,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
):
OutDtype = gateup_input_ptr.dtype.element_ty
src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id != num_experts:
if a1_scales_ptr is not None:
scale = 1.0 / tl.load(a1_scales_ptr)
else:
scale = 1.0
dst_idx = tl.load(src2dst_ptr + idx)
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
out_data = (in_data * scale).to(OutDtype)
tl.store(dst_ptr + offset, out_data, mask=mask)
@triton.jit
def pre_reorder_triton_kernel(
input_ptr,
gateup_input_ptr,
src2dst_ptr,
topk_ids_ptr,
a1_scales_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
BLOCK_SIZE: tl.constexpr,
use_per_token_if_dynamic: tl.constexpr,
):
OutDtype = gateup_input_ptr.dtype.element_ty
src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk
@@ -244,15 +188,11 @@ def pre_reorder_triton_kernel(
vec = tl.arange(0, BLOCK_SIZE)
if a1_scales_ptr is not None and use_per_token_if_dynamic:
scale = 1.0 / tl.load(a1_scales_ptr + src_idx)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
if expert_id != num_local_experts:
if a1_scales_ptr is not None:
if not use_per_token_if_dynamic:
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
scale = 1.0 / tl.load(a1_scales_ptr)
else:
scale = 1.0
@@ -267,52 +207,6 @@ def pre_reorder_triton_kernel(
tl.store(dst_ptr + offset, out_data, mask=mask)
@triton.jit
def silu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# silu & mul & quantize
gate_output = gate_output * tl.sigmoid(gate_output)
gate_output = gate_output.to(InDtype)
silu_mul_output = gate_output * up_output * scale
silu_mul_output = silu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
@@ -461,70 +355,44 @@ def silu_and_mul_masked_post_quant_fwd(
@triton.jit
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def gelu_and_mul_triton_kernel(
gateup_output,
down_input,
def post_reorder_triton_kernel_for_cutlass_moe(
down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
topk,
num_local_experts,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
InDtype = down_output_ptr.dtype.element_ty
half_hidden_size = hidden_size // 2
src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
store_ptr = output_ptr + src_idx * hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
vec = tl.arange(0, BLOCK_SIZE)
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec
mask = offset < hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# gelu & mul & quantize
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
# sqrt(2/pi)
kAlpha = 0.7978845608028654
gate_output = (
0.5
* gate_output
* (
1
+ tanh(
kAlpha
* (
gate_output
+ 0.044715 * gate_output * gate_output * gate_output
)
)
)
)
gate_output = gate_output.to(InDtype)
gelu_mul_output = gate_output * up_output * scale
gelu_mul_output = gelu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id != num_local_experts:
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
sum_vec += in_data * weigh_scale
tl.store(store_ptr + offset, sum_vec, mask=mask)
@triton.jit
@@ -534,64 +402,8 @@ def post_reorder_triton_kernel(
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
dst_start,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty
src_idx_int32 = tl.program_id(0)
src_idx = src_idx_int32.to(tl.int64)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk
computed = False
store_ptr = output_ptr + src_idx * hidden_size
vec = tl.arange(0, BLOCK_SIZE)
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec
mask = offset < hidden_size
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
computed = True
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - dst_start
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
sum_vec += in_data * weigh_scale
tl.store(store_ptr + offset, sum_vec, mask=mask)
if computed == False:
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec
mask = offset < hidden_size
tl.store(
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
)
@triton.jit
def post_reorder_triton_kernel_for_cutlass_moe(
down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
num_experts,
topk,
hidden_size,
dst_start,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty
@@ -613,10 +425,9 @@ def post_reorder_triton_kernel_for_cutlass_moe(
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id != num_experts:
if expert_id > 0:
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - dst_start
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
@@ -624,232 +435,6 @@ def post_reorder_triton_kernel_for_cutlass_moe(
tl.store(store_ptr + offset, sum_vec, mask=mask)
@triton.jit
def compute_m_range(
pid,
batch_size,
seg_indptr,
weight_indices,
m_num_tiles_indptr,
BLOCK_SIZE_M: tl.constexpr,
):
idx = 0
for bs in range(batch_size):
tiles = tl.load(m_num_tiles_indptr + bs)
if pid >= tiles:
idx = bs
idx_start = tl.load(m_num_tiles_indptr + idx)
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
expert_id = tl.load(weight_indices + idx)
return m_range_start, m_range_end, expert_id
@triton.jit
def grouped_gemm_triton_kernel(
a,
b,
c,
batch_size,
N,
K,
seg_indptr,
weight_indices,
m_num_tiles_indptr,
scale_a,
scale_b,
use_fp8_w8a8: tl.constexpr,
group_n: tl.constexpr,
group_k: tl.constexpr,
a_stride_0: tl.constexpr,
b_stride_0: tl.constexpr,
b_stride_1: tl.constexpr,
as_stride_0: tl.constexpr,
as_stride_1: tl.constexpr,
bs_stride_0: tl.constexpr,
bs_stride_2: tl.constexpr,
bs_stride_1: tl.constexpr,
use_per_token_if_dynamic: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
c_dtype = c.dtype.element_ty
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
if pid_m >= total_m_block:
return
m_range_start, m_range_end, expert_id = compute_m_range(
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
)
if m_range_end - m_range_start == 0:
return
n_range_start = pid_n * BLOCK_SIZE_N
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
offs_am = tl.arange(0, BLOCK_SIZE_M)
offs_bn = tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
b_ptr = b + (
(expert_id * b_stride_0)
+ (n_range_start + offs_bn[:, None]) * b_stride_1
+ offs_k[None, :]
)
if group_k > 0 and group_n > 0:
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
offs_bsn = (n_range_start + offs_bn) // group_n
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a_tile = tl.load(
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
)
b_tile = tl.load(
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
)
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
else:
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
a_ptr += BLOCK_SIZE_K
b_ptr += BLOCK_SIZE_K
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
if use_per_token_if_dynamic:
scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None]))
else:
scale_a_value = tl.load(scale_a + expert_id)
scale_b_value = tl.load(scale_b + expert_id)
accumulator *= scale_a_value * scale_b_value
c_tile = accumulator.to(c_dtype)
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
tl.store(c_ptr, c_tile, mask=c_mask)
@triton.jit
def compute_m_num_tiles_indptr(
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
):
for bs in range(batch_size):
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
def grouped_gemm_triton(
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
batch_size: int,
weight_column_major: bool,
seg_indptr: Optional[torch.Tensor] = None,
weight_indices: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
scale_a: torch.Tensor = None,
scale_b: torch.Tensor = None,
block_shape: Optional[List[int]] = None,
c_dtype=None,
use_per_token_if_dynamic: bool = True,
):
assert weight_column_major == True # TODO: more
if use_fp8_w8a8 and block_shape is None:
assert scale_a is not None and scale_b is not None
if block_shape is not None:
a_original = a
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
a, scale_a = per_token_group_quant_fp8(a, block_k)
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
dispose_tensor(a_original)
# TODO: adjust config or tune kernel
# Reduce block size to prevent L40 shared memory overflow.
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
}
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
compute_m_num_tiles_indptr[(1,)](
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
)
if c is None:
assert c_dtype is not None
c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
grid = lambda META: (
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
)
if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic:
assert (
scale_a.shape[0] == a.shape[0]
), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}"
grouped_gemm_triton_kernel[grid](
a,
b,
c,
batch_size,
b.size(1),
b.size(2),
seg_indptr,
weight_indices,
m_num_tiles_indptr,
scale_a,
scale_b,
use_fp8_w8a8,
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
a.stride(0),
b.stride(0),
b.stride(1),
scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
use_per_token_if_dynamic,
**config,
)
return c
@triton.jit
def _fwd_kernel_ep_scatter_1(
num_recv_tokens_per_expert,
@@ -1234,7 +819,7 @@ def deepgemm_compute_src2dst_triton_kernel(
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
expert_dst_start = tl.load(seg_indptr + expert_id)
expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0))
expert_dst_offset = dst_id - expert_dst_start
dst_id = expert_id * m_max + expert_dst_offset
tl.store(src2dst + src_id, dst_id, mask=mask)
@@ -1248,10 +833,7 @@ def fill_gateup_input_triton_kernel(
gateup_input_scale_ptr,
src2dst_ptr,
topk_ids_ptr,
start_expert_id,
end_expert_id,
topk,
m_max,
hidden_size,
scale_size,
BLOCK_SIZE: tl.constexpr,
@@ -1267,10 +849,9 @@ def fill_gateup_input_triton_kernel(
vec = tl.arange(0, BLOCK_SIZE)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
if expert_id >= 0:
dst_idx_int32 = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx_int32.to(tl.int64)
dst_idx = dst_idx - start_expert_id * m_max
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + vec
@@ -1287,31 +868,31 @@ def fill_gateup_input_triton_kernel(
def moe_ep_deepgemm_preprocess(
topk_ids: torch.Tensor,
num_experts: int,
num_local_experts: int,
hidden_states: torch.Tensor,
top_k: int,
start_expert_id,
end_expert_id,
block_shape,
output_dtype: torch.dtype = torch.float8_e4m3fn,
):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
seg_indptr = torch.zeros(
num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32)
compute_seg_indptr_triton_kernel[(num_experts,)](
compute_seg_indptr_triton_kernel[(num_local_experts + 1,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m)
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
m_max = (hidden_states.size(0) + 255) // 256 * 256
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
m_max = (hidden_states.size(0) // 256 + 1) * 256
expected_m = (topk_ids.numel() - 1) // num_local_experts + 1
gateup_input = torch.empty(
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
(num_local_experts, m_max, hidden_states.size(1)),
device=hidden_states.device,
dtype=output_dtype,
)
@@ -1330,6 +911,8 @@ def moe_ep_deepgemm_preprocess(
block_shape = [128, 128]
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
# TODO: fuse this with the preprocess
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
gateup_input_scale = torch.empty(
@@ -1345,18 +928,14 @@ def moe_ep_deepgemm_preprocess(
gateup_input_scale,
src2dst,
topk_ids,
start_expert_id,
end_expert_id,
top_k,
m_max,
hidden_states.size(1),
scale.size(1),
BLOCK_SIZE=1024,
)
return (
m_max,
masked_m[start_expert_id : (end_expert_id + 1)],
masked_m,
expected_m,
src2dst,
gateup_input,

View File

@@ -1,14 +1,10 @@
from __future__ import annotations
import logging
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
import triton
import triton.language as tl
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
@@ -18,13 +14,10 @@ from sglang.srt.layers.moe import (
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
moe_ep_deepgemm_preprocess,
post_reorder_triton_kernel,
silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config
@@ -36,19 +29,10 @@ from sglang.srt.layers.quantization.modelopt_quant import (
CUTEDSL_MOE_NVFP4_DISPATCH,
ModelOptNvFp4FusedMoEMethod,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.offloader import get_offloader
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import (
ceil_div,
dispose_tensor,
get_bool_env_var,
get_int_env_var,
is_cuda,
is_hip,
is_npu,
)
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
@@ -72,275 +56,7 @@ if _use_aiter:
logger = logging.getLogger(__name__)
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
temp = x.to(torch.float32).view(torch.int32)
exp = torch.bitwise_right_shift(temp, 23)
mant = torch.bitwise_and(temp, 0x7FFFFF)
is_ru = torch.logical_and(
torch.logical_and((mant > 0), (exp != 0xFE)),
~torch.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = torch.where(is_ru, exp + 1, exp)
new_x = exp.to(torch.uint8).view(torch.int)
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
class EPMoE(FusedMoE):
"""
MoE Expert Parallel Impl
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
with_bias: bool = False,
):
super().__init__(
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_fused_shared_experts=num_fused_shared_experts,
layer_id=layer_id,
top_k=top_k,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
with_bias=with_bias,
)
self.intermediate_size = intermediate_size
if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
else:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, topk_output)
else:
return super().forward(hidden_states, topk_output)
def forward_deepgemm(
self,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
):
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
topk_weights, topk_ids, _ = topk_output
if not self.use_block_quant:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size = 128
w13_weight_scale_n = 2 * (
(self.intermediate_size + scale_block_size - 1) // scale_block_size
)
w13_weight_scale_k = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w13_weight_scale = (
self.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_weight_scale_k, dim=2)
)
self.w13_weight_fp8 = (
self.w13_weight,
w13_weight_scale,
)
w2_weight_scale_n = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w2_weight_scale_k = (
self.intermediate_size + scale_block_size - 1
) // scale_block_size
w2_weight_scale = (
self.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_weight_scale_k, dim=2)
)
self.w2_weight_fp8 = (
self.w2_weight,
w2_weight_scale,
)
# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)
dispose_tensor(hidden_states)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = gateup_input_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0
gateup_input_fp8 = (
gateup_input,
(
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
gateup_input_scale
)
),
)
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
gateup_input_fp8,
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
)
del gateup_input
del gateup_input_fp8
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
(
down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
),
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
)
del down_input
del down_input_fp8
# PostReorder
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states_shape[1],
m_max * self.start_expert_id,
BLOCK_SIZE=512,
)
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return output
class DeepEPMoE(EPMoE):
class DeepEPMoE(FusedMoE):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
"""
@@ -374,6 +90,15 @@ class DeepEPMoE(EPMoE):
activation=activation,
routed_scaling_factor=routed_scaling_factor,
)
if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
else:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.deepep_mode = get_deepep_mode()
# TODO: move to the beginning of the file
@@ -567,7 +292,6 @@ class DeepEPMoE(EPMoE):
N = self.w13_weight.size(1)
scale_block_size = 128
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
w13_weight_fp8 = (
self.w13_weight,
(
@@ -988,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
return FlashInferFusedMoE
if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE
if get_moe_expert_parallel_world_size() > 1:
return EPMoE
return FusedMoE

View File

@@ -156,8 +156,7 @@ class FusedMoE(torch.nn.Module):
self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
if self.moe_ep_size > 1:
# TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1

View File

@@ -0,0 +1,304 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
import torch
from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
MoeRunnerCore,
RunnerInput,
RunnerOutput,
register_post_permute,
register_pre_permute,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import dispose_tensor
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
temp = x.to(torch.float32).view(torch.int32)
exp = torch.bitwise_right_shift(temp, 23)
mant = torch.bitwise_and(temp, 0x7FFFFF)
is_ru = torch.logical_and(
torch.logical_and((mant > 0), (exp != 0xFE)),
~torch.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = torch.where(is_ru, exp + 1, exp)
new_x = exp.to(torch.uint8).view(torch.int)
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
@dataclass
class DeepGemmRunnerInput(RunnerInput):
hidden_states: torch.Tensor
hidden_states_scale: torch.Tensor
masked_m: torch.Tensor
expected_m: int
use_masked_gemm: bool
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@dataclass
class DeepGemmRunnerOutput(RunnerOutput):
hidden_states: torch.Tensor
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@dataclass
class DeepGemmMoeQuantInfo(MoeQuantInfo):
w13_weight: torch.Tensor
w2_weight: torch.Tensor
use_fp8: bool
w13_scale: Optional[torch.Tensor] = None
w2_scale: Optional[torch.Tensor] = None
block_shape: Optional[List[int]] = None
class DeepGemmRunnerCore(MoeRunnerCore):
def __init__(self, config: MoeRunnerConfig):
super().__init__(config)
assert self.config.activation == "silu"
def run(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> DeepGemmRunnerOutput:
if runner_input.use_masked_gemm:
hidden_states = self._run_masked_gemm(
runner_input,
quant_info,
running_state,
)
else:
hidden_states = self._run_contiguous_gemm(
runner_input,
quant_info,
running_state,
)
return DeepGemmRunnerOutput(hidden_states=hidden_states)
def _run_masked_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale
masked_m = runner_input.masked_m
expected_m = runner_input.expected_m
w13_weight = quant_info.w13_weight
w2_weight = quant_info.w2_weight
w13_scale = quant_info.w13_scale
w2_scale = quant_info.w2_scale
hidden_states_device = running_state["hidden_states_device"]
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = hidden_states_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
else:
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
hidden_states_scale
)
num_groups, m, k = hidden_states.shape
n = w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(hidden_states, hidden_states_scale),
(w13_weight, w13_scale),
gateup_output,
masked_m,
expected_m,
)
dispose_tensor(hidden_states)
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = w2_weight.shape[1]
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
down_input_scale
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(down_input, down_input_scale),
(w2_weight, w2_scale),
down_output,
masked_m,
expected_m,
)
del down_input
return down_output
def _run_contiguous_gemm(
self,
runner_input: DeepGemmRunnerInput,
quant_info: DeepGemmMoeQuantInfo,
running_state: dict,
) -> torch.Tensor:
pass
@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.DEEP_GEMM
@register_pre_permute("standard", "deep_gemm")
def pre_permute_standard_to_deep_gemm(
dispatch_output: StandardDispatchOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> DeepGemmRunnerInput:
from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
hidden_states, topk_output = dispatch_output
topk_weights, topk_ids, _ = topk_output
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
hidden_states_ref = hidden_states
topk_weights, topk_ids = topk_weights, topk_ids
# PreReorder
masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
runner_config.num_local_experts,
hidden_states,
runner_config.top_k,
quant_info.block_shape,
)
)
dispose_tensor(hidden_states_ref)
running_state["topk_ids"] = topk_ids
running_state["topk_weights"] = topk_weights
running_state["hidden_states_shape"] = hidden_states_shape
running_state["hidden_states_dtype"] = hidden_states_dtype
running_state["hidden_states_device"] = hidden_states_device
running_state["src2dst"] = src2dst
return DeepGemmRunnerInput(
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
masked_m=masked_m,
expected_m=expected_m,
use_masked_gemm=True,
)
@register_post_permute("deep_gemm", "standard")
def post_permute_deep_gemm_to_standard(
runner_output: DeepGemmRunnerOutput,
quant_info: DeepGemmMoeQuantInfo,
runner_config: MoeRunnerConfig,
running_state: dict,
) -> StandardCombineInput:
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
hidden_states_shape = running_state["hidden_states_shape"]
hidden_states_dtype = running_state["hidden_states_dtype"]
hidden_states_device = running_state["hidden_states_device"]
src2dst = running_state["src2dst"]
topk_ids = running_state["topk_ids"]
topk_weights = running_state["topk_weights"]
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
runner_output.hidden_states,
output,
src2dst,
topk_ids,
topk_weights,
runner_config.top_k,
hidden_states_shape[1],
BLOCK_SIZE=512,
)
dispose_tensor(runner_output.hidden_states)
if runner_config.routed_scaling_factor is not None:
output *= runner_config.routed_scaling_factor
return StandardCombineInput(
hidden_states=output,
)

View File

@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
MoeRunnerConfig,
PermuteMethodPool,
)
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
@@ -30,6 +31,8 @@ class MoeRunner:
if runner_backend.is_triton():
self.runner_core = TritonRunnerCore(config)
elif runner_backend.is_deep_gemm():
self.runner_core = DeepGemmRunnerCore(config)
else:
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")

View File

@@ -44,6 +44,7 @@ class MoeA2ABackend(Enum):
class MoeRunnerBackend(Enum):
AUTO = "auto"
DEEP_GEMM = "deep_gemm"
TRITON = "triton"
TRITON_KERNEL = "triton_kernel"
FLASHINFER_TRTLLM = "flashinfer_trtllm"
@@ -54,6 +55,9 @@ class MoeRunnerBackend(Enum):
def is_auto(self):
return self == MoeRunnerBackend.AUTO
def is_deep_gemm(self):
return self == MoeRunnerBackend.DEEP_GEMM
def is_triton(self):
return self == MoeRunnerBackend.TRITON
@@ -147,7 +151,9 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
def get_moe_runner_backend() -> MoeRunnerBackend:
global MOE_RUNNER_BACKEND
if MOE_RUNNER_BACKEND is None:
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
logger.warning(
"MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected"
)
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
return MOE_RUNNER_BACKEND

View File

@@ -31,8 +31,8 @@ except ImportError:
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
from sglang.srt.layers.parameter import (
BlockQuantScaleParameter,
ModelWeightParameter,
@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
from sglang.srt.layers.moe.utils import (
get_moe_a2a_backend,
get_moe_runner_backend,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
moe_runner_backend = get_moe_runner_backend()
if moe_runner_backend.is_auto():
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and get_moe_a2a_backend().is_deepep()
):
moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
else:
moe_runner_backend = MoeRunnerBackend.TRITON
if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
else:
# TODO(cwan): refactor other backends
pass
def apply(
self,
@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
return StandardCombineInput(hidden_states=output)
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w13_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
if self.runner.runner_backend.is_deep_gemm():
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
if self.block_quant:
block_shape = self.quant_config.weight_block_size
w13_scale = layer.w13_weight_scale_inv
w2_scale = layer.w2_weight_scale_inv
else:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size = 128
block_shape = [scale_block_size, scale_block_size]
w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
w13_scale = (
layer.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_scale_k, dim=2)
)
w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
w2_scale = (
layer.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_scale_k, dim=2)
)
quant_info = DeepGemmMoeQuantInfo(
w13_weight=w13_weight,
w2_weight=w2_weight,
use_fp8=True,
w13_scale=w13_scale,
w2_scale=w2_scale,
block_shape=block_shape,
)
elif self.runner.runner_backend.is_triton():
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w13_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
else:
raise NotImplementedError(
"Unsupported runner backend: %s" % self.runner.runner_backend
)
return self.runner.run(dispatch_output, quant_info)
def apply_with_router_logits(

View File

@@ -21,7 +21,6 @@ from sglang.srt.utils import is_npu, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
@@ -94,9 +93,7 @@ class W4AFp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.managers.schedule_batch import global_server_args_dict
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
@@ -133,7 +130,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def create_weights(
self,
layer: EPMoE,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
@@ -292,7 +289,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: EPMoE,
layer: Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
@@ -303,18 +300,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids
if get_moe_expert_parallel_world_size() > 1:
local_topk_ids = torch.where(
topk_ids == -1,
layer.num_experts,
topk_ids,
)
output = cutlass_w4a8_moe(
layer.start_expert_id,
layer.end_expert_id,
layer.num_experts,
x,
layer.w13_weight,
layer.w2_weight,
@@ -322,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale_inv,
topk_weights,
topk_ids,
local_topk_ids,
self.a_strides1,
self.b_strides1,
self.c_strides1,

View File

@@ -49,7 +49,6 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.router import fused_moe_router_shim
from sglang.srt.layers.moe.topk import TopK
@@ -176,17 +175,7 @@ class Grok1MoE(nn.Module):
custom_routing_function=custom_routing_function,
)
kwargs = {}
if get_moe_expert_parallel_world_size() > 1:
MoEImpl = EPMoE
else:
MoEImpl = FusedMoE
kwargs["reduce_results"] = reduce_results
kwargs["use_presharded_weights"] = use_presharded_weights
kwargs["inplace"] = inplace
kwargs["no_combine"] = no_combine
self.experts = MoEImpl(
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
layer_id=layer_id,
@@ -195,7 +184,10 @@ class Grok1MoE(nn.Module):
params_dtype=params_dtype,
quant_config=quant_config,
activation="gelu",
**kwargs,
reduce_results=reduce_results,
use_presharded_weights=use_presharded_weights,
inplace=inplace,
no_combine=no_combine,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@@ -36,7 +36,6 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -94,8 +93,7 @@ class MixtralMoE(nn.Module):
renormalize=True,
)
MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE
self.experts = MoEImpl(
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
layer_id=layer_id,

View File

@@ -121,6 +121,17 @@ NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
MOE_RUNNER_BACKEND_CHOICES = [
"auto",
"deep_gemm",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
"flashinfer_cutedsl",
]
# Allow external code to add more choices
def add_load_format_choices(choices):
@@ -143,6 +154,10 @@ def add_grammar_backend_choices(choices):
GRAMMAR_BACKEND_CHOICES.extend(choices)
def add_moe_runner_backend_choices(choices):
MOE_RUNNER_BACKEND_CHOICES.extend(choices)
def add_deterministic_attention_backend_choices(choices):
DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)
@@ -315,14 +330,7 @@ class ServerArgs:
# Expert parallelism
ep_size: int = 1
moe_a2a_backend: Literal["none", "deepep"] = "none"
moe_runner_backend: Literal[
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
] = "auto"
moe_runner_backend: str = "auto"
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
@@ -2191,15 +2199,7 @@ class ServerArgs:
parser.add_argument(
"--moe-runner-backend",
type=str,
choices=[
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
"flashinfer_cutedsl",
],
choices=MOE_RUNNER_BACKEND_CHOICES,
default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.",
)

View File

@@ -1,358 +0,0 @@
import itertools
import random
import unittest
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.test.test_utils import CustomTestCase
# For test
def ep_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
topk_config: TopKConfig,
# ep config
num_experts: int = 256,
fp8_dtype: torch.types = torch.float8_e4m3fn,
num_experts_per_partition: int = 128,
start_expert_id: int = 0,
end_expert_id: int = 127,
use_fp8_w8a8: bool = False,
w1_scale_inv: Optional[torch.Tensor] = None,
w2_scale_inv: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
):
use_blockwise_fp8 = block_shape is not None
top_k = topk_config.top_k
topk_output = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=topk_config,
)
topk_weights, topk_ids, _ = topk_output
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
gateup_input = torch.empty(
(int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
device=hidden_states.device,
dtype=(
fp8_dtype
if (use_fp8_w8a8 and not use_blockwise_fp8)
else hidden_states.dtype
),
)
if use_fp8_w8a8 and not use_blockwise_fp8:
max_value = (
torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
)
w1_input_scale = max_value / torch.finfo(fp8_dtype).max
else:
w1_input_scale = None
# PreReorder
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
src2dst,
topk_ids,
w1_input_scale,
start_expert_id,
end_expert_id,
top_k,
hidden_states.shape[1],
BLOCK_SIZE=512,
use_per_token_if_dynamic=True,
)
seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
weight_indices_cur_rank = torch.arange(
0,
num_experts_per_partition,
device=hidden_states.device,
dtype=torch.int64,
)
# GroupGemm-0
gateup_output = torch.empty(
gateup_input.shape[0],
w1.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
gateup_output = grouped_gemm_triton(
a=gateup_input,
b=w1,
c=gateup_output,
batch_size=num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=use_fp8_w8a8,
scale_a=w1_input_scale,
scale_b=w1_scale_inv,
block_shape=block_shape,
)
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
fp8_dtype
if (use_fp8_w8a8 and not use_blockwise_fp8)
else hidden_states.dtype
),
)
if use_fp8_w8a8 and not use_blockwise_fp8:
w2_input_scale = torch.ones(
num_experts_per_partition,
dtype=torch.float32,
device=hidden_states.device,
)
else:
w2_input_scale = None
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
w2_input_scale,
start_expert_id,
end_expert_id,
BLOCK_SIZE=512,
)
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
w2.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
down_output = grouped_gemm_triton(
a=down_input,
b=w2,
c=down_output,
batch_size=num_experts_per_partition,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=use_fp8_w8a8,
scale_a=w2_input_scale,
scale_b=w2_scale_inv,
block_shape=block_shape,
)
# PostReorder
output = torch.empty_like(hidden_states)
post_reorder_triton_kernel[(hidden_states.size(0),)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
top_k,
hidden_states.size(1),
0,
BLOCK_SIZE=512,
)
return output
# test util
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise quantization.
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
and the block size.
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
Note only float8 is supported for now.
"""
# process 3D tensor
if x_q_block.dim() == 3:
batch_size = x_q_block.size(0)
return torch.stack(
[block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
)
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
x_dq_block_tiles = [
[
x_dq_block[
j * block_n : min((j + 1) * block_n, n),
i * block_k : min((i + 1) * block_k, k),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
return x_dq_block
class TestW8A8BlockFP8EPMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 222, 1024, 2048]
N = [128, 1024, 2048]
K = [256, 4096, 5120]
E = [8, 16]
ep_size = [2, 4]
TOP_KS = [2, 4]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w8a8_block_fp8_ep_moe(
self, M, N, K, E, ep_size, topk, block_size, dtype, seed
):
torch.manual_seed(seed)
random.seed(seed)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = (
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
* factor_for_scale
)
w2_s = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
* factor_for_scale
)
w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)
score = torch.randn((M, E), dtype=dtype)
num_experts_per_partition = E // ep_size
cur_rank = random.randint(0, ep_size - 1)
start_id = cur_rank * num_experts_per_partition
end_id = start_id + num_experts_per_partition - 1
topk_config = TopKConfig(
top_k=topk,
renormalize=False,
)
with torch.inference_mode():
out = ep_moe(
hidden_states=a,
w1=w1,
w2=w2,
router_logits=score,
topk_config=topk_config,
use_fp8_w8a8=True,
w1_scale_inv=w1_s,
w2_scale_inv=w2_s,
block_shape=block_size,
num_experts=E,
num_experts_per_partition=num_experts_per_partition,
start_expert_id=start_id,
end_expert_id=end_id,
)
ref_out = ep_moe(
hidden_states=a,
w1=w1_ref,
w2=w2_ref,
router_logits=score,
topk_config=topk_config,
use_fp8_w8a8=False,
w1_scale_inv=None,
w2_scale_inv=None,
block_shape=None,
num_experts=E,
num_experts_per_partition=num_experts_per_partition,
start_expert_id=start_id,
end_expert_id=end_id,
)
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
/ (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
< 0.06
)
def test_w8a8_block_fp8_ep_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.ep_size,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
ep_size=params[4],
topk=params[5],
block_size=params[6],
dtype=params[7],
seed=params[8],
):
self._w8a8_block_fp8_ep_moe(*params)
torch.cuda.empty_cache()
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -120,7 +120,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
)
topk_weights, topk_ids, _ = topk_output
expert_map = torch.arange(E, dtype=torch.int32, device=device)
expert_map[local_e:] = E
expert_map[local_e:] = -1
output = cutlass_moe(
a,
@@ -138,9 +138,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
c_strides2,
s_strides13,
s_strides2,
0,
local_e - 1,
E,
local_e,
a1_scale,
a2_scale,
expert_map,
@@ -178,7 +176,7 @@ def cutlass_moe(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids_: torch.Tensor,
topk_ids: torch.Tensor,
a_strides1: torch.Tensor,
b_strides1: torch.Tensor,
c_strides1: torch.Tensor,
@@ -187,40 +185,32 @@ def cutlass_moe(
c_strides2: torch.Tensor,
s_strides13: torch.Tensor,
s_strides2: torch.Tensor,
start_expert_id: int,
end_expert_id: int,
E: int,
num_local_experts: int,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
):
local_topk_ids = topk_ids_
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
topk_ids = expert_map[topk_ids]
device = a.device
local_num_experts = end_expert_id - start_expert_id + 1
expert_offsets = torch.empty(
(local_num_experts + 1), dtype=torch.int32, device=device
(num_local_experts + 1), dtype=torch.int32, device=device
)
problem_sizes1 = torch.empty(
(local_num_experts, 3), dtype=torch.int32, device=device
(num_local_experts, 3), dtype=torch.int32, device=device
)
problem_sizes2 = torch.empty(
(local_num_experts, 3), dtype=torch.int32, device=device
(num_local_experts, 3), dtype=torch.int32, device=device
)
return cutlass_w4a8_moe(
start_expert_id,
end_expert_id,
E,
a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids_,
local_topk_ids,
topk_ids,
a_strides1,
b_strides1,
c_strides1,