Add a CUDA kernel for fusing mapping and weighted sum for MoE. (#6916)

Co-authored-by: Elfie Guo <elfiegxf@gmail.com>
This commit is contained in:
Elfie Guo
2025-06-07 15:24:39 -07:00
committed by GitHub
parent 62fec60d81
commit 3e56f557fd
7 changed files with 146 additions and 12 deletions

View File

@@ -15,6 +15,7 @@ _is_cuda = is_cuda()
if _is_cuda:
import sgl_kernel
from sgl_kernel import (
apply_shuffle_mul_sum,
cutlass_fp4_group_mm,
fp8_blockwise_scaled_grouped_mm,
prepare_moe_input,
@@ -151,8 +152,8 @@ def cutlass_fused_experts_fp8(
k,
)
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
rep_a1_scales = a1_scale[a_map]
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
@@ -206,9 +207,9 @@ def cutlass_fused_experts_fp8(
expert_offsets[:-1],
workspace,
)
return (
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
).sum(dim=1)
result = torch.empty((m, k), device=device, dtype=out_dtype)
return apply_shuffle_mul_sum(c2, result, c_map, topk_weights)
FLOAT4_E2M1_MAX = 6.0