[8/N] MoE Refactor: deprecate EPMoE (#11211)
This commit is contained in:
@@ -13,22 +13,18 @@ from sgl_kernel import (
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
post_reorder_triton_kernel_for_cutlass_moe,
|
||||
pre_reorder_triton_kernel_for_cutlass_moe,
|
||||
run_cutlass_moe_ep_preproess,
|
||||
run_moe_ep_preproess,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_w4a8_moe(
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
total_num_experts: int,
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids_: torch.Tensor,
|
||||
local_topk_ids: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
a_strides1: torch.Tensor,
|
||||
b_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
@@ -64,6 +60,7 @@ def cutlass_w4a8_moe(
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts, N // 512, K * 4]
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- topk_ids (torch.Tensor): The ids of each token->expert mapping.
|
||||
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
||||
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
||||
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
||||
@@ -83,7 +80,7 @@ def cutlass_w4a8_moe(
|
||||
Returns:
|
||||
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
||||
"""
|
||||
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert w1_q.dtype == torch.int8
|
||||
assert w2_q.dtype == torch.int8
|
||||
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
||||
@@ -96,20 +93,21 @@ def cutlass_w4a8_moe(
|
||||
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
||||
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
||||
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
||||
num_experts = w1_q.size(0)
|
||||
num_local_experts = w1_q.size(0)
|
||||
m = a.size(0)
|
||||
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
||||
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
||||
topk = topk_ids_.size(1)
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
|
||||
|
||||
device = a.device
|
||||
topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
|
||||
|
||||
_, src2dst, _ = run_cutlass_moe_ep_preproess(
|
||||
local_topk_ids,
|
||||
num_experts,
|
||||
_, src2dst, _ = run_moe_ep_preproess(
|
||||
topk_ids,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
gateup_input = torch.empty(
|
||||
@@ -122,9 +120,9 @@ def cutlass_w4a8_moe(
|
||||
a,
|
||||
gateup_input,
|
||||
src2dst,
|
||||
local_topk_ids,
|
||||
topk_ids,
|
||||
a1_scale,
|
||||
total_num_experts,
|
||||
num_local_experts,
|
||||
topk,
|
||||
k,
|
||||
BLOCK_SIZE=512,
|
||||
@@ -133,16 +131,16 @@ def cutlass_w4a8_moe(
|
||||
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
|
||||
# they are kept to allow for a quick switch of the permutation logic
|
||||
# from the current triton kernel implementation to the cutlass-based one if needed.
|
||||
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
get_cutlass_w4a8_moe_mm_data(
|
||||
local_topk_ids,
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
a_map,
|
||||
c_map,
|
||||
num_experts,
|
||||
num_local_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
@@ -195,12 +193,11 @@ def cutlass_w4a8_moe(
|
||||
c2,
|
||||
output,
|
||||
src2dst,
|
||||
local_topk_ids,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
num_experts,
|
||||
topk,
|
||||
num_local_experts,
|
||||
k,
|
||||
0,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -130,28 +130,30 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
|
||||
|
||||
@triton.jit
|
||||
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
||||
expert = tl.program_id(0)
|
||||
expert_id_minus_1 = tl.program_id(0) - 1
|
||||
low = 0
|
||||
high = num_toks - 1
|
||||
target_location = -1
|
||||
while low <= high:
|
||||
mid = (low + high) // 2
|
||||
|
||||
if tl.load(reorder_topk_ids + mid) > expert:
|
||||
if tl.load(reorder_topk_ids + mid) > expert_id_minus_1:
|
||||
high = mid - 1
|
||||
else:
|
||||
low = mid + 1
|
||||
target_location = mid
|
||||
tl.store(seg_indptr + expert + 1, target_location + 1)
|
||||
tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1)
|
||||
|
||||
|
||||
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
||||
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int):
|
||||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||
|
||||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
||||
seg_indptr = torch.zeros(
|
||||
num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
|
||||
)
|
||||
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
||||
|
||||
compute_seg_indptr_triton_kernel[(num_experts,)](
|
||||
compute_seg_indptr_triton_kernel[(num_local_experts,)](
|
||||
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
||||
)
|
||||
|
||||
@@ -164,25 +166,6 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
||||
return reorder_topk_ids, src2dst, seg_indptr
|
||||
|
||||
|
||||
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
|
||||
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
|
||||
|
||||
seg_indptr = torch.zeros(
|
||||
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
|
||||
)
|
||||
src2dst = torch.empty(
|
||||
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
BLOCK_SIZE = 512
|
||||
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
|
||||
compute_src2dst_triton_kernel[grid](
|
||||
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
|
||||
)
|
||||
|
||||
return reorder_topk_ids, src2dst, seg_indptr
|
||||
|
||||
|
||||
@triton.jit
|
||||
def pre_reorder_triton_kernel_for_cutlass_moe(
|
||||
input_ptr,
|
||||
@@ -190,52 +173,13 @@ def pre_reorder_triton_kernel_for_cutlass_moe(
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
a1_scales_ptr,
|
||||
num_experts,
|
||||
num_local_experts,
|
||||
topk,
|
||||
hidden_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
OutDtype = gateup_input_ptr.dtype.element_ty
|
||||
|
||||
src_idx = tl.program_id(0)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
|
||||
src_ptr = input_ptr + src_idx * hidden_size
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id != num_experts:
|
||||
if a1_scales_ptr is not None:
|
||||
scale = 1.0 / tl.load(a1_scales_ptr)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
dst_idx = tl.load(src2dst_ptr + idx)
|
||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < hidden_size
|
||||
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
||||
out_data = (in_data * scale).to(OutDtype)
|
||||
tl.store(dst_ptr + offset, out_data, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def pre_reorder_triton_kernel(
|
||||
input_ptr,
|
||||
gateup_input_ptr,
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
a1_scales_ptr,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
use_per_token_if_dynamic: tl.constexpr,
|
||||
):
|
||||
OutDtype = gateup_input_ptr.dtype.element_ty
|
||||
|
||||
src_idx_int32 = tl.program_id(0)
|
||||
src_idx = src_idx_int32.to(tl.int64)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
@@ -244,15 +188,11 @@ def pre_reorder_triton_kernel(
|
||||
|
||||
vec = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
if a1_scales_ptr is not None and use_per_token_if_dynamic:
|
||||
scale = 1.0 / tl.load(a1_scales_ptr + src_idx)
|
||||
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
if expert_id != num_local_experts:
|
||||
if a1_scales_ptr is not None:
|
||||
if not use_per_token_if_dynamic:
|
||||
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
|
||||
scale = 1.0 / tl.load(a1_scales_ptr)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
@@ -267,52 +207,6 @@ def pre_reorder_triton_kernel(
|
||||
tl.store(dst_ptr + offset, out_data, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def silu_and_mul_triton_kernel(
|
||||
gateup_output,
|
||||
down_input,
|
||||
hidden_size,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
InDtype = gateup_output.dtype.element_ty
|
||||
OutDtype = down_input.dtype.element_ty
|
||||
|
||||
half_hidden_size = hidden_size // 2
|
||||
|
||||
pid = tl.program_id(0)
|
||||
expert_id = tl.load(reorder_topk_ids + pid)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
gateup_output_ptr = gateup_output + pid * hidden_size
|
||||
gate_output_ptr = gateup_output_ptr
|
||||
up_output_ptr = gateup_output_ptr + half_hidden_size
|
||||
down_input_ptr = down_input + pid * half_hidden_size
|
||||
|
||||
if scales is not None:
|
||||
scale = tl.load(scales + expert_id - start_expert_id)
|
||||
scale = (1 / scale).to(InDtype)
|
||||
else:
|
||||
scale = 1
|
||||
|
||||
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < half_hidden_size
|
||||
|
||||
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
||||
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
||||
|
||||
# silu & mul & quantize
|
||||
gate_output = gate_output * tl.sigmoid(gate_output)
|
||||
gate_output = gate_output.to(InDtype)
|
||||
|
||||
silu_mul_output = gate_output * up_output * scale
|
||||
silu_mul_output = silu_mul_output.to(OutDtype)
|
||||
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
||||
|
||||
|
||||
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
|
||||
@triton.jit
|
||||
def _silu_and_mul_post_quant_kernel(
|
||||
@@ -461,70 +355,44 @@ def silu_and_mul_masked_post_quant_fwd(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gelu_and_mul_triton_kernel(
|
||||
gateup_output,
|
||||
down_input,
|
||||
def post_reorder_triton_kernel_for_cutlass_moe(
|
||||
down_output_ptr,
|
||||
output_ptr,
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
topk_weights_ptr,
|
||||
topk,
|
||||
num_local_experts,
|
||||
hidden_size,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
InDtype = gateup_output.dtype.element_ty
|
||||
OutDtype = down_input.dtype.element_ty
|
||||
InDtype = down_output_ptr.dtype.element_ty
|
||||
|
||||
half_hidden_size = hidden_size // 2
|
||||
src_idx_int32 = tl.program_id(0)
|
||||
src_idx = src_idx_int32.to(tl.int64)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
||||
|
||||
pid = tl.program_id(0)
|
||||
expert_id = tl.load(reorder_topk_ids + pid)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
gateup_output_ptr = gateup_output + pid * hidden_size
|
||||
gate_output_ptr = gateup_output_ptr
|
||||
up_output_ptr = gateup_output_ptr + half_hidden_size
|
||||
down_input_ptr = down_input + pid * half_hidden_size
|
||||
store_ptr = output_ptr + src_idx * hidden_size
|
||||
|
||||
if scales is not None:
|
||||
scale = tl.load(scales + expert_id - start_expert_id)
|
||||
scale = (1 / scale).to(InDtype)
|
||||
else:
|
||||
scale = 1
|
||||
vec = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < half_hidden_size
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + vec
|
||||
mask = offset < hidden_size
|
||||
|
||||
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
||||
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
||||
|
||||
# gelu & mul & quantize
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
||||
# sqrt(2/pi)
|
||||
kAlpha = 0.7978845608028654
|
||||
gate_output = (
|
||||
0.5
|
||||
* gate_output
|
||||
* (
|
||||
1
|
||||
+ tanh(
|
||||
kAlpha
|
||||
* (
|
||||
gate_output
|
||||
+ 0.044715 * gate_output * gate_output * gate_output
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
gate_output = gate_output.to(InDtype)
|
||||
|
||||
gelu_mul_output = gate_output * up_output * scale
|
||||
gelu_mul_output = gelu_mul_output.to(OutDtype)
|
||||
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
|
||||
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id != num_local_experts:
|
||||
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
||||
dst_idx = dst_idx_int32.to(tl.int64)
|
||||
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
||||
load_ptr = down_output_ptr + dst_idx * hidden_size
|
||||
in_data = tl.load(load_ptr + offset, mask=mask)
|
||||
sum_vec += in_data * weigh_scale
|
||||
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -534,64 +402,8 @@ def post_reorder_triton_kernel(
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
topk_weights_ptr,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
dst_start,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
InDtype = down_output_ptr.dtype.element_ty
|
||||
|
||||
src_idx_int32 = tl.program_id(0)
|
||||
src_idx = src_idx_int32.to(tl.int64)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
||||
|
||||
computed = False
|
||||
store_ptr = output_ptr + src_idx * hidden_size
|
||||
|
||||
vec = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + vec
|
||||
mask = offset < hidden_size
|
||||
|
||||
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
computed = True
|
||||
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
||||
dst_idx = dst_idx_int32.to(tl.int64)
|
||||
dst_idx = dst_idx - dst_start
|
||||
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
||||
load_ptr = down_output_ptr + dst_idx * hidden_size
|
||||
in_data = tl.load(load_ptr + offset, mask=mask)
|
||||
sum_vec += in_data * weigh_scale
|
||||
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
||||
|
||||
if computed == False:
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + vec
|
||||
mask = offset < hidden_size
|
||||
tl.store(
|
||||
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def post_reorder_triton_kernel_for_cutlass_moe(
|
||||
down_output_ptr,
|
||||
output_ptr,
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
topk_weights_ptr,
|
||||
num_experts,
|
||||
topk,
|
||||
hidden_size,
|
||||
dst_start,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
InDtype = down_output_ptr.dtype.element_ty
|
||||
@@ -613,10 +425,9 @@ def post_reorder_triton_kernel_for_cutlass_moe(
|
||||
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id != num_experts:
|
||||
if expert_id > 0:
|
||||
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
||||
dst_idx = dst_idx_int32.to(tl.int64)
|
||||
dst_idx = dst_idx - dst_start
|
||||
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
||||
load_ptr = down_output_ptr + dst_idx * hidden_size
|
||||
in_data = tl.load(load_ptr + offset, mask=mask)
|
||||
@@ -624,232 +435,6 @@ def post_reorder_triton_kernel_for_cutlass_moe(
|
||||
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_m_range(
|
||||
pid,
|
||||
batch_size,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
m_num_tiles_indptr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
):
|
||||
idx = 0
|
||||
for bs in range(batch_size):
|
||||
tiles = tl.load(m_num_tiles_indptr + bs)
|
||||
if pid >= tiles:
|
||||
idx = bs
|
||||
|
||||
idx_start = tl.load(m_num_tiles_indptr + idx)
|
||||
|
||||
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
|
||||
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
|
||||
expert_id = tl.load(weight_indices + idx)
|
||||
return m_range_start, m_range_end, expert_id
|
||||
|
||||
|
||||
@triton.jit
|
||||
def grouped_gemm_triton_kernel(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
batch_size,
|
||||
N,
|
||||
K,
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
m_num_tiles_indptr,
|
||||
scale_a,
|
||||
scale_b,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
a_stride_0: tl.constexpr,
|
||||
b_stride_0: tl.constexpr,
|
||||
b_stride_1: tl.constexpr,
|
||||
as_stride_0: tl.constexpr,
|
||||
as_stride_1: tl.constexpr,
|
||||
bs_stride_0: tl.constexpr,
|
||||
bs_stride_2: tl.constexpr,
|
||||
bs_stride_1: tl.constexpr,
|
||||
use_per_token_if_dynamic: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
c_dtype = c.dtype.element_ty
|
||||
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
|
||||
if pid_m >= total_m_block:
|
||||
return
|
||||
|
||||
m_range_start, m_range_end, expert_id = compute_m_range(
|
||||
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
|
||||
)
|
||||
if m_range_end - m_range_start == 0:
|
||||
return
|
||||
|
||||
n_range_start = pid_n * BLOCK_SIZE_N
|
||||
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
|
||||
|
||||
offs_am = tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
|
||||
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
|
||||
b_ptr = b + (
|
||||
(expert_id * b_stride_0)
|
||||
+ (n_range_start + offs_bn[:, None]) * b_stride_1
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
|
||||
if group_k > 0 and group_n > 0:
|
||||
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
|
||||
offs_bsn = (n_range_start + offs_bn) // group_n
|
||||
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a_tile = tl.load(
|
||||
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
||||
)
|
||||
b_tile = tl.load(
|
||||
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
||||
)
|
||||
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
|
||||
b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
|
||||
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
|
||||
else:
|
||||
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
|
||||
a_ptr += BLOCK_SIZE_K
|
||||
b_ptr += BLOCK_SIZE_K
|
||||
|
||||
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
|
||||
if use_per_token_if_dynamic:
|
||||
scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None]))
|
||||
else:
|
||||
scale_a_value = tl.load(scale_a + expert_id)
|
||||
scale_b_value = tl.load(scale_b + expert_id)
|
||||
accumulator *= scale_a_value * scale_b_value
|
||||
|
||||
c_tile = accumulator.to(c_dtype)
|
||||
|
||||
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
|
||||
tl.store(c_ptr, c_tile, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_m_num_tiles_indptr(
|
||||
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
|
||||
):
|
||||
for bs in range(batch_size):
|
||||
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
|
||||
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
|
||||
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
|
||||
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
|
||||
|
||||
|
||||
def grouped_gemm_triton(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
batch_size: int,
|
||||
weight_column_major: bool,
|
||||
seg_indptr: Optional[torch.Tensor] = None,
|
||||
weight_indices: Optional[torch.Tensor] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
scale_a: torch.Tensor = None,
|
||||
scale_b: torch.Tensor = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
c_dtype=None,
|
||||
use_per_token_if_dynamic: bool = True,
|
||||
):
|
||||
assert weight_column_major == True # TODO: more
|
||||
if use_fp8_w8a8 and block_shape is None:
|
||||
assert scale_a is not None and scale_b is not None
|
||||
|
||||
if block_shape is not None:
|
||||
a_original = a
|
||||
|
||||
assert len(block_shape) == 2
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
||||
|
||||
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
|
||||
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
||||
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
|
||||
|
||||
dispose_tensor(a_original)
|
||||
|
||||
# TODO: adjust config or tune kernel
|
||||
# Reduce block size to prevent L40 shared memory overflow.
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
}
|
||||
|
||||
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
|
||||
compute_m_num_tiles_indptr[(1,)](
|
||||
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
|
||||
if c is None:
|
||||
assert c_dtype is not None
|
||||
c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
|
||||
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic:
|
||||
assert (
|
||||
scale_a.shape[0] == a.shape[0]
|
||||
), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}"
|
||||
|
||||
grouped_gemm_triton_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
batch_size,
|
||||
b.size(1),
|
||||
b.size(2),
|
||||
seg_indptr,
|
||||
weight_indices,
|
||||
m_num_tiles_indptr,
|
||||
scale_a,
|
||||
scale_b,
|
||||
use_fp8_w8a8,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
a.stride(0),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
|
||||
scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
|
||||
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
||||
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
|
||||
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
||||
use_per_token_if_dynamic,
|
||||
**config,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_ep_scatter_1(
|
||||
num_recv_tokens_per_expert,
|
||||
@@ -1234,7 +819,7 @@ def deepgemm_compute_src2dst_triton_kernel(
|
||||
mask = dst_id < num_toks
|
||||
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
||||
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
|
||||
expert_dst_start = tl.load(seg_indptr + expert_id)
|
||||
expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0))
|
||||
expert_dst_offset = dst_id - expert_dst_start
|
||||
dst_id = expert_id * m_max + expert_dst_offset
|
||||
tl.store(src2dst + src_id, dst_id, mask=mask)
|
||||
@@ -1248,10 +833,7 @@ def fill_gateup_input_triton_kernel(
|
||||
gateup_input_scale_ptr,
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
m_max,
|
||||
hidden_size,
|
||||
scale_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
@@ -1267,10 +849,9 @@ def fill_gateup_input_triton_kernel(
|
||||
vec = tl.arange(0, BLOCK_SIZE)
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
if expert_id >= 0:
|
||||
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
||||
dst_idx = dst_idx_int32.to(tl.int64)
|
||||
dst_idx = dst_idx - start_expert_id * m_max
|
||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + vec
|
||||
@@ -1287,31 +868,31 @@ def fill_gateup_input_triton_kernel(
|
||||
|
||||
def moe_ep_deepgemm_preprocess(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k: int,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
block_shape,
|
||||
output_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
):
|
||||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
||||
seg_indptr = torch.zeros(
|
||||
num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
|
||||
)
|
||||
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
||||
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
|
||||
masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32)
|
||||
|
||||
compute_seg_indptr_triton_kernel[(num_experts,)](
|
||||
compute_seg_indptr_triton_kernel[(num_local_experts + 1,)](
|
||||
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
||||
)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
|
||||
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
|
||||
compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m)
|
||||
|
||||
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
|
||||
m_max = (hidden_states.size(0) + 255) // 256 * 256
|
||||
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
|
||||
m_max = (hidden_states.size(0) // 256 + 1) * 256
|
||||
expected_m = (topk_ids.numel() - 1) // num_local_experts + 1
|
||||
gateup_input = torch.empty(
|
||||
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
|
||||
(num_local_experts, m_max, hidden_states.size(1)),
|
||||
device=hidden_states.device,
|
||||
dtype=output_dtype,
|
||||
)
|
||||
@@ -1330,6 +911,8 @@ def moe_ep_deepgemm_preprocess(
|
||||
block_shape = [128, 128]
|
||||
assert len(block_shape) == 2
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
|
||||
# TODO: fuse this with the preprocess
|
||||
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
|
||||
|
||||
gateup_input_scale = torch.empty(
|
||||
@@ -1345,18 +928,14 @@ def moe_ep_deepgemm_preprocess(
|
||||
gateup_input_scale,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
top_k,
|
||||
m_max,
|
||||
hidden_states.size(1),
|
||||
scale.size(1),
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
return (
|
||||
m_max,
|
||||
masked_m[start_expert_id : (end_expert_id + 1)],
|
||||
masked_m,
|
||||
expected_m,
|
||||
src2dst,
|
||||
gateup_input,
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.moe import (
|
||||
get_deepep_mode,
|
||||
get_moe_a2a_backend,
|
||||
@@ -18,13 +14,10 @@ from sglang.srt.layers.moe import (
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
ep_gather,
|
||||
ep_scatter,
|
||||
moe_ep_deepgemm_preprocess,
|
||||
post_reorder_triton_kernel,
|
||||
silu_and_mul_masked_post_quant_fwd,
|
||||
tma_align_input_scale,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
@@ -36,19 +29,10 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
CUTEDSL_MOE_NVFP4_DISPATCH,
|
||||
ModelOptNvFp4FusedMoEMethod,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.offloader import get_offloader
|
||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||
from sglang.srt.utils import (
|
||||
ceil_div,
|
||||
dispose_tensor,
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
)
|
||||
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
@@ -72,275 +56,7 @@ if _use_aiter:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
||||
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
||||
@torch.compile
|
||||
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
||||
temp = x.to(torch.float32).view(torch.int32)
|
||||
exp = torch.bitwise_right_shift(temp, 23)
|
||||
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
||||
is_ru = torch.logical_and(
|
||||
torch.logical_and((mant > 0), (exp != 0xFE)),
|
||||
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
||||
)
|
||||
exp = torch.where(is_ru, exp + 1, exp)
|
||||
new_x = exp.to(torch.uint8).view(torch.int)
|
||||
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
|
||||
|
||||
class EPMoE(FusedMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
num_fused_shared_experts: int = 0,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_clamp_limit: Optional[float] = None,
|
||||
with_bias: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
layer_id=layer_id,
|
||||
top_k=top_k,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
activation=activation,
|
||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||
with_bias=with_bias,
|
||||
)
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
if isinstance(quant_config, Fp8Config):
|
||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||
self.block_shape = (
|
||||
self.quant_method.quant_config.weight_block_size
|
||||
if self.use_block_quant
|
||||
else None
|
||||
)
|
||||
self.use_fp8_w8a8 = True
|
||||
self.fp8_dtype = torch.float8_e4m3fn
|
||||
self.activation_scheme = quant_config.activation_scheme
|
||||
else:
|
||||
self.use_fp8_w8a8 = False
|
||||
self.use_block_quant = False
|
||||
self.block_shape = None
|
||||
self.activation_scheme = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||
return self.forward_deepgemm(hidden_states, topk_output)
|
||||
else:
|
||||
return super().forward(hidden_states, topk_output)
|
||||
|
||||
def forward_deepgemm(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||
)
|
||||
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
if not self.use_block_quant:
|
||||
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
||||
scale_block_size = 128
|
||||
w13_weight_scale_n = 2 * (
|
||||
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
||||
)
|
||||
w13_weight_scale_k = (
|
||||
hidden_states_shape[-1] + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w13_weight_scale = (
|
||||
self.w13_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w13_weight_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w13_weight_scale_k, dim=2)
|
||||
)
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
w13_weight_scale,
|
||||
)
|
||||
w2_weight_scale_n = (
|
||||
hidden_states_shape[-1] + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w2_weight_scale_k = (
|
||||
self.intermediate_size + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w2_weight_scale = (
|
||||
self.w2_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w2_weight_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w2_weight_scale_k, dim=2)
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
w2_weight_scale,
|
||||
)
|
||||
|
||||
# PreReorder
|
||||
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
||||
moe_ep_deepgemm_preprocess(
|
||||
topk_ids,
|
||||
self.num_experts,
|
||||
hidden_states,
|
||||
self.top_k,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.block_shape,
|
||||
)
|
||||
)
|
||||
|
||||
dispose_tensor(hidden_states)
|
||||
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
b, s_mn, s_k = gateup_input_scale.shape
|
||||
assert (
|
||||
s_mn % 4 == 0 and s_k % 4 == 0
|
||||
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_input_fp8 = (
|
||||
gateup_input,
|
||||
(
|
||||
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
||||
gateup_input_scale
|
||||
)
|
||||
),
|
||||
)
|
||||
num_groups, m, k = gateup_input_fp8[0].size()
|
||||
n = self.w13_weight.size(1)
|
||||
gateup_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
gateup_input_fp8,
|
||||
self.w13_weight_fp8,
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
del gateup_input
|
||||
del gateup_input_fp8
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2,
|
||||
),
|
||||
device=hidden_states_device,
|
||||
dtype=self.fp8_dtype,
|
||||
)
|
||||
scale_block_size = 128
|
||||
down_input_scale = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2 // scale_block_size,
|
||||
),
|
||||
device=hidden_states_device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
silu_and_mul_masked_post_quant_fwd(
|
||||
gateup_output,
|
||||
down_input,
|
||||
down_input_scale,
|
||||
scale_block_size,
|
||||
masked_m,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
del gateup_output
|
||||
|
||||
# GroupGemm-1
|
||||
n = self.w2_weight.size(1)
|
||||
down_input_fp8 = (
|
||||
down_input,
|
||||
(
|
||||
down_input_scale
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
||||
),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
down_input_fp8,
|
||||
self.w2_weight_fp8,
|
||||
down_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
del down_input
|
||||
del down_input_fp8
|
||||
|
||||
# PostReorder
|
||||
output = torch.empty(
|
||||
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
||||
)
|
||||
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.top_k,
|
||||
hidden_states_shape[1],
|
||||
m_max * self.start_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
if self.moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= self.moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
|
||||
|
||||
class DeepEPMoE(EPMoE):
|
||||
class DeepEPMoE(FusedMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
||||
"""
|
||||
@@ -374,6 +90,15 @@ class DeepEPMoE(EPMoE):
|
||||
activation=activation,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if isinstance(quant_config, Fp8Config):
|
||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||
self.use_fp8_w8a8 = True
|
||||
self.fp8_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
self.use_fp8_w8a8 = False
|
||||
self.use_block_quant = False
|
||||
|
||||
self.deepep_mode = get_deepep_mode()
|
||||
|
||||
# TODO: move to the beginning of the file
|
||||
@@ -567,7 +292,6 @@ class DeepEPMoE(EPMoE):
|
||||
N = self.w13_weight.size(1)
|
||||
scale_block_size = 128
|
||||
|
||||
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
|
||||
w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
@@ -988,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
|
||||
return FlashInferFusedMoE
|
||||
if get_moe_runner_backend().is_flashinfer_cutlass():
|
||||
return FusedMoE
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
return EPMoE
|
||||
return FusedMoE
|
||||
|
||||
|
||||
|
||||
@@ -156,8 +156,7 @@ class FusedMoE(torch.nn.Module):
|
||||
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
||||
assert num_experts % self.moe_ep_size == 0
|
||||
self.num_local_experts = num_experts // self.moe_ep_size
|
||||
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
||||
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
||||
|
||||
if self.moe_ep_size > 1:
|
||||
# TODO(ch-wan): support shared experts fusion
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
|
||||
304
python/sglang/srt/layers/moe/moe_runner/deep_gemm.py
Normal file
304
python/sglang/srt/layers/moe/moe_runner/deep_gemm.py
Normal file
@@ -0,0 +1,304 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.moe_runner.base import (
|
||||
MoeQuantInfo,
|
||||
MoeRunnerConfig,
|
||||
MoeRunnerCore,
|
||||
RunnerInput,
|
||||
RunnerOutput,
|
||||
register_post_permute,
|
||||
register_pre_permute,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
||||
from sglang.srt.utils import dispose_tensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||
StandardCombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
|
||||
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
||||
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
||||
@torch.compile
|
||||
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
||||
temp = x.to(torch.float32).view(torch.int32)
|
||||
exp = torch.bitwise_right_shift(temp, 23)
|
||||
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
||||
is_ru = torch.logical_and(
|
||||
torch.logical_and((mant > 0), (exp != 0xFE)),
|
||||
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
||||
)
|
||||
exp = torch.where(is_ru, exp + 1, exp)
|
||||
new_x = exp.to(torch.uint8).view(torch.int)
|
||||
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepGemmRunnerInput(RunnerInput):
|
||||
hidden_states: torch.Tensor
|
||||
hidden_states_scale: torch.Tensor
|
||||
masked_m: torch.Tensor
|
||||
expected_m: int
|
||||
use_masked_gemm: bool
|
||||
|
||||
@property
|
||||
def runner_backend(self) -> MoeRunnerBackend:
|
||||
return MoeRunnerBackend.DEEP_GEMM
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepGemmRunnerOutput(RunnerOutput):
|
||||
hidden_states: torch.Tensor
|
||||
|
||||
@property
|
||||
def runner_backend(self) -> MoeRunnerBackend:
|
||||
return MoeRunnerBackend.DEEP_GEMM
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepGemmMoeQuantInfo(MoeQuantInfo):
|
||||
w13_weight: torch.Tensor
|
||||
w2_weight: torch.Tensor
|
||||
use_fp8: bool
|
||||
w13_scale: Optional[torch.Tensor] = None
|
||||
w2_scale: Optional[torch.Tensor] = None
|
||||
block_shape: Optional[List[int]] = None
|
||||
|
||||
|
||||
class DeepGemmRunnerCore(MoeRunnerCore):
|
||||
def __init__(self, config: MoeRunnerConfig):
|
||||
super().__init__(config)
|
||||
assert self.config.activation == "silu"
|
||||
|
||||
def run(
|
||||
self,
|
||||
runner_input: DeepGemmRunnerInput,
|
||||
quant_info: DeepGemmMoeQuantInfo,
|
||||
running_state: dict,
|
||||
) -> DeepGemmRunnerOutput:
|
||||
|
||||
if runner_input.use_masked_gemm:
|
||||
hidden_states = self._run_masked_gemm(
|
||||
runner_input,
|
||||
quant_info,
|
||||
running_state,
|
||||
)
|
||||
else:
|
||||
hidden_states = self._run_contiguous_gemm(
|
||||
runner_input,
|
||||
quant_info,
|
||||
running_state,
|
||||
)
|
||||
return DeepGemmRunnerOutput(hidden_states=hidden_states)
|
||||
|
||||
def _run_masked_gemm(
|
||||
self,
|
||||
runner_input: DeepGemmRunnerInput,
|
||||
quant_info: DeepGemmMoeQuantInfo,
|
||||
running_state: dict,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
silu_and_mul_masked_post_quant_fwd,
|
||||
)
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
|
||||
hidden_states = runner_input.hidden_states
|
||||
hidden_states_scale = runner_input.hidden_states_scale
|
||||
masked_m = runner_input.masked_m
|
||||
expected_m = runner_input.expected_m
|
||||
|
||||
w13_weight = quant_info.w13_weight
|
||||
w2_weight = quant_info.w2_weight
|
||||
w13_scale = quant_info.w13_scale
|
||||
w2_scale = quant_info.w2_scale
|
||||
|
||||
hidden_states_device = running_state["hidden_states_device"]
|
||||
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
b, s_mn, s_k = hidden_states_scale.shape
|
||||
assert (
|
||||
s_mn % 4 == 0 and s_k % 4 == 0
|
||||
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
||||
|
||||
# GroupGemm-0
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
|
||||
else:
|
||||
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
||||
hidden_states_scale
|
||||
)
|
||||
|
||||
num_groups, m, k = hidden_states.shape
|
||||
n = w13_weight.size(1)
|
||||
gateup_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
(hidden_states, hidden_states_scale),
|
||||
(w13_weight, w13_scale),
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
dispose_tensor(hidden_states)
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2,
|
||||
),
|
||||
device=hidden_states_device,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
scale_block_size = 128
|
||||
down_input_scale = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2 // scale_block_size,
|
||||
),
|
||||
device=hidden_states_device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
silu_and_mul_masked_post_quant_fwd(
|
||||
gateup_output,
|
||||
down_input,
|
||||
down_input_scale,
|
||||
scale_block_size,
|
||||
masked_m,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
del gateup_output
|
||||
|
||||
# GroupGemm-1
|
||||
n = w2_weight.shape[1]
|
||||
|
||||
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
||||
down_input_scale
|
||||
)
|
||||
|
||||
down_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
(down_input, down_input_scale),
|
||||
(w2_weight, w2_scale),
|
||||
down_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
del down_input
|
||||
|
||||
return down_output
|
||||
|
||||
def _run_contiguous_gemm(
|
||||
self,
|
||||
runner_input: DeepGemmRunnerInput,
|
||||
quant_info: DeepGemmMoeQuantInfo,
|
||||
running_state: dict,
|
||||
) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
@property
|
||||
def runner_backend(self) -> MoeRunnerBackend:
|
||||
return MoeRunnerBackend.DEEP_GEMM
|
||||
|
||||
|
||||
@register_pre_permute("standard", "deep_gemm")
|
||||
def pre_permute_standard_to_deep_gemm(
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
quant_info: DeepGemmMoeQuantInfo,
|
||||
runner_config: MoeRunnerConfig,
|
||||
running_state: dict,
|
||||
) -> DeepGemmRunnerInput:
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
|
||||
|
||||
hidden_states, topk_output = dispatch_output
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
hidden_states_ref = hidden_states
|
||||
|
||||
topk_weights, topk_ids = topk_weights, topk_ids
|
||||
|
||||
# PreReorder
|
||||
masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
|
||||
moe_ep_deepgemm_preprocess(
|
||||
topk_ids,
|
||||
runner_config.num_local_experts,
|
||||
hidden_states,
|
||||
runner_config.top_k,
|
||||
quant_info.block_shape,
|
||||
)
|
||||
)
|
||||
|
||||
dispose_tensor(hidden_states_ref)
|
||||
|
||||
running_state["topk_ids"] = topk_ids
|
||||
running_state["topk_weights"] = topk_weights
|
||||
running_state["hidden_states_shape"] = hidden_states_shape
|
||||
running_state["hidden_states_dtype"] = hidden_states_dtype
|
||||
running_state["hidden_states_device"] = hidden_states_device
|
||||
running_state["src2dst"] = src2dst
|
||||
|
||||
return DeepGemmRunnerInput(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
use_masked_gemm=True,
|
||||
)
|
||||
|
||||
|
||||
@register_post_permute("deep_gemm", "standard")
|
||||
def post_permute_deep_gemm_to_standard(
|
||||
runner_output: DeepGemmRunnerOutput,
|
||||
quant_info: DeepGemmMoeQuantInfo,
|
||||
runner_config: MoeRunnerConfig,
|
||||
running_state: dict,
|
||||
) -> StandardCombineInput:
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||
|
||||
hidden_states_shape = running_state["hidden_states_shape"]
|
||||
hidden_states_dtype = running_state["hidden_states_dtype"]
|
||||
hidden_states_device = running_state["hidden_states_device"]
|
||||
src2dst = running_state["src2dst"]
|
||||
topk_ids = running_state["topk_ids"]
|
||||
topk_weights = running_state["topk_weights"]
|
||||
|
||||
output = torch.empty(
|
||||
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
||||
)
|
||||
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
||||
runner_output.hidden_states,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
runner_config.top_k,
|
||||
hidden_states_shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
dispose_tensor(runner_output.hidden_states)
|
||||
|
||||
if runner_config.routed_scaling_factor is not None:
|
||||
output *= runner_config.routed_scaling_factor
|
||||
|
||||
return StandardCombineInput(
|
||||
hidden_states=output,
|
||||
)
|
||||
@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
|
||||
MoeRunnerConfig,
|
||||
PermuteMethodPool,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
||||
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
||||
|
||||
@@ -30,6 +31,8 @@ class MoeRunner:
|
||||
|
||||
if runner_backend.is_triton():
|
||||
self.runner_core = TritonRunnerCore(config)
|
||||
elif runner_backend.is_deep_gemm():
|
||||
self.runner_core = DeepGemmRunnerCore(config)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ class MoeA2ABackend(Enum):
|
||||
class MoeRunnerBackend(Enum):
|
||||
|
||||
AUTO = "auto"
|
||||
DEEP_GEMM = "deep_gemm"
|
||||
TRITON = "triton"
|
||||
TRITON_KERNEL = "triton_kernel"
|
||||
FLASHINFER_TRTLLM = "flashinfer_trtllm"
|
||||
@@ -54,6 +55,9 @@ class MoeRunnerBackend(Enum):
|
||||
def is_auto(self):
|
||||
return self == MoeRunnerBackend.AUTO
|
||||
|
||||
def is_deep_gemm(self):
|
||||
return self == MoeRunnerBackend.DEEP_GEMM
|
||||
|
||||
def is_triton(self):
|
||||
return self == MoeRunnerBackend.TRITON
|
||||
|
||||
@@ -147,7 +151,9 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
|
||||
def get_moe_runner_backend() -> MoeRunnerBackend:
|
||||
global MOE_RUNNER_BACKEND
|
||||
if MOE_RUNNER_BACKEND is None:
|
||||
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
|
||||
logger.warning(
|
||||
"MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected"
|
||||
)
|
||||
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
|
||||
return MOE_RUNNER_BACKEND
|
||||
|
||||
|
||||
@@ -31,8 +31,8 @@ except ImportError:
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
|
||||
from sglang.srt.layers.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
|
||||
from sglang.srt.layers.moe.utils import (
|
||||
get_moe_a2a_backend,
|
||||
get_moe_runner_backend,
|
||||
)
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
moe_runner_backend = get_moe_runner_backend()
|
||||
|
||||
if moe_runner_backend.is_auto():
|
||||
if (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
and get_moe_a2a_backend().is_deepep()
|
||||
):
|
||||
moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
|
||||
else:
|
||||
moe_runner_backend = MoeRunnerBackend.TRITON
|
||||
if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
|
||||
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
|
||||
else:
|
||||
# TODO(cwan): refactor other backends
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
),
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
if self.runner.runner_backend.is_deep_gemm():
|
||||
|
||||
w13_weight = layer.w13_weight
|
||||
w2_weight = layer.w2_weight
|
||||
|
||||
if self.block_quant:
|
||||
block_shape = self.quant_config.weight_block_size
|
||||
w13_scale = layer.w13_weight_scale_inv
|
||||
w2_scale = layer.w2_weight_scale_inv
|
||||
else:
|
||||
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
||||
scale_block_size = 128
|
||||
block_shape = [scale_block_size, scale_block_size]
|
||||
w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
|
||||
w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
|
||||
w13_scale = (
|
||||
layer.w13_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w13_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w13_scale_k, dim=2)
|
||||
)
|
||||
w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
|
||||
w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
|
||||
w2_scale = (
|
||||
layer.w2_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w2_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w2_scale_k, dim=2)
|
||||
)
|
||||
quant_info = DeepGemmMoeQuantInfo(
|
||||
w13_weight=w13_weight,
|
||||
w2_weight=w2_weight,
|
||||
use_fp8=True,
|
||||
w13_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
elif self.runner.runner_backend.is_triton():
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
use_fp8_w8a8=True,
|
||||
w13_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w2_weight_scale
|
||||
),
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported runner backend: %s" % self.runner.runner_backend
|
||||
)
|
||||
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
def apply_with_router_logits(
|
||||
|
||||
@@ -21,7 +21,6 @@ from sglang.srt.utils import is_npu, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
@@ -94,9 +93,7 @@ class W4AFp8Config(QuantizationConfig):
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
@@ -133,7 +130,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: EPMoE,
|
||||
layer: Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
@@ -292,7 +289,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: EPMoE,
|
||||
layer: Module,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
@@ -303,18 +300,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
local_topk_ids = topk_ids
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
local_topk_ids = torch.where(
|
||||
topk_ids == -1,
|
||||
layer.num_experts,
|
||||
topk_ids,
|
||||
)
|
||||
|
||||
output = cutlass_w4a8_moe(
|
||||
layer.start_expert_id,
|
||||
layer.end_expert_id,
|
||||
layer.num_experts,
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -322,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale_inv,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
local_topk_ids,
|
||||
self.a_strides1,
|
||||
self.b_strides1,
|
||||
self.c_strides1,
|
||||
|
||||
@@ -49,7 +49,6 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.router import fused_moe_router_shim
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
@@ -176,17 +175,7 @@ class Grok1MoE(nn.Module):
|
||||
custom_routing_function=custom_routing_function,
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
MoEImpl = EPMoE
|
||||
else:
|
||||
MoEImpl = FusedMoE
|
||||
kwargs["reduce_results"] = reduce_results
|
||||
kwargs["use_presharded_weights"] = use_presharded_weights
|
||||
kwargs["inplace"] = inplace
|
||||
kwargs["no_combine"] = no_combine
|
||||
|
||||
self.experts = MoEImpl(
|
||||
self.experts = FusedMoE(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
layer_id=layer_id,
|
||||
@@ -195,7 +184,10 @@ class Grok1MoE(nn.Module):
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
activation="gelu",
|
||||
**kwargs,
|
||||
reduce_results=reduce_results,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -36,7 +36,6 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -94,8 +93,7 @@ class MixtralMoE(nn.Module):
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE
|
||||
self.experts = MoEImpl(
|
||||
self.experts = FusedMoE(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
layer_id=layer_id,
|
||||
|
||||
@@ -121,6 +121,17 @@ NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"
|
||||
|
||||
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
|
||||
|
||||
MOE_RUNNER_BACKEND_CHOICES = [
|
||||
"auto",
|
||||
"deep_gemm",
|
||||
"triton",
|
||||
"triton_kernel",
|
||||
"flashinfer_trtllm",
|
||||
"flashinfer_cutlass",
|
||||
"flashinfer_mxfp4",
|
||||
"flashinfer_cutedsl",
|
||||
]
|
||||
|
||||
|
||||
# Allow external code to add more choices
|
||||
def add_load_format_choices(choices):
|
||||
@@ -143,6 +154,10 @@ def add_grammar_backend_choices(choices):
|
||||
GRAMMAR_BACKEND_CHOICES.extend(choices)
|
||||
|
||||
|
||||
def add_moe_runner_backend_choices(choices):
|
||||
MOE_RUNNER_BACKEND_CHOICES.extend(choices)
|
||||
|
||||
|
||||
def add_deterministic_attention_backend_choices(choices):
|
||||
DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)
|
||||
|
||||
@@ -315,14 +330,7 @@ class ServerArgs:
|
||||
# Expert parallelism
|
||||
ep_size: int = 1
|
||||
moe_a2a_backend: Literal["none", "deepep"] = "none"
|
||||
moe_runner_backend: Literal[
|
||||
"auto",
|
||||
"triton",
|
||||
"triton_kernel",
|
||||
"flashinfer_trtllm",
|
||||
"flashinfer_cutlass",
|
||||
"flashinfer_mxfp4",
|
||||
] = "auto"
|
||||
moe_runner_backend: str = "auto"
|
||||
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
|
||||
enable_flashinfer_allreduce_fusion: bool = False
|
||||
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
|
||||
@@ -2191,15 +2199,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--moe-runner-backend",
|
||||
type=str,
|
||||
choices=[
|
||||
"auto",
|
||||
"triton",
|
||||
"triton_kernel",
|
||||
"flashinfer_trtllm",
|
||||
"flashinfer_cutlass",
|
||||
"flashinfer_mxfp4",
|
||||
"flashinfer_cutedsl",
|
||||
],
|
||||
choices=MOE_RUNNER_BACKEND_CHOICES,
|
||||
default=ServerArgs.moe_runner_backend,
|
||||
help="Choose the runner backend for MoE.",
|
||||
)
|
||||
|
||||
@@ -1,358 +0,0 @@
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
grouped_gemm_triton,
|
||||
post_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel,
|
||||
run_moe_ep_preproess,
|
||||
silu_and_mul_triton_kernel,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
# For test
|
||||
def ep_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
topk_config: TopKConfig,
|
||||
# ep config
|
||||
num_experts: int = 256,
|
||||
fp8_dtype: torch.types = torch.float8_e4m3fn,
|
||||
num_experts_per_partition: int = 128,
|
||||
start_expert_id: int = 0,
|
||||
end_expert_id: int = 127,
|
||||
use_fp8_w8a8: bool = False,
|
||||
w1_scale_inv: Optional[torch.Tensor] = None,
|
||||
w2_scale_inv: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
):
|
||||
use_blockwise_fp8 = block_shape is not None
|
||||
top_k = topk_config.top_k
|
||||
topk_output = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
topk_config=topk_config,
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
|
||||
|
||||
gateup_input = torch.empty(
|
||||
(int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=(
|
||||
fp8_dtype
|
||||
if (use_fp8_w8a8 and not use_blockwise_fp8)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
|
||||
if use_fp8_w8a8 and not use_blockwise_fp8:
|
||||
max_value = (
|
||||
torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
|
||||
)
|
||||
w1_input_scale = max_value / torch.finfo(fp8_dtype).max
|
||||
else:
|
||||
w1_input_scale = None
|
||||
|
||||
# PreReorder
|
||||
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
||||
hidden_states,
|
||||
gateup_input,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
w1_input_scale,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
top_k,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
use_per_token_if_dynamic=True,
|
||||
)
|
||||
|
||||
seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
|
||||
weight_indices_cur_rank = torch.arange(
|
||||
0,
|
||||
num_experts_per_partition,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_output = torch.empty(
|
||||
gateup_input.shape[0],
|
||||
w1.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
gateup_output = grouped_gemm_triton(
|
||||
a=gateup_input,
|
||||
b=w1,
|
||||
c=gateup_output,
|
||||
batch_size=num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
scale_a=w1_input_scale,
|
||||
scale_b=w1_scale_inv,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1] // 2,
|
||||
device=gateup_output.device,
|
||||
dtype=(
|
||||
fp8_dtype
|
||||
if (use_fp8_w8a8 and not use_blockwise_fp8)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
if use_fp8_w8a8 and not use_blockwise_fp8:
|
||||
w2_input_scale = torch.ones(
|
||||
num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
else:
|
||||
w2_input_scale = None
|
||||
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
w2_input_scale,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
down_input.shape[0],
|
||||
w2.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
down_output = grouped_gemm_triton(
|
||||
a=down_input,
|
||||
b=w2,
|
||||
c=down_output,
|
||||
batch_size=num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
scale_a=w2_input_scale,
|
||||
scale_b=w2_scale_inv,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# PostReorder
|
||||
output = torch.empty_like(hidden_states)
|
||||
post_reorder_triton_kernel[(hidden_states.size(0),)](
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
top_k,
|
||||
hidden_states.size(1),
|
||||
0,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
# test util
|
||||
def block_dequant(
|
||||
x_q_block: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
block_size: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function converts block-wise quantization to tensor-wise quantization.
|
||||
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
||||
and the block size.
|
||||
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
|
||||
Note only float8 is supported for now.
|
||||
"""
|
||||
|
||||
# process 3D tensor
|
||||
if x_q_block.dim() == 3:
|
||||
batch_size = x_q_block.size(0)
|
||||
return torch.stack(
|
||||
[block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
|
||||
)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n, k = x_q_block.shape
|
||||
n_tiles = (n + block_n - 1) // block_n
|
||||
k_tiles = (k + block_k - 1) // block_k
|
||||
assert n_tiles == x_s.shape[0]
|
||||
assert k_tiles == x_s.shape[1]
|
||||
|
||||
x_dq_block = x_q_block.to(torch.float32)
|
||||
|
||||
x_dq_block_tiles = [
|
||||
[
|
||||
x_dq_block[
|
||||
j * block_n : min((j + 1) * block_n, n),
|
||||
i * block_k : min((i + 1) * block_k, k),
|
||||
]
|
||||
for i in range(k_tiles)
|
||||
]
|
||||
for j in range(n_tiles)
|
||||
]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
||||
|
||||
return x_dq_block
|
||||
|
||||
|
||||
class TestW8A8BlockFP8EPMoE(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 222, 1024, 2048]
|
||||
N = [128, 1024, 2048]
|
||||
K = [256, 4096, 5120]
|
||||
E = [8, 16]
|
||||
ep_size = [2, 4]
|
||||
TOP_KS = [2, 4]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _w8a8_block_fp8_ep_moe(
|
||||
self, M, N, K, E, ep_size, topk, block_size, dtype, seed
|
||||
):
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
|
||||
w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
|
||||
w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
|
||||
w1_s = (
|
||||
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||
* factor_for_scale
|
||||
)
|
||||
w2_s = (
|
||||
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
|
||||
w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
num_experts_per_partition = E // ep_size
|
||||
cur_rank = random.randint(0, ep_size - 1)
|
||||
start_id = cur_rank * num_experts_per_partition
|
||||
end_id = start_id + num_experts_per_partition - 1
|
||||
|
||||
topk_config = TopKConfig(
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
out = ep_moe(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
router_logits=score,
|
||||
topk_config=topk_config,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale_inv=w1_s,
|
||||
w2_scale_inv=w2_s,
|
||||
block_shape=block_size,
|
||||
num_experts=E,
|
||||
num_experts_per_partition=num_experts_per_partition,
|
||||
start_expert_id=start_id,
|
||||
end_expert_id=end_id,
|
||||
)
|
||||
ref_out = ep_moe(
|
||||
hidden_states=a,
|
||||
w1=w1_ref,
|
||||
w2=w2_ref,
|
||||
router_logits=score,
|
||||
topk_config=topk_config,
|
||||
use_fp8_w8a8=False,
|
||||
w1_scale_inv=None,
|
||||
w2_scale_inv=None,
|
||||
block_shape=None,
|
||||
num_experts=E,
|
||||
num_experts_per_partition=num_experts_per_partition,
|
||||
start_expert_id=start_id,
|
||||
end_expert_id=end_id,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
||||
/ (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
|
||||
< 0.06
|
||||
)
|
||||
|
||||
def test_w8a8_block_fp8_ep_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.E,
|
||||
self.ep_size,
|
||||
self.TOP_KS,
|
||||
self.BLOCK_SIZE,
|
||||
self.DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
ep_size=params[4],
|
||||
topk=params[5],
|
||||
block_size=params[6],
|
||||
dtype=params[7],
|
||||
seed=params[8],
|
||||
):
|
||||
self._w8a8_block_fp8_ep_moe(*params)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
@@ -120,7 +120,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
||||
expert_map[local_e:] = E
|
||||
expert_map[local_e:] = -1
|
||||
|
||||
output = cutlass_moe(
|
||||
a,
|
||||
@@ -138,9 +138,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
|
||||
c_strides2,
|
||||
s_strides13,
|
||||
s_strides2,
|
||||
0,
|
||||
local_e - 1,
|
||||
E,
|
||||
local_e,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
expert_map,
|
||||
@@ -178,7 +176,7 @@ def cutlass_moe(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids_: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
a_strides1: torch.Tensor,
|
||||
b_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
@@ -187,40 +185,32 @@ def cutlass_moe(
|
||||
c_strides2: torch.Tensor,
|
||||
s_strides13: torch.Tensor,
|
||||
s_strides2: torch.Tensor,
|
||||
start_expert_id: int,
|
||||
end_expert_id: int,
|
||||
E: int,
|
||||
num_local_experts: int,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
):
|
||||
local_topk_ids = topk_ids_
|
||||
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
|
||||
topk_ids = expert_map[topk_ids]
|
||||
device = a.device
|
||||
|
||||
local_num_experts = end_expert_id - start_expert_id + 1
|
||||
expert_offsets = torch.empty(
|
||||
(local_num_experts + 1), dtype=torch.int32, device=device
|
||||
(num_local_experts + 1), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes1 = torch.empty(
|
||||
(local_num_experts, 3), dtype=torch.int32, device=device
|
||||
(num_local_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes2 = torch.empty(
|
||||
(local_num_experts, 3), dtype=torch.int32, device=device
|
||||
(num_local_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
return cutlass_w4a8_moe(
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
E,
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids_,
|
||||
local_topk_ids,
|
||||
topk_ids,
|
||||
a_strides1,
|
||||
b_strides1,
|
||||
c_strides1,
|
||||
|
||||
Reference in New Issue
Block a user