Add fp8 shared_expert kernel for CPU in sgl-kernel and add UT (#6339)
Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com> Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
223
test/srt/cpu/test_shared_expert.py
Normal file
223
test/srt/cpu/test_shared_expert.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import itertools
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
from sgl_kernel.common_ops import convert_weight_packed
|
||||
from sgl_kernel.common_ops import shared_expert_cpu as shared_expert
|
||||
from utils import (
|
||||
BLOCK_K,
|
||||
BLOCK_N,
|
||||
SiluAndMul,
|
||||
factor_for_scale,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
per_token_quant_int8,
|
||||
precision,
|
||||
scaled_weight,
|
||||
torch_naive_moe,
|
||||
torch_w8a8_per_column_moe,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestSharedExpert(CustomTestCase):
|
||||
M = [2, 121]
|
||||
N = [32, 32 * 4]
|
||||
K = [32, 32 * 2]
|
||||
routed_scaling_factor = [16]
|
||||
|
||||
M_fp8 = [2, 12]
|
||||
N_fp8 = [512]
|
||||
K_fp8 = [256]
|
||||
|
||||
def _bf16_shared_expert(self, m, n, k, routed_scaling_factor):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
hidden_states = torch.randn(m, k, dtype=dtype) / k
|
||||
w1 = torch.randn(2 * n, k, dtype=dtype)
|
||||
w2 = torch.randn(k, n, dtype=dtype)
|
||||
fused_output = torch.randn(m, k, dtype=dtype) / k
|
||||
|
||||
# fused moe mutates content in hs
|
||||
hidden_states2 = hidden_states.clone()
|
||||
|
||||
# bfloat16
|
||||
ref = torch_naive_moe(
|
||||
hidden_states.float(),
|
||||
w1.float(),
|
||||
w2.float(),
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res = shared_expert(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
fused_output,
|
||||
routed_scaling_factor,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref.dtype]
|
||||
self.assertTrue(torch.allclose(ref, res, atol=atol, rtol=rtol))
|
||||
|
||||
def test_bf16_shared_expert(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.routed_scaling_factor,
|
||||
):
|
||||
with self.subTest(
|
||||
m=params[0],
|
||||
n=params[1],
|
||||
k=params[2],
|
||||
routed_scaling_factor=params[3],
|
||||
):
|
||||
self._bf16_shared_expert(*params)
|
||||
|
||||
def _int8_shared_expert(self, m, n, k, routed_scaling_factor):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
hidden_states = torch.randn(m, k, dtype=dtype) / k
|
||||
w1 = torch.randn(2 * n, k, dtype=dtype)
|
||||
w2 = torch.randn(k, n, dtype=dtype)
|
||||
fused_output = torch.randn(m, k, dtype=dtype) / k
|
||||
|
||||
# fused moe mutates content in hs
|
||||
hidden_states2 = hidden_states.clone()
|
||||
|
||||
w1_q, w1_s = per_token_quant_int8(w1)
|
||||
w2_q, w2_s = per_token_quant_int8(w2)
|
||||
ref2 = torch_w8a8_per_column_moe(
|
||||
hidden_states2.float(),
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_s,
|
||||
w2_s,
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res2 = shared_expert(
|
||||
hidden_states2,
|
||||
w1_q,
|
||||
w2_q,
|
||||
fused_output,
|
||||
routed_scaling_factor,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
w1_s,
|
||||
w2_s,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref2.dtype]
|
||||
self.assertTrue(torch.allclose(ref2, res2, atol=atol, rtol=rtol))
|
||||
|
||||
def test_int8_shared_expert(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.routed_scaling_factor,
|
||||
):
|
||||
with self.subTest(
|
||||
m=params[0],
|
||||
n=params[1],
|
||||
k=params[2],
|
||||
routed_scaling_factor=params[3],
|
||||
):
|
||||
self._int8_shared_expert(*params)
|
||||
|
||||
def _fp8_shared_expert(self, M, N, K, routed_scaling_factor):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
|
||||
|
||||
w1_fp32 = torch.randn(1, 2 * N, K)
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = torch.randn(1, K, N)
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w1s = torch.randn(1, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
|
||||
w2s = torch.randn(1, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
|
||||
|
||||
w1_scaled = scaled_weight(w1, w1s).view(2 * N, K)
|
||||
w2_scaled = scaled_weight(w2, w2s).view(K, N)
|
||||
|
||||
# change back to 2D
|
||||
w1, w2 = w1.squeeze(0), w2.squeeze(0)
|
||||
w1s, w2s = w1s.squeeze(0), w2s.squeeze(0)
|
||||
w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0)
|
||||
|
||||
fused_out = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
|
||||
a2 = a.clone()
|
||||
|
||||
# ref
|
||||
ic0 = torch.matmul(a.float(), w1_scaled.transpose(0, 1))
|
||||
ic1 = SiluAndMul(ic0)
|
||||
shared_out = torch.matmul(ic1, w2_scaled.transpose(0, 1))
|
||||
ref_out = shared_out + fused_out.float() * routed_scaling_factor
|
||||
ref_out = ref_out.to(dtype=dtype)
|
||||
|
||||
w1 = convert_weight_packed(w1) # [2N, K]
|
||||
w2 = convert_weight_packed(w2) # [K, N]
|
||||
out = shared_expert(
|
||||
a2,
|
||||
w1,
|
||||
w2,
|
||||
fused_out,
|
||||
routed_scaling_factor,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
w1s,
|
||||
w2s,
|
||||
[BLOCK_N, BLOCK_K],
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||
|
||||
def test_fp8_shared_expert(self):
|
||||
for params in itertools.product(
|
||||
self.M_fp8,
|
||||
self.N_fp8,
|
||||
self.K_fp8,
|
||||
self.routed_scaling_factor,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
routed_scaling_factor=params[3],
|
||||
):
|
||||
self._fp8_shared_expert(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
precision = {
|
||||
torch.bfloat16: 1e-2,
|
||||
@@ -9,6 +10,16 @@ precision = {
|
||||
}
|
||||
|
||||
|
||||
BLOCK_N, BLOCK_K = 64, 128
|
||||
factor_for_scale = 1e-3
|
||||
fp8_max, fp8_min = 400, -400
|
||||
|
||||
|
||||
def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
|
||||
def per_token_quant_int8(x):
|
||||
x = x.float()
|
||||
absmax = x.abs().max(dim=-1).values
|
||||
@@ -94,3 +105,46 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16
|
||||
C.add_(bias.view(1, -1))
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_naive_moe(a, w1, w2, b, routed_scaling_factor):
|
||||
|
||||
ic1 = torch.matmul(a, w1.transpose(0, 1))
|
||||
ic2 = SiluAndMul(ic1)
|
||||
ic3 = torch.matmul(ic2, w2.transpose(0, 1))
|
||||
|
||||
return ic3 + b * routed_scaling_factor
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_factor):
|
||||
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
|
||||
ic1 = native_w8a8_per_token_matmul(
|
||||
a_q, w1_q, a_s, w1_s, bias=None, output_dtype=torch.float32
|
||||
)
|
||||
ic2 = SiluAndMul(ic1)
|
||||
|
||||
a1_q, a1_s = per_token_quant_int8(ic2)
|
||||
ic3 = native_w8a8_per_token_matmul(
|
||||
a1_q, w2_q, a1_s, w2_s, bias=None, output_dtype=torch.float32
|
||||
)
|
||||
|
||||
return ic3 + b * routed_scaling_factor
|
||||
|
||||
|
||||
def scaled_weight(weight, scales):
|
||||
E, N, K = weight.shape
|
||||
weight_block = (
|
||||
weight.view(E, N // BLOCK_N, BLOCK_N, K // BLOCK_K, BLOCK_K)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.float()
|
||||
.contiguous()
|
||||
)
|
||||
return (
|
||||
(weight_block * scales.view(E, N // BLOCK_N, K // BLOCK_K, 1, 1))
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.contiguous()
|
||||
.view(E, N, K)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user