From fb4959b2c5697e33969f73f2989d61efe6a820df Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Fri, 16 May 2025 00:10:40 +0800 Subject: [PATCH] Add fp8 gemm kernel for CPU in sgl-kernel and add gemm UT (#6216) Co-authored-by: YanbingJiang Co-authored-by: mingfeima --- sgl-kernel/csrc/cpu/common.h | 6 +- sgl-kernel/csrc/cpu/gemm.cpp | 3 +- sgl-kernel/csrc/cpu/gemm.h | 5 + sgl-kernel/csrc/cpu/gemm_fp8.cpp | 543 ++++++++++++++++++++ sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 13 + sgl-kernel/csrc/cpu/vec.h | 60 +++ sgl-kernel/setup_cpu.py | 6 + test/srt/cpu/test_gemm.py | 191 +++++++ test/srt/cpu/utils.py | 96 ++++ 9 files changed, 921 insertions(+), 2 deletions(-) create mode 100644 sgl-kernel/csrc/cpu/gemm_fp8.cpp create mode 100644 test/srt/cpu/test_gemm.py create mode 100644 test/srt/cpu/utils.py diff --git a/sgl-kernel/csrc/cpu/common.h b/sgl-kernel/csrc/cpu/common.h index 1acdd64a6..6f09a0922 100644 --- a/sgl-kernel/csrc/cpu/common.h +++ b/sgl-kernel/csrc/cpu/common.h @@ -22,7 +22,7 @@ namespace { } \ }() -// dispatch: bfloat16, float16, int8_t +// dispatch: bfloat16, float16, int8_t, fp8_e4m3 #define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \ [&] { \ switch (TYPE) { \ @@ -38,6 +38,10 @@ namespace { using packed_t = int8_t; \ return __VA_ARGS__(); \ } \ + case at::ScalarType::Float8_e4m3fn: { \ + using packed_t = at::Float8_e4m3fn; \ + return __VA_ARGS__(); \ + } \ default: \ TORCH_CHECK(false, "Unsupported floating data type.\n"); \ } \ diff --git a/sgl-kernel/csrc/cpu/gemm.cpp b/sgl-kernel/csrc/cpu/gemm.cpp index 97c0e7935..68dbd4896 100644 --- a/sgl-kernel/csrc/cpu/gemm.cpp +++ b/sgl-kernel/csrc/cpu/gemm.cpp @@ -424,7 +424,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { const int64_t stride = OC * IC; TORCH_CHECK( - st == at::kBFloat16 || st == at::kHalf || st == at::kChar, "expect weight to be bfloat16, float16 or int8."); + st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn, + "expect weight to be bfloat16, float16, int8 or fp8_e4m3."); CPU_DISPATCH_PACKED_TYPES(st, [&] { // adjust most inner dimension size diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h index 026e158a0..e945cec04 100644 --- a/sgl-kernel/csrc/cpu/gemm.h +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -33,6 +33,11 @@ inline bool can_use_brgemm(int M) { return false; } +template <> +inline bool can_use_brgemm(int M) { + return M > 4; +} + // work around compiler internal error #define BLOCK_K 128 // 4 * TILE_K diff --git a/sgl-kernel/csrc/cpu/gemm_fp8.cpp b/sgl-kernel/csrc/cpu/gemm_fp8.cpp new file mode 100644 index 000000000..ae5d56cee --- /dev/null +++ b/sgl-kernel/csrc/cpu/gemm_fp8.cpp @@ -0,0 +1,543 @@ +#include "common.h" +#include "gemm.h" +#include "vec.h" + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } +} + +template +inline void copy_add_stub( + scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); + fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + bias[d]); + } +} + +inline void unpack_B( + at::BFloat16* __restrict__ Btmp, + const at::Float8_e4m3fn* __restrict__ packed_B, + int N, + int K, + int ldb, + int ldb_tmp, + float scale) { +#if defined(CPU_CAPABILITY_AVX512) + // [K/2, N, 2] + const int K2 = K >> 1; + const int ldb2 = ldb; // ldb * 2 >> 1; + const uint16_t* b_ptr = reinterpret_cast(packed_B); + const __m512 vd = _mm512_set1_ps(scale); + + constexpr int BLOCK_N = block_size_n(); + static_assert(BLOCK_N == 32); + + for (int k = 0; k < K2; ++k) { + for (int n = 0; n < N; n += 64) { // BLOCK_N = 32 + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + n); + + __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); + __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); + + __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); + __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); + + // Apply scale + __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); + __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); + __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); + __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); + + f0_lo = _mm512_mul_ps(f0_lo, vd); + f0_hi = _mm512_mul_ps(f0_hi, vd); + f1_lo = _mm512_mul_ps(f1_lo, vd); + f1_hi = _mm512_mul_ps(f1_hi, vd); + + bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); + bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); + + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 0, (__m512i)bf16_0); + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 32, (__m512i)bf16_1); + } + } +#else + TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); +#endif +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, + const packed_t* __restrict__ B, + scalar_t* __restrict__ C, + const float* __restrict__ bias, + const float* __restrict__ scale, + int K, + int lda, + int ldb, + int ldc, + int64_t block_size_K) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + at::BFloat16* __restrict__ C, + const float* __restrict__ bias, + const float* __restrict__ scale, + int K, + int lda, + int ldb, + int ldc, + int64_t block_size_K) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_set1_ps(0.f); + } + }; + Unroll{}(loadc); + + const int K2 = K >> 1; + const int lda2 = lda >> 1; + const int ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const uint16_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + int idx = k * 2 / block_size_K; + const __m512 vd = _mm512_set1_ps(scale[idx]); + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + + __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); + __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); + + __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); + __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); + + // Apply scale + __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); + __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); + __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); + __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); + + f0_lo = _mm512_mul_ps(f0_lo, vd); + f0_hi = _mm512_mul_ps(f0_hi, vd); + f1_lo = _mm512_mul_ps(f1_lo, vd); + f1_hi = _mm512_mul_ps(f1_hi, vd); + + vb[col + 0] = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); + vb[col + 1] = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + for (int k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 1, 3 use 256bit store + // for COLS = 2, 4 use 512bit store + if constexpr (COLS % 2 == 0) { + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + } else { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i]))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, \ + B + nb_start * 2, \ + C + mb_start * ldc + nb_start, \ + has_bias ? bias + nb_start : nullptr, \ + scale, \ + K, \ + lda, \ + ldb, \ + ldc, \ + block_size_K); + +template +struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, + const packed_t* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const float* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc) { + TORCH_CHECK(false, "struct brgemm: primary template not implemented!"); + } +}; + +template +struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const float* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc) { + UNUSED(scale); + + constexpr int BLOCK_N = block_size_n(); + at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp); + + // copy from Ctmp to C + for (int m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template +struct brgemm { + static inline void apply( + const at::BFloat16* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + at::BFloat16* __restrict__ C, + at::BFloat16* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const float* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc) { + constexpr int BLOCK_N = block_size_n(); + + // [BLOCK_K, BLOCK_N] -> [BLOCK_K / 2, BLOCK_N * 2] + const int ldb_tmp = block_size_n(); + + static_assert(BLOCK_K == 128); + + // accumulate across K per BLOCK_K + for (int k = 0; k < K; k += BLOCK_K) { + int kb_size = std::min(BLOCK_K, K - k); + + int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 + unpack_B(Btmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); + + const bool add_C = (k != 0); + at::native::cpublas::brgemm(M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, A + k, Btmp, Ctmp); + } + + // copy from Ctmp to C + for (int m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K) { + if (brg) { + brgemm::apply(A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc); + return; + } + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: + LAUNCH_TINYGEMM_KERNEL_NN(1, 32); + break; + case 0x14: + LAUNCH_TINYGEMM_KERNEL_NN(1, 64); + break; + // mb_size = 2 + case 0x22: + LAUNCH_TINYGEMM_KERNEL_NN(2, 32); + break; + case 0x24: + LAUNCH_TINYGEMM_KERNEL_NN(2, 64); + break; + // mb_size = 3 + case 0x32: + LAUNCH_TINYGEMM_KERNEL_NN(3, 32); + break; + case 0x34: + LAUNCH_TINYGEMM_KERNEL_NN(3, 64); + break; + // mb_size = 4 + case 0x42: + LAUNCH_TINYGEMM_KERNEL_NN(4, 32); + break; + case 0x44: + LAUNCH_TINYGEMM_KERNEL_NN(4, 64); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void fp8_scaled_mm_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ mat1, + const at::Float8_e4m3fn* __restrict__ mat2, + const float* __restrict__ scales2, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t mat1_strideM, + int64_t out_strideM, + int64_t block_size_N, + int64_t block_size_K) { + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + const int64_t scale_size_K = div_up(K, block_size_K); + const int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const bool use_brgemm = can_use_brgemm(M); + + // parallel on [MB, NB] + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + // for brgemm when mat2 is float8_e4m3 + alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; + + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * mat1_strideM, + /* B */ mat2 + nb_start * K, // nb * BLOCK_N * K + /* C */ out + mb_start * out_strideM + nb_start, + /* Btmp */ Btmp, + /* Ctmp */ Ctmp, + /* scale */ scale_ptr, + /* bias */ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +at::Tensor fp8_scaled_mm_cpu( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales2, + std::vector block_size, + std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, block_size, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales2 to be float32."); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1); + + CHECK_EQ(mat1.size(1), K); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + TORCH_CHECK(block_size.size() == 2, "fp8_scaled_mm_cpu: expect block_size.size() to be 2."); + + int64_t block_size_N = block_size[0]; + int64_t block_size_K = block_size[1]; + + constexpr int64_t BLOCK_N = block_size_n(); + TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); + TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); + CHECK_EQ(scales2.size(0), div_up(N, block_size_N)); + CHECK_EQ(scales2.size(1), div_up(K, block_size_K)); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "fp8_scaled_mm_cpu: expect A to be bfloat16 or half."); + TORCH_CHECK(st == out_dtype, "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype."); + TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3."); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales to be float32."); + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + // strides + int64_t mat1_strideM = mat1.stride(0); + int64_t out_strideM = out.stride(0); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { + fp8_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales2.data_ptr(), + bias_data, + M, + N, + K, + mat1_strideM, + out_strideM, + block_size_N, + block_size_K); + }); + + return out; +} diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 02db7f61e..1300e818e 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -94,6 +94,16 @@ at::Tensor int8_scaled_mm_cpu( at::ScalarType out_dtype, bool is_vnni); +// fp8 gemm +at::Tensor fp8_scaled_mm_cpu( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales2, + std::vector block_size, + std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni); + // quant + igemm at::Tensor int8_scaled_mm_with_quant( at::Tensor& mat1, @@ -198,6 +208,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // igemm m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX"); + // fp8 gemm + m.def("fp8_scaled_mm_cpu", &fp8_scaled_mm_cpu, "fp8 weight packed linear for intel AMX"); + // quant + igemm m.def( "int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX"); diff --git a/sgl-kernel/csrc/cpu/vec.h b/sgl-kernel/csrc/cpu/vec.h index e058bd716..78e8b8f17 100644 --- a/sgl-kernel/csrc/cpu/vec.h +++ b/sgl-kernel/csrc/cpu/vec.h @@ -30,6 +30,66 @@ convert_from_float_ext(const Vectorized& a, const Vectorize #define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) +inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) { + // The following conversion is without denorm behavior, that is to say, + // Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6) + // Min subnorm : S.0000.001 = 2**(−9) + // 0.0019 ~ 0.0137 cannot be converted correctly. + __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + auto mask = _mm512_cmpneq_epi16_mask( + _mm512_and_si512(x, _mm512_set1_epi16(127)), + _mm512_setzero_si512()); // mask = x & 0x7f + auto mask_nan = _mm512_cmpneq_epi16_mask( + _mm512_and_si512(x, _mm512_set1_epi16(127)), + _mm512_set1_epi16(127)); // mask_nan = x & 0x7f + auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4 + auto exponent = _mm512_add_epi16( + _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), + _mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120) + auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7))); + nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan + return (__m512bh)(_mm512_or_si512( + nonsign, + _mm512_slli_epi16( + _mm512_and_si512(x, _mm512_set1_epi16(128)), + 8))); // add sign (x & 128) << 8 +} + +inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) { + __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + __m512i lg2mant = _mm512_mask_mov_epi16( + _mm512_mask_mov_epi16( + _mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)), + _mm512_test_epi16_mask(x, _mm512_set1_epi16(4)), + _mm512_set1_epi16(2)); + return (__m512bh)(_mm512_or_si512( + _mm512_maskz_mov_epi16( + _mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()), + _mm512_mask_blend_epi16( + _mm512_test_epi16_mask(x, _mm512_set1_epi16(120)), + _mm512_or_si512( + _mm512_and_si512( + _mm512_sllv_epi16( + _mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)), + _mm512_set1_epi16(0x007f)), + _mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)), + _mm512_or_si512( + _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4), + _mm512_slli_epi16( + _mm512_add_epi16( + _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)), + 7)))), + _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8))); +} + +inline __m512bh CVT_FP8_TO_BF16(__m256i a) { +#ifdef SGLANG_CPU_FP8_CVT_FTZ + return cvt_e4m3_bf16_intrinsic_without_denorm(a); +#else + return cvt_e4m3_bf16_intrinsic_with_denorm(a); +#endif +} + #endif // vector to scalar reduction diff --git a/sgl-kernel/setup_cpu.py b/sgl-kernel/setup_cpu.py index 6237011bc..7c1f2234e 100644 --- a/sgl-kernel/setup_cpu.py +++ b/sgl-kernel/setup_cpu.py @@ -47,6 +47,8 @@ def _get_version(): return line.split("=")[1].strip().strip('"') +cpu_fp8_ftz = os.getenv("SGLANG_CPU_FP8_CVT_FTZ", "1") == "1" + operator_namespace = "sgl_kernel" include_dirs = [] @@ -56,6 +58,7 @@ sources = [ "csrc/cpu/decode.cpp", "csrc/cpu/extend.cpp", "csrc/cpu/gemm.cpp", + "csrc/cpu/gemm_fp8.cpp", "csrc/cpu/gemm_int8.cpp", "csrc/cpu/moe.cpp", "csrc/cpu/moe_int8.cpp", @@ -76,6 +79,9 @@ extra_compile_args = { "-fopenmp", ] } +if cpu_fp8_ftz: + extra_compile_args["cxx"].append("-DSGLANG_CPU_FP8_CVT_FTZ") + libraries = ["c10", "torch", "torch_python"] cmdclass = { "build_ext": BuildExtension.with_options(use_ninja=True), diff --git a/test/srt/cpu/test_gemm.py b/test/srt/cpu/test_gemm.py new file mode 100644 index 000000000..cc94bd3a0 --- /dev/null +++ b/test/srt/cpu/test_gemm.py @@ -0,0 +1,191 @@ +import itertools +import unittest + +import torch +import torch.nn as nn + +# TODO: use interface in cpu.py +from sgl_kernel.common_ops import ( + convert_weight_packed, + fp8_scaled_mm_cpu, + int8_scaled_mm_cpu, + int8_scaled_mm_with_quant, + per_token_quant_int8_cpu, + weight_packed_linear, +) +from utils import ( + convert_weight, + native_w8a8_per_token_matmul, + per_token_quant_int8, + precision, +) + +from sglang.test.test_utils import CustomTestCase + + +class Mod(nn.Module): + def __init__(self, input_channel, output_channel, has_bias): + super(Mod, self).__init__() + self.linear = torch.nn.Linear(input_channel, output_channel, has_bias) + + def forward(self, x): + return self.linear(x) + + +class TestGemm(CustomTestCase): + M = [1, 101] + N = [32 * 13] + K = [32 * 16] + has_bias = [False, True] + + M_int8 = [2, 128] + N_int8 = [32 * 12] + K_int8 = [32 * 17] + + M_fp8 = [1, 11] + N_fp8 = [128, 224] + K_fp8 = [512, 576] + + def _bf16_gemm(self, M, N, K, has_bias): + + mat1 = torch.randn(M, K, dtype=torch.bfloat16) + mat2 = torch.randn(N, K, dtype=torch.bfloat16) + + ref = torch.matmul(mat1.float(), mat2.float().t()) + if has_bias: + bias = torch.randn(N, dtype=torch.float32) + ref.add_(bias.bfloat16()) + + ref = ref.bfloat16() + + out = weight_packed_linear(mat1, mat2, bias if has_bias else None, False) + + packed_mat2 = convert_weight_packed(mat2) + out2 = weight_packed_linear(mat1, packed_mat2, bias if has_bias else None, True) + + atol = rtol = precision[ref.dtype] + self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(ref, out2, atol=atol, rtol=rtol)) + + def test_bf16_gemm(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.has_bias, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + has_bias=params[3], + ): + self._bf16_gemm(*params) + + def _int8_gemm(self, M, N, K, has_bias): + dtype = torch.bfloat16 + A = torch.randn((M, K), dtype=dtype) / 10 + Aq, As = per_token_quant_int8(A) + + factor_for_scale = 1e-2 + int8_max = 127 + int8_min = -128 + + B = (torch.rand((N, K), dtype=torch.float32) - 0.5) * 2 + Bq = (B * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) + Bs = torch.rand(N) * factor_for_scale + + bias = torch.randn(N) if has_bias else None + ref_out = native_w8a8_per_token_matmul(Aq, Bq, As, Bs, bias, dtype) + + atol = rtol = precision[ref_out.dtype] + + Aq2, As2 = per_token_quant_int8_cpu(A) + out = int8_scaled_mm_cpu( + Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False + ) + self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + + # test the fused version + fused_out = int8_scaled_mm_with_quant( + A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False + ) + self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol)) + + def test_int8_gemm(self): + for params in itertools.product( + self.M_int8, + self.N_int8, + self.K_int8, + self.has_bias, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + has_bias=params[3], + ): + self._int8_gemm(*params) + + def _fp8_gemm(self, M, N, K, has_bias): + prepack = True + chunk = False + scale_block_size_N = 64 + scale_block_size_K = 128 + assert scale_block_size_N <= N + assert scale_block_size_K <= K + A_dtype = torch.bfloat16 + + model = Mod(K, N, has_bias).eval() + if chunk: + data = torch.randn(M, K + 6, dtype=A_dtype).narrow(1, 0, K) + else: + data = torch.randn(M, K, dtype=A_dtype) + + weight = model.linear.weight # (N, K) + + if has_bias: + bias = model.linear.bias + + fp8_weight, scales, dq_weight = convert_weight( + weight, [scale_block_size_N, scale_block_size_K], A_dtype + ) + + if has_bias: + ref = torch.matmul(data.to(A_dtype), dq_weight.T) + bias.to(A_dtype) + else: + ref = torch.matmul(data.to(A_dtype), dq_weight.T) + + if prepack: + fp8_weight = convert_weight_packed(fp8_weight) + + opt = fp8_scaled_mm_cpu( + data, + fp8_weight, + scales, + [scale_block_size_N, scale_block_size_K], + bias if has_bias else None, + data.dtype, + prepack, + ) + atol = rtol = precision[ref.dtype] + self.assertTrue(torch.allclose(ref, opt, atol=atol, rtol=rtol)) + + def test_fp8_gemm(self): + for params in itertools.product( + self.M_fp8, + self.N_fp8, + self.K_fp8, + self.has_bias, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + has_bias=params[3], + ): + self._fp8_gemm(*params) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/cpu/utils.py b/test/srt/cpu/utils.py new file mode 100644 index 000000000..551f5dedf --- /dev/null +++ b/test/srt/cpu/utils.py @@ -0,0 +1,96 @@ +import math + +import torch + +precision = { + torch.bfloat16: 1e-2, + torch.float16: 1e-3, + torch.float32: 1e-5, +} + + +def per_token_quant_int8(x): + x = x.float() + absmax = x.abs().max(dim=-1).values + absmax = absmax.clamp_min(1e-10).unsqueeze(-1) + scale_x = absmax / 127 + x_q = x.mul(127 / absmax) + x_q = torch.round(x_q).to(torch.int8) + + return x_q, scale_x + + +def convert_weight(weight, scale_block_size, A_dtype): + N, K = weight.size() + fp8_max = 448.0 + scale_block_size_N, scale_block_size_K = scale_block_size # (128, 128) + + pad_N = (scale_block_size_N - (N % scale_block_size_N)) % scale_block_size_N + pad_K = (scale_block_size_K - (K % scale_block_size_K)) % scale_block_size_K + + if pad_N > 0 or pad_K > 0: + weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N)) + + weight_blocks = weight.view( + math.ceil(N / scale_block_size_N), + scale_block_size_N, + math.ceil(K / scale_block_size_K), + scale_block_size_K, + ) # (8, 128, 8, 128) + weight_blocks = weight_blocks.permute(0, 2, 1, 3).contiguous() # (8, 8, 128, 128) + + # Step 2: compute per-block max abs values → scale + abs_max = weight_blocks.abs().amax(dim=(-2, -1), keepdim=True) # (8, 8, 1, 1) + scales = abs_max / fp8_max + scales = torch.where( + scales == 0, torch.ones_like(scales), scales + ) # avoid division by zero + + q_fp8 = (weight_blocks / scales).to(torch.float8_e4m3fn) + q_fp8_reshape = q_fp8.permute(0, 2, 1, 3).contiguous() + + if pad_N > 0 or pad_K > 0: + q_fp8_reshape = q_fp8_reshape.view(N + pad_N, K + pad_K) + q_fp8_reshape = q_fp8_reshape[:N, :K].contiguous() + else: + q_fp8_reshape = q_fp8_reshape.view(N, K) + + dq_weight = q_fp8.float() * scales + dq_weight = dq_weight.permute(0, 2, 1, 3).contiguous() # (8, 128, 8, 128) + + if pad_N > 0 or pad_K > 0: + w_dq = dq_weight.view(N + pad_N, K + pad_K).to(A_dtype) + w_dq = w_dq[:N, :K].contiguous() + else: + w_dq = dq_weight.view(N, K).to(A_dtype) + + scales = scales.view( + math.ceil(N / scale_block_size_N), math.ceil(K / scale_block_size_K) + ) + + return q_fp8_reshape, scales, w_dq + + +def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16): + """Matrix multiplication function that supports per-token input quantization and per-column weight quantization""" + A = A.to(torch.float32) + B = B.to(torch.float32) + + assert A.shape[-1] == B.shape[-1], "Dimension mismatch" + assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" + + # Reshape input + M = A.numel() // A.shape[-1] + B = B.t() # Transpose weight matrix + N, K = B.shape + origin_C_shape = A.shape[:-1] + (K,) + A = A.reshape(M, N) + + # As is per-token [M, 1], Bs is per-column [1, K] + C = torch.matmul(A, B) # [M, K] + C = As * C * Bs.view(1, -1) # Broadcast per-column scale + + if bias is not None: + C.add_(bias.view(1, -1)) + + return C.reshape(origin_C_shape).to(output_dtype)