Add a CUDA kernel for fusing mapping and weighted sum for MoE. (#6916)
Co-authored-by: Elfie Guo <elfiegxf@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@ _is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
import sgl_kernel
|
||||
from sgl_kernel import (
|
||||
apply_shuffle_mul_sum,
|
||||
cutlass_fp4_group_mm,
|
||||
fp8_blockwise_scaled_grouped_mm,
|
||||
prepare_moe_input,
|
||||
@@ -151,8 +152,8 @@ def cutlass_fused_experts_fp8(
|
||||
k,
|
||||
)
|
||||
|
||||
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
||||
rep_a1_scales = a1_scale[a_map]
|
||||
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
||||
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
||||
@@ -206,9 +207,9 @@ def cutlass_fused_experts_fp8(
|
||||
expert_offsets[:-1],
|
||||
workspace,
|
||||
)
|
||||
return (
|
||||
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
result = torch.empty((m, k), device=device, dtype=out_dtype)
|
||||
return apply_shuffle_mul_sum(c2, result, c_map, topk_weights)
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
|
||||
Reference in New Issue
Block a user