Add fp8 fused_experts kernel for CPU in sgl-kernel and add UT (#6404)
This commit is contained in:
259
test/srt/cpu/test_moe.py
Normal file
259
test/srt/cpu/test_moe.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import itertools
|
||||
import math
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
kernel = torch.ops.sgl_kernel
|
||||
|
||||
from utils import (
|
||||
BLOCK_K,
|
||||
BLOCK_N,
|
||||
factor_for_scale,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
native_fp8_fused_moe,
|
||||
precision,
|
||||
scaled_weight,
|
||||
torch_naive_fused_moe,
|
||||
torch_w8a8_per_column_fused_moe,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
def fused_moe(a, w1, w2, score, topk, renormalize, prepack):
|
||||
|
||||
G = 1
|
||||
topk_group = 1
|
||||
|
||||
B, D = a.shape
|
||||
topk_weights = torch.empty(B, topk, dtype=torch.float32)
|
||||
topk_ids = torch.empty(B, topk, dtype=torch.int32)
|
||||
topk_weights, topk_ids = kernel.grouped_topk_cpu(
|
||||
a, score, topk, renormalize, G, topk_group
|
||||
)
|
||||
|
||||
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
|
||||
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
|
||||
|
||||
inplace = True
|
||||
return kernel.fused_experts_cpu(
|
||||
a,
|
||||
packed_w1,
|
||||
packed_w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
prepack,
|
||||
)
|
||||
|
||||
|
||||
class TestFusedExperts(CustomTestCase):
|
||||
M = [2, 114]
|
||||
N = [32]
|
||||
K = [32]
|
||||
E = [4]
|
||||
topk = [2]
|
||||
renormalize = [False, True]
|
||||
|
||||
M_int8 = [1, 39]
|
||||
N_int8 = [128]
|
||||
K_int8 = [256]
|
||||
E_int8 = [8]
|
||||
topk_int8 = [3]
|
||||
|
||||
M_fp8 = [2, 121]
|
||||
N_fp8 = [512]
|
||||
K_fp8 = [256]
|
||||
E_fp8 = [8]
|
||||
topk_fp8 = [4]
|
||||
|
||||
def _bf16_moe(self, m, n, k, e, topk, renormalize):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
a = torch.randn((m, k), device="cpu", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cpu", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cpu", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cpu", dtype=dtype)
|
||||
|
||||
torch_output = torch_naive_fused_moe(a, w1, w2, score, topk, renormalize)
|
||||
fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack)
|
||||
|
||||
atol = rtol = precision[torch_output.dtype]
|
||||
self.assertTrue(
|
||||
torch.allclose(torch_output, fused_output, atol=atol, rtol=rtol)
|
||||
)
|
||||
|
||||
def test_bf16_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.E,
|
||||
self.topk,
|
||||
self.renormalize,
|
||||
):
|
||||
with self.subTest(
|
||||
m=params[0],
|
||||
n=params[1],
|
||||
k=params[2],
|
||||
e=params[3],
|
||||
topk=params[4],
|
||||
renormalize=params[5],
|
||||
):
|
||||
self._bf16_moe(*params)
|
||||
|
||||
def _int8_moe(self, M, N, K, E, topk):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
# Initialize int8 quantization parameters
|
||||
int8_factor_for_scale = 1e-2
|
||||
int8_max = 127
|
||||
int8_min = -128
|
||||
|
||||
# Input tensor
|
||||
# M * K
|
||||
a = torch.randn((M, K), dtype=dtype) / math.sqrt(K)
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * int8_factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * int8_factor_for_scale
|
||||
|
||||
# Calculate routing
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
ref_out = torch_w8a8_per_column_fused_moe(
|
||||
a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk
|
||||
)
|
||||
|
||||
inplace = True
|
||||
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
|
||||
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
|
||||
out = kernel.fused_experts_cpu(
|
||||
a,
|
||||
packed_w1,
|
||||
packed_w2,
|
||||
topk_weight,
|
||||
topk_ids.to(torch.int32),
|
||||
inplace,
|
||||
True,
|
||||
False,
|
||||
w1_s,
|
||||
w2_s,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
prepack,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
# Increase the tolerance for large input shapes
|
||||
if M > 35:
|
||||
atol = rtol = 0.02
|
||||
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||
|
||||
def test_int8_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M_int8,
|
||||
self.N_int8,
|
||||
self.K_int8,
|
||||
self.E_int8,
|
||||
self.topk_int8,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
topk=params[4],
|
||||
):
|
||||
self._int8_moe(*params)
|
||||
|
||||
def _fp8_moe(self, M, N, K, E, topk):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
|
||||
|
||||
w1_fp32 = torch.randn(E, 2 * N, K)
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = torch.randn(E, K, N)
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w1s = torch.randn(E, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
|
||||
w2s = torch.randn(E, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
|
||||
|
||||
w1_scaled = scaled_weight(w1, w1s)
|
||||
w2_scaled = scaled_weight(w2, w2s)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
w1 = kernel.convert_weight_packed(w1)
|
||||
w2 = kernel.convert_weight_packed(w2)
|
||||
|
||||
ref_out = native_fp8_fused_moe(
|
||||
a, w1_scaled, w2_scaled, topk_weight, topk_ids, topk
|
||||
)
|
||||
out = kernel.fused_experts_cpu(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids.to(torch.int32),
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
w1s,
|
||||
w2s,
|
||||
[BLOCK_N, BLOCK_K],
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
atol = rtol = precision[dtype]
|
||||
self.assertTrue(torch.allclose(ref_out.bfloat16(), out, atol=atol, rtol=rtol))
|
||||
|
||||
def test_fp8_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M_fp8,
|
||||
self.N_fp8,
|
||||
self.K_fp8,
|
||||
self.E_fp8,
|
||||
self.topk_fp8,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
topk=params[4],
|
||||
):
|
||||
self._fp8_moe(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -148,3 +148,99 @@ def scaled_weight(weight, scales):
|
||||
.contiguous()
|
||||
.view(E, N, K)
|
||||
)
|
||||
|
||||
|
||||
def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
if renormalize:
|
||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
||||
0, 1
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_fused_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk):
|
||||
"""This function performs fused moe with per-column int8 quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
bias=None,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = per_token_quant_int8(act_out)
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
bias=None,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (
|
||||
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
|
||||
.sum(dim=1)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
|
||||
def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D).float()
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
ic0 = torch.matmul(a[mask], w1[i].transpose(0, 1))
|
||||
ic1 = SiluAndMul(ic0)
|
||||
out[mask] = torch.matmul(ic1, w2[i].transpose(0, 1))
|
||||
|
||||
return (
|
||||
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
|
||||
.sum(dim=1)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user