From 5dd62c3a6f97ebffc022d5ae74353510f1a6631f Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Mon, 19 May 2025 03:42:15 +0800 Subject: [PATCH] Add fp8 shared_expert kernel for CPU in sgl-kernel and add UT (#6339) Co-authored-by: Jiang, Yanbing Co-authored-by: mingfeima --- sgl-kernel/csrc/cpu/gemm.h | 35 +++ sgl-kernel/csrc/cpu/gemm_fp8.cpp | 72 ++++--- sgl-kernel/csrc/cpu/moe.cpp | 43 ++++ sgl-kernel/csrc/cpu/moe_fp8.cpp | 205 ++++++++++++++++++ sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 2 + sgl-kernel/setup_cpu.py | 1 + test/srt/cpu/test_shared_expert.py | 223 ++++++++++++++++++++ test/srt/cpu/utils.py | 54 +++++ 8 files changed, 603 insertions(+), 32 deletions(-) create mode 100644 sgl-kernel/csrc/cpu/moe_fp8.cpp create mode 100644 test/srt/cpu/test_shared_expert.py diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h index e945cec04..5c3ff26bb 100644 --- a/sgl-kernel/csrc/cpu/gemm.h +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -104,6 +104,24 @@ void shared_expert_int8_kernel_impl( int64_t N, int64_t K); +template +void shared_expert_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K); + // tinygemm interface template void tinygemm_kernel( @@ -134,3 +152,20 @@ void tinygemm_kernel( int64_t ldb, int64_t ldc, bool brg); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K); diff --git a/sgl-kernel/csrc/cpu/gemm_fp8.cpp b/sgl-kernel/csrc/cpu/gemm_fp8.cpp index ae5d56cee..20dfed2da 100644 --- a/sgl-kernel/csrc/cpu/gemm_fp8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_fp8.cpp @@ -248,38 +248,6 @@ struct brgemm { } }; -template -struct brgemm { - static inline void apply( - const scalar_t* __restrict__ A, - const scalar_t* __restrict__ B, - scalar_t* __restrict__ C, - scalar_t* __restrict__ Btmp, - float* __restrict__ Ctmp, - const float* __restrict__ bias, - const float* __restrict__ scale, - int M, - int N, - int K, - int lda, - int ldb, - int ldc) { - UNUSED(scale); - - constexpr int BLOCK_N = block_size_n(); - at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp); - - // copy from Ctmp to C - for (int m = 0; m < M; ++m) { - if constexpr (has_bias) { - copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); - } else { - copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); - } - } - } -}; - template struct brgemm { static inline void apply( @@ -469,6 +437,46 @@ void fp8_scaled_mm_kernel_impl( } // anonymous namespace +// tinygemm interface +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K) { + tinygemm_kernel(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const TYPE* __restrict__ A, \ + const at::Float8_e4m3fn* __restrict__ B, \ + TYPE* __restrict__ C, \ + TYPE* __restrict__ Btmp, \ + float* __restrict__ Ctmp, \ + const float* __restrict__ scale, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t lda, \ + int64_t ldb, \ + int64_t ldc, \ + bool brg, \ + int64_t block_size_K) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + at::Tensor fp8_scaled_mm_cpu( at::Tensor& mat1, at::Tensor& mat2, diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index 05825e04f..6e12f1e38 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -1137,8 +1137,10 @@ at::Tensor shared_expert_cpu( double routed_scaling_factor, bool inplace, bool use_int8_w8a8, + bool use_fp8_w8a16, std::optional& w1_scale, std::optional& w2_scale, + std::optional> block_size, std::optional& a1_scale, std::optional& a2_scale, bool is_vnni) { @@ -1180,6 +1182,11 @@ at::Tensor shared_expert_cpu( TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported."); TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported."); } + if (use_fp8_w8a16) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16."); + TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16."); + } at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); @@ -1191,12 +1198,18 @@ at::Tensor shared_expert_cpu( // 3. Aq_tmp : [M, K] or [M, N] // 4. As_tmp : [M] // + // for fp8 w8a16: + // 5. intermediate_cache0 : [M, 2N] + // int num_threads = at::get_num_threads(); int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); if (use_int8_w8a8) { buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float); } + if (use_fp8_w8a16) { + buffer_size_nbytes += M * 2 * N * 2; + } auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] { @@ -1228,6 +1241,36 @@ at::Tensor shared_expert_cpu( M, N, K); + } else if (use_fp8_w8a16) { + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + + auto w1s = w1_scale.value(); + auto w2s = w2_scale.value(); + auto block_size_val = block_size.value(); + TORCH_CHECK(block_size_val.size() == 2, "shared_expert: expect block_size.size() to be 2."); + int64_t block_size_N = block_size_val[0]; + int64_t block_size_K = block_size_val[1]; + TORCH_CHECK(w1s.size(0) == 2 * N / block_size_N); + TORCH_CHECK(w1s.size(1) == K / block_size_K); + TORCH_CHECK(w2s.size(0) == K / block_size_N); + TORCH_CHECK(w2s.size(1) == N / block_size_K); + + shared_expert_fp8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache0, + intermediate_cache1, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + block_size_N, + block_size_K, + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); } else { shared_expert_kernel_impl( out_hidden_states.data_ptr(), diff --git a/sgl-kernel/csrc/cpu/moe_fp8.cpp b/sgl-kernel/csrc/cpu/moe_fp8.cpp new file mode 100644 index 000000000..77bf5fbb2 --- /dev/null +++ b/sgl-kernel/csrc/cpu/moe_fp8.cpp @@ -0,0 +1,205 @@ +#include "common.h" +#include "gemm.h" +#include "vec.h" + +namespace { + +// out = input + input2 * scale +template +inline void add_mul_stub( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ input2, + float scale, + int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + bVec x_bvec = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x_bvec); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +template +inline void silu_and_mul_stub( + scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const scalar_t* __restrict__ input2, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + const fVec one = fVec(1.f); + + // no remainder +#pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += bVec::size()) { + bVec x = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x); + bVec y = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y); + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + x0 = x0 * y0; + x1 = x1 * y1; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } +} + +} // anonymous namespace + +template +void shared_expert_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache0 = hidden_states @ w1 + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(2 * N, BLOCK_N); + int64_t scale_size_K = div_up(K, block_size_K); + int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const bool use_brgemm = can_use_brgemm(M); + + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + int64_t mb_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t nb_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + + tinygemm_kernel( + /* A */ input + mb * BLOCK_M * K, + /* B */ packed_w1 + nb * BLOCK_N * K, + /* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, + /* Btmp */ Btmp, + /* Ctmp */ Ctmp, + /* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ K, + /* ldb */ nb_size, + /* ldc */ 2 * N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + silu_and_mul_stub(ic1 + m * N, ic0 + m * 2 * N, ic0 + m * 2 * N + N, N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(K, BLOCK_N); + scale_size_K = div_up(N, block_size_K); + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N]; + alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; + alignas(64) float Ctmp[BLOCK_M * BLOCK_K]; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + int64_t mb_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t nb_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ ic1 + mb * BLOCK_M * N, + /* B */ packed_w2 + nb * BLOCK_N * N, + /* C */ C, + /* Btmp */ Btmp, + /* Ctmp */ Ctmp, + /* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K, + /* M */ mb_size, + /* N */ nb_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ nb_size, + /* ldc */ BLOCK_N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < mb_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, nb_size); + } + } + }); + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } +} + +#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \ + template void shared_expert_fp8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic0, \ + TYPE* __restrict__ ic1, \ + const TYPE* __restrict__ input, \ + const at::Float8_e4m3fn* __restrict__ packed_w1, \ + const at::Float8_e4m3fn* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + int64_t block_size_N, \ + int64_t block_size_K, \ + const TYPE* __restrict__ fused_experts_out, \ + float routed_scaling_factor, \ + int64_t M, \ + int64_t N, \ + int64_t K) + +INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16); +INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half); diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 1300e818e..efaa12aca 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -139,8 +139,10 @@ at::Tensor shared_expert_cpu( double routed_scaling_factor, bool inplace, bool use_int8_w8a8, + bool use_fp8_w8a16, std::optional& w1_scale, std::optional& w2_scale, + std::optional> block_size, std::optional& a1_scale, std::optional& a2_scale, bool is_vnni); diff --git a/sgl-kernel/setup_cpu.py b/sgl-kernel/setup_cpu.py index 7c1f2234e..b5f182dc2 100644 --- a/sgl-kernel/setup_cpu.py +++ b/sgl-kernel/setup_cpu.py @@ -61,6 +61,7 @@ sources = [ "csrc/cpu/gemm_fp8.cpp", "csrc/cpu/gemm_int8.cpp", "csrc/cpu/moe.cpp", + "csrc/cpu/moe_fp8.cpp", "csrc/cpu/moe_int8.cpp", "csrc/cpu/norm.cpp", "csrc/cpu/qkv_proj.cpp", diff --git a/test/srt/cpu/test_shared_expert.py b/test/srt/cpu/test_shared_expert.py new file mode 100644 index 000000000..900d985f4 --- /dev/null +++ b/test/srt/cpu/test_shared_expert.py @@ -0,0 +1,223 @@ +import itertools +import math +import unittest + +import torch +import torch.nn as nn + +# TODO: use interface in cpu.py +from sgl_kernel.common_ops import convert_weight_packed +from sgl_kernel.common_ops import shared_expert_cpu as shared_expert +from utils import ( + BLOCK_K, + BLOCK_N, + SiluAndMul, + factor_for_scale, + fp8_max, + fp8_min, + per_token_quant_int8, + precision, + scaled_weight, + torch_naive_moe, + torch_w8a8_per_column_moe, +) + +from sglang.test.test_utils import CustomTestCase + + +class TestSharedExpert(CustomTestCase): + M = [2, 121] + N = [32, 32 * 4] + K = [32, 32 * 2] + routed_scaling_factor = [16] + + M_fp8 = [2, 12] + N_fp8 = [512] + K_fp8 = [256] + + def _bf16_shared_expert(self, m, n, k, routed_scaling_factor): + dtype = torch.bfloat16 + prepack = True + + hidden_states = torch.randn(m, k, dtype=dtype) / k + w1 = torch.randn(2 * n, k, dtype=dtype) + w2 = torch.randn(k, n, dtype=dtype) + fused_output = torch.randn(m, k, dtype=dtype) / k + + # fused moe mutates content in hs + hidden_states2 = hidden_states.clone() + + # bfloat16 + ref = torch_naive_moe( + hidden_states.float(), + w1.float(), + w2.float(), + fused_output.float(), + routed_scaling_factor, + ).to(dtype=dtype) + res = shared_expert( + hidden_states, + w1, + w2, + fused_output, + routed_scaling_factor, + True, + False, + False, + None, + None, + None, + None, + None, + False, + ) + + atol = rtol = precision[ref.dtype] + self.assertTrue(torch.allclose(ref, res, atol=atol, rtol=rtol)) + + def test_bf16_shared_expert(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.routed_scaling_factor, + ): + with self.subTest( + m=params[0], + n=params[1], + k=params[2], + routed_scaling_factor=params[3], + ): + self._bf16_shared_expert(*params) + + def _int8_shared_expert(self, m, n, k, routed_scaling_factor): + dtype = torch.bfloat16 + prepack = True + + hidden_states = torch.randn(m, k, dtype=dtype) / k + w1 = torch.randn(2 * n, k, dtype=dtype) + w2 = torch.randn(k, n, dtype=dtype) + fused_output = torch.randn(m, k, dtype=dtype) / k + + # fused moe mutates content in hs + hidden_states2 = hidden_states.clone() + + w1_q, w1_s = per_token_quant_int8(w1) + w2_q, w2_s = per_token_quant_int8(w2) + ref2 = torch_w8a8_per_column_moe( + hidden_states2.float(), + w1_q, + w2_q, + w1_s, + w2_s, + fused_output.float(), + routed_scaling_factor, + ).to(dtype=dtype) + res2 = shared_expert( + hidden_states2, + w1_q, + w2_q, + fused_output, + routed_scaling_factor, + True, + True, + False, + w1_s, + w2_s, + None, + None, + None, + False, + ) + + atol = rtol = precision[ref2.dtype] + self.assertTrue(torch.allclose(ref2, res2, atol=atol, rtol=rtol)) + + def test_int8_shared_expert(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.routed_scaling_factor, + ): + with self.subTest( + m=params[0], + n=params[1], + k=params[2], + routed_scaling_factor=params[3], + ): + self._int8_shared_expert(*params) + + def _fp8_shared_expert(self, M, N, K, routed_scaling_factor): + dtype = torch.bfloat16 + prepack = True + + a = torch.randn(M, K, dtype=dtype) / math.sqrt(K) + + w1_fp32 = torch.randn(1, 2 * N, K) + w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w2_fp32 = torch.randn(1, K, N) + w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w1s = torch.randn(1, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale + w2s = torch.randn(1, K // BLOCK_N, N // BLOCK_K) * factor_for_scale + + w1_scaled = scaled_weight(w1, w1s).view(2 * N, K) + w2_scaled = scaled_weight(w2, w2s).view(K, N) + + # change back to 2D + w1, w2 = w1.squeeze(0), w2.squeeze(0) + w1s, w2s = w1s.squeeze(0), w2s.squeeze(0) + w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0) + + fused_out = torch.randn(M, K, dtype=dtype) / math.sqrt(K) + a2 = a.clone() + + # ref + ic0 = torch.matmul(a.float(), w1_scaled.transpose(0, 1)) + ic1 = SiluAndMul(ic0) + shared_out = torch.matmul(ic1, w2_scaled.transpose(0, 1)) + ref_out = shared_out + fused_out.float() * routed_scaling_factor + ref_out = ref_out.to(dtype=dtype) + + w1 = convert_weight_packed(w1) # [2N, K] + w2 = convert_weight_packed(w2) # [K, N] + out = shared_expert( + a2, + w1, + w2, + fused_out, + routed_scaling_factor, + True, + False, + True, + w1s, + w2s, + [BLOCK_N, BLOCK_K], + None, + None, + True, + ) + + atol = rtol = precision[ref_out.dtype] + self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + + def test_fp8_shared_expert(self): + for params in itertools.product( + self.M_fp8, + self.N_fp8, + self.K_fp8, + self.routed_scaling_factor, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + routed_scaling_factor=params[3], + ): + self._fp8_shared_expert(*params) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/cpu/utils.py b/test/srt/cpu/utils.py index 551f5dedf..4665eb2cf 100644 --- a/test/srt/cpu/utils.py +++ b/test/srt/cpu/utils.py @@ -1,6 +1,7 @@ import math import torch +import torch.nn.functional as F precision = { torch.bfloat16: 1e-2, @@ -9,6 +10,16 @@ precision = { } +BLOCK_N, BLOCK_K = 64, 128 +factor_for_scale = 1e-3 +fp8_max, fp8_min = 400, -400 + + +def SiluAndMul(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + def per_token_quant_int8(x): x = x.float() absmax = x.abs().max(dim=-1).values @@ -94,3 +105,46 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16 C.add_(bias.view(1, -1)) return C.reshape(origin_C_shape).to(output_dtype) + + +def torch_naive_moe(a, w1, w2, b, routed_scaling_factor): + + ic1 = torch.matmul(a, w1.transpose(0, 1)) + ic2 = SiluAndMul(ic1) + ic3 = torch.matmul(ic2, w2.transpose(0, 1)) + + return ic3 + b * routed_scaling_factor + + +def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_factor): + + # Perform per-token quantization + a_q, a_s = per_token_quant_int8(a) + + ic1 = native_w8a8_per_token_matmul( + a_q, w1_q, a_s, w1_s, bias=None, output_dtype=torch.float32 + ) + ic2 = SiluAndMul(ic1) + + a1_q, a1_s = per_token_quant_int8(ic2) + ic3 = native_w8a8_per_token_matmul( + a1_q, w2_q, a1_s, w2_s, bias=None, output_dtype=torch.float32 + ) + + return ic3 + b * routed_scaling_factor + + +def scaled_weight(weight, scales): + E, N, K = weight.shape + weight_block = ( + weight.view(E, N // BLOCK_N, BLOCK_N, K // BLOCK_K, BLOCK_K) + .permute(0, 1, 3, 2, 4) + .float() + .contiguous() + ) + return ( + (weight_block * scales.view(E, N // BLOCK_N, K // BLOCK_K, 1, 1)) + .permute(0, 1, 3, 2, 4) + .contiguous() + .view(E, N, K) + )