Add fp8 shared_expert kernel for CPU in sgl-kernel and add UT (#6339)
Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com> Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
precision = {
|
||||
torch.bfloat16: 1e-2,
|
||||
@@ -9,6 +10,16 @@ precision = {
|
||||
}
|
||||
|
||||
|
||||
BLOCK_N, BLOCK_K = 64, 128
|
||||
factor_for_scale = 1e-3
|
||||
fp8_max, fp8_min = 400, -400
|
||||
|
||||
|
||||
def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
|
||||
def per_token_quant_int8(x):
|
||||
x = x.float()
|
||||
absmax = x.abs().max(dim=-1).values
|
||||
@@ -94,3 +105,46 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16
|
||||
C.add_(bias.view(1, -1))
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_naive_moe(a, w1, w2, b, routed_scaling_factor):
|
||||
|
||||
ic1 = torch.matmul(a, w1.transpose(0, 1))
|
||||
ic2 = SiluAndMul(ic1)
|
||||
ic3 = torch.matmul(ic2, w2.transpose(0, 1))
|
||||
|
||||
return ic3 + b * routed_scaling_factor
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_factor):
|
||||
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
|
||||
ic1 = native_w8a8_per_token_matmul(
|
||||
a_q, w1_q, a_s, w1_s, bias=None, output_dtype=torch.float32
|
||||
)
|
||||
ic2 = SiluAndMul(ic1)
|
||||
|
||||
a1_q, a1_s = per_token_quant_int8(ic2)
|
||||
ic3 = native_w8a8_per_token_matmul(
|
||||
a1_q, w2_q, a1_s, w2_s, bias=None, output_dtype=torch.float32
|
||||
)
|
||||
|
||||
return ic3 + b * routed_scaling_factor
|
||||
|
||||
|
||||
def scaled_weight(weight, scales):
|
||||
E, N, K = weight.shape
|
||||
weight_block = (
|
||||
weight.view(E, N // BLOCK_N, BLOCK_N, K // BLOCK_K, BLOCK_K)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.float()
|
||||
.contiguous()
|
||||
)
|
||||
return (
|
||||
(weight_block * scales.view(E, N // BLOCK_N, K // BLOCK_K, 1, 1))
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.contiguous()
|
||||
.view(E, N, K)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user