Add fp8 fused_experts kernel for CPU in sgl-kernel and add UT (#6404)

This commit is contained in:
Chunyuan WU
2025-05-23 17:01:55 +08:00
committed by GitHub
parent 4ba1eea83f
commit 3ded6235c9
7 changed files with 752 additions and 157 deletions

259
test/srt/cpu/test_moe.py Normal file
View File

@@ -0,0 +1,259 @@
import itertools
import math
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
kernel = torch.ops.sgl_kernel
from utils import (
BLOCK_K,
BLOCK_N,
factor_for_scale,
fp8_max,
fp8_min,
native_fp8_fused_moe,
precision,
scaled_weight,
torch_naive_fused_moe,
torch_w8a8_per_column_fused_moe,
)
from sglang.test.test_utils import CustomTestCase
def fused_moe(a, w1, w2, score, topk, renormalize, prepack):
G = 1
topk_group = 1
B, D = a.shape
topk_weights = torch.empty(B, topk, dtype=torch.float32)
topk_ids = torch.empty(B, topk, dtype=torch.int32)
topk_weights, topk_ids = kernel.grouped_topk_cpu(
a, score, topk, renormalize, G, topk_group
)
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
inplace = True
return kernel.fused_experts_cpu(
a,
packed_w1,
packed_w2,
topk_weights,
topk_ids,
inplace,
False,
False,
None,
None,
None,
None,
None,
prepack,
)
class TestFusedExperts(CustomTestCase):
M = [2, 114]
N = [32]
K = [32]
E = [4]
topk = [2]
renormalize = [False, True]
M_int8 = [1, 39]
N_int8 = [128]
K_int8 = [256]
E_int8 = [8]
topk_int8 = [3]
M_fp8 = [2, 121]
N_fp8 = [512]
K_fp8 = [256]
E_fp8 = [8]
topk_fp8 = [4]
def _bf16_moe(self, m, n, k, e, topk, renormalize):
dtype = torch.bfloat16
prepack = True
a = torch.randn((m, k), device="cpu", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cpu", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cpu", dtype=dtype) / 10
score = torch.randn((m, e), device="cpu", dtype=dtype)
torch_output = torch_naive_fused_moe(a, w1, w2, score, topk, renormalize)
fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack)
atol = rtol = precision[torch_output.dtype]
self.assertTrue(
torch.allclose(torch_output, fused_output, atol=atol, rtol=rtol)
)
def test_bf16_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.topk,
self.renormalize,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
e=params[3],
topk=params[4],
renormalize=params[5],
):
self._bf16_moe(*params)
def _int8_moe(self, M, N, K, E, topk):
dtype = torch.bfloat16
prepack = True
# Initialize int8 quantization parameters
int8_factor_for_scale = 1e-2
int8_max = 127
int8_min = -128
# Input tensor
# M * K
a = torch.randn((M, K), dtype=dtype) / math.sqrt(K)
# Generate int8 weights
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
# Generate scale for each column (per-column quantization)
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * int8_factor_for_scale
w2_s = torch.rand(E, K, device=w2_fp32.device) * int8_factor_for_scale
# Calculate routing
score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
ref_out = torch_w8a8_per_column_fused_moe(
a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk
)
inplace = True
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
out = kernel.fused_experts_cpu(
a,
packed_w1,
packed_w2,
topk_weight,
topk_ids.to(torch.int32),
inplace,
True,
False,
w1_s,
w2_s,
None,
None,
None,
prepack,
)
atol = rtol = precision[ref_out.dtype]
# Increase the tolerance for large input shapes
if M > 35:
atol = rtol = 0.02
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
def test_int8_moe(self):
for params in itertools.product(
self.M_int8,
self.N_int8,
self.K_int8,
self.E_int8,
self.topk_int8,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
):
self._int8_moe(*params)
def _fp8_moe(self, M, N, K, E, topk):
dtype = torch.bfloat16
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
w1_fp32 = torch.randn(E, 2 * N, K)
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = torch.randn(E, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w1s = torch.randn(E, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
w2s = torch.randn(E, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
w1_scaled = scaled_weight(w1, w1s)
w2_scaled = scaled_weight(w2, w2s)
score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
w1 = kernel.convert_weight_packed(w1)
w2 = kernel.convert_weight_packed(w2)
ref_out = native_fp8_fused_moe(
a, w1_scaled, w2_scaled, topk_weight, topk_ids, topk
)
out = kernel.fused_experts_cpu(
a,
w1,
w2,
topk_weight,
topk_ids.to(torch.int32),
False,
False,
True,
w1s,
w2s,
[BLOCK_N, BLOCK_K],
None,
None,
True,
)
atol = rtol = precision[dtype]
self.assertTrue(torch.allclose(ref_out.bfloat16(), out, atol=atol, rtol=rtol))
def test_fp8_moe(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.E_fp8,
self.topk_fp8,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
):
self._fp8_moe(*params)
if __name__ == "__main__":
unittest.main()

View File

@@ -148,3 +148,99 @@ def scaled_weight(weight, scales):
.contiguous()
.view(E, N, K)
)
def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
if renormalize:
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
0, 1
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def torch_w8a8_per_column_fused_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B, D = a.shape
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
# Repeat tokens to match topk
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
# Also repeat the scale
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
# Calculate routing
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
# First MLP layer: note that a_s is now per-token
inter_out = native_w8a8_per_token_matmul(
a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
bias=None,
output_dtype=torch.float32,
)
# Activation function
act_out = SiluAndMul(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = per_token_quant_int8(act_out)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(
act_out_q,
w2[i],
act_out_s,
w2_s[i],
bias=None,
output_dtype=torch.float32,
)
# Apply routing weights and sum
return (
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
.sum(dim=1)
.to(a.dtype)
)
def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D).float()
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
# Calculate routing
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
ic0 = torch.matmul(a[mask], w1[i].transpose(0, 1))
ic1 = SiluAndMul(ic0)
out[mask] = torch.matmul(ic1, w2[i].transpose(0, 1))
return (
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
.sum(dim=1)
.to(a.dtype)
)