feat: integrate deepgemm into EPMoE (#6821)
Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: TianQiLin666666 <1834987979@qq.com> Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -478,11 +478,13 @@ def post_reorder_triton_kernel(
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
dst_start,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
InDtype = down_output_ptr.dtype.element_ty
|
||||
|
||||
src_idx = tl.program_id(0)
|
||||
src_idx_int32 = tl.program_id(0)
|
||||
src_idx = src_idx_int32.to(tl.int64)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
||||
@@ -501,7 +503,9 @@ def post_reorder_triton_kernel(
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
computed = True
|
||||
dst_idx = tl.load(src2dst_ptr + idx)
|
||||
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
||||
dst_idx = dst_idx_int32.to(tl.int64)
|
||||
dst_idx = dst_idx - dst_start
|
||||
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
||||
load_ptr = down_output_ptr + dst_idx * hidden_size
|
||||
in_data = tl.load(load_ptr + offset, mask=mask)
|
||||
@@ -1086,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor):
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
)
|
||||
return output.t()[:m]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_masked_m_triton_kernel(seg_indptr, masked_m):
|
||||
expert_id = tl.program_id(0)
|
||||
start = tl.load(seg_indptr + expert_id)
|
||||
end = tl.load(seg_indptr + expert_id + 1)
|
||||
tl.store(masked_m + expert_id, (end - start))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def deepgemm_compute_src2dst_triton_kernel(
|
||||
topk_ids,
|
||||
reorder_ids,
|
||||
seg_indptr,
|
||||
src2dst,
|
||||
m_max,
|
||||
num_toks,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = dst_id < num_toks
|
||||
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
||||
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
|
||||
expert_dst_start = tl.load(seg_indptr + expert_id)
|
||||
expert_dst_offset = dst_id - expert_dst_start
|
||||
dst_id = expert_id * m_max + expert_dst_offset
|
||||
tl.store(src2dst + src_id, dst_id, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fill_gateup_input_triton_kernel(
|
||||
input_ptr,
|
||||
scale_ptr,
|
||||
gateup_input_ptr,
|
||||
gateup_input_scale_ptr,
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
m_max,
|
||||
hidden_size,
|
||||
scale_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
|
||||
src_idx_int32 = tl.program_id(0)
|
||||
src_idx = src_idx_int32.to(tl.int64)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
src_ptr = input_ptr + src_idx * hidden_size
|
||||
scale_src_ptr = scale_ptr + src_idx * scale_size
|
||||
|
||||
vec = tl.arange(0, BLOCK_SIZE)
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
||||
dst_idx = dst_idx_int32.to(tl.int64)
|
||||
dst_idx = dst_idx - start_expert_id * m_max
|
||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + vec
|
||||
mask = offset < hidden_size
|
||||
in_data = tl.load(src_ptr + offset, mask=mask)
|
||||
tl.store(dst_ptr + offset, in_data, mask=mask)
|
||||
scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
|
||||
for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
|
||||
offset = start_offset + vec
|
||||
mask = offset < scale_size
|
||||
in_scale = tl.load(scale_src_ptr + offset, mask=mask)
|
||||
tl.store(scale_dst_ptr + offset, in_scale, mask=mask)
|
||||
|
||||
|
||||
def moe_ep_deepgemm_preprocess(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k: int,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
block_shape,
|
||||
output_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
):
|
||||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
||||
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
||||
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
|
||||
|
||||
compute_seg_indptr_triton_kernel[(num_experts,)](
|
||||
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
||||
)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
|
||||
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
|
||||
|
||||
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
|
||||
m_max = (hidden_states.size(0) + 255) // 256 * 256
|
||||
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
|
||||
gateup_input = torch.empty(
|
||||
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
|
||||
device=hidden_states.device,
|
||||
dtype=output_dtype,
|
||||
)
|
||||
|
||||
deepgemm_compute_src2dst_triton_kernel[grid](
|
||||
topk_ids,
|
||||
reorder_ids,
|
||||
seg_indptr,
|
||||
src2dst,
|
||||
m_max,
|
||||
topk_ids.numel(),
|
||||
BLOCK_SIZE=256,
|
||||
)
|
||||
|
||||
if block_shape is None:
|
||||
block_shape = [128, 128]
|
||||
assert len(block_shape) == 2
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
|
||||
|
||||
gateup_input_scale = torch.empty(
|
||||
(gateup_input.size(0), gateup_input.size(1), scale.size(1)),
|
||||
device=hidden_states.device,
|
||||
dtype=scale.dtype,
|
||||
)
|
||||
|
||||
fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
|
||||
hidden_states,
|
||||
scale,
|
||||
gateup_input,
|
||||
gateup_input_scale,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
top_k,
|
||||
m_max,
|
||||
hidden_states.size(1),
|
||||
scale.size(1),
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
return (
|
||||
m_max,
|
||||
masked_m[start_expert_id : (end_expert_id + 1)],
|
||||
expected_m,
|
||||
src2dst,
|
||||
gateup_input,
|
||||
gateup_input_scale,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
ep_scatter,
|
||||
gelu_and_mul_triton_kernel,
|
||||
grouped_gemm_triton,
|
||||
moe_ep_deepgemm_preprocess,
|
||||
post_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel,
|
||||
run_moe_ep_preproess,
|
||||
@@ -178,6 +179,7 @@ class EPMoE(torch.nn.Module):
|
||||
assert (
|
||||
num_fused_shared_experts == 0
|
||||
), "num_fused_shared_experts is not supported in EP"
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.num_experts_per_partition = self.num_experts // self.tp_size
|
||||
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
||||
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
||||
@@ -227,13 +229,182 @@ class EPMoE(torch.nn.Module):
|
||||
|
||||
self.grouped_gemm_runner = None
|
||||
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||
return self.forward_deepgemm(hidden_states, router_logits)
|
||||
else:
|
||||
return self.forward_normal(hidden_states, router_logits)
|
||||
|
||||
def forward_deepgemm(
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
|
||||
):
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
if not self.use_block_quant:
|
||||
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
||||
scale_block_size = 128
|
||||
w13_weight_scale_n = 2 * (
|
||||
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
||||
)
|
||||
w13_weight_scale_k = (
|
||||
hidden_states_shape[-1] + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w13_weight_scale = (
|
||||
self.w13_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w13_weight_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w13_weight_scale_k, dim=2)
|
||||
)
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
w13_weight_scale,
|
||||
)
|
||||
w2_weight_scale_n = (
|
||||
hidden_states_shape[-1] + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w2_weight_scale_k = (
|
||||
self.intermediate_size + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w2_weight_scale = (
|
||||
self.w2_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w2_weight_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w2_weight_scale_k, dim=2)
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
w2_weight_scale,
|
||||
)
|
||||
|
||||
# PreReorder
|
||||
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
||||
moe_ep_deepgemm_preprocess(
|
||||
topk_ids,
|
||||
self.num_experts,
|
||||
hidden_states,
|
||||
self.top_k,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.block_shape,
|
||||
)
|
||||
)
|
||||
|
||||
dispose_tensor(hidden_states)
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_input_fp8 = (
|
||||
gateup_input,
|
||||
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
|
||||
)
|
||||
num_groups, m, k = gateup_input_fp8[0].size()
|
||||
n = self.w13_weight.size(1)
|
||||
gateup_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
||||
)
|
||||
del gateup_input
|
||||
del gateup_input_fp8
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2,
|
||||
),
|
||||
device=hidden_states_device,
|
||||
dtype=self.fp8_dtype,
|
||||
)
|
||||
scale_block_size = 128
|
||||
down_input_scale = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2 // scale_block_size,
|
||||
),
|
||||
device=hidden_states_device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
silu_and_mul_masked_post_quant_fwd(
|
||||
gateup_output,
|
||||
down_input,
|
||||
down_input_scale,
|
||||
scale_block_size,
|
||||
masked_m,
|
||||
)
|
||||
del gateup_output
|
||||
|
||||
# GroupGemm-1
|
||||
n = self.w2_weight.size(1)
|
||||
down_input_fp8 = (
|
||||
down_input,
|
||||
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
||||
)
|
||||
del down_input
|
||||
del down_input_fp8
|
||||
|
||||
# PostReorder
|
||||
output = torch.empty(
|
||||
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
||||
)
|
||||
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.top_k,
|
||||
hidden_states_shape[1],
|
||||
m_max * self.start_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
if self.grouped_gemm_runner is None:
|
||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||
hidden_states.device,
|
||||
@@ -249,6 +420,7 @@ class EPMoE(torch.nn.Module):
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
@@ -440,6 +612,7 @@ class EPMoE(torch.nn.Module):
|
||||
self.end_expert_id,
|
||||
self.top_k,
|
||||
hidden_states_shape[1],
|
||||
0,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -182,6 +182,7 @@ def ep_moe(
|
||||
end_expert_id,
|
||||
top_k,
|
||||
hidden_states.size(1),
|
||||
0,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user