diff --git a/sgl-kernel/csrc/cpu/common.h b/sgl-kernel/csrc/cpu/common.h index 1bf45ee4b..0fb132607 100644 --- a/sgl-kernel/csrc/cpu/common.h +++ b/sgl-kernel/csrc/cpu/common.h @@ -105,7 +105,19 @@ namespace { #define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) -// parallel routines +// [NB] Parallel Routines +// +// * at::parallel_for - applies for most of generic use cases, this will be compiled +// against openmp in default torch release. +// +// * parallel_for - same function as above, can choose payload partition scheme in +// balance211. +// +// * parallel_2d - parallel for 2 dimensions, used in GEMM, etc. +// this one will do payload balance across 2 dimensions. +// + +// grain size for each thread constexpr int GRAIN_SIZE = 1024; template ::value, int>::type = 0> @@ -113,6 +125,17 @@ inline T div_up(T x, T y) { return (x + y - 1) / y; } +// you can only use at::get_thread_num() with at::parallel_for() +// as it is lazy initialized, otherwise it will always return 0. +inline int get_thread_num() { +#if defined(_OPENMP) + return omp_get_thread_num(); +#else + return 0; +#endif +} + +// balance payload across each thread template inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { #if 0 @@ -153,6 +176,100 @@ inline void parallel_for(int n, const func_t& f) { #endif } +// for 1d parallel, use `actual_nth` +// for 2d parallel, use even nths, e.g. 43->42 +int inline adjust_num_threads(int m) { + int actual_nth = at::get_num_threads(); + if (m == 1) { + return actual_nth; + } + return std::max(1, (actual_nth >> 1) * 2); +} + +template +inline void parallel_2d(int m, int n, const func_t& f) { + // make sure we have even num_threads + int nth = adjust_num_threads(m); + + // [NOTE] thread blocking: + // + // 1) prefer square block per thread + // 2) use even number of CPU cores + // 3) use all `num_threads` cores + // + // we have: + // TM * TN = T + // BM / TM = BN / TN + // then: + // TM = ((BM / BN) * T) ^ 0.5 + // + float r = float(m) / n; + int nth_m = std::ceil(std::sqrt(r * nth)); + int nth_n = 1; + for (; nth_m > 0; --nth_m) { + nth_n = nth / nth_m; + if (nth_m * nth_n == nth) { + break; + } + } + +#if defined(_OPENMP) +#pragma omp parallel num_threads(nth) + { + int ith = omp_get_thread_num(); + int ith_m = ith / nth_n; + int ith_n = ith % nth_n; + + int thread_block_m = div_up(m, nth_m); + int thread_block_n = div_up(n, nth_n); + + int begin_m = ith_m * thread_block_m; + int end_m = std::min(m, begin_m + thread_block_m); + int begin_n = ith_n * thread_block_n; + int end_n = std::min(n, begin_n + thread_block_n); + + f(begin_m, end_m, begin_n, end_n); + } +#else + f(0, m, 0, n); +#endif +} + +// limit max cache blocks +// when we need to do pre-unpack for weights, e.g. fp8 +#define MAX_CACHE_BLOCK_SIZE 4 + +template +inline int get_cache_blocks(int chunk_size) { + // L2 2MB and ratio of 50% + const int L2_size = 2048 * 1024 >> 1; + return std::max(1, int(L2_size / (chunk_size * sizeof(T)))); +} + +template <> +inline int get_cache_blocks(int chunk_size) { + // fp8 uses bf16 as accumulate type + int cache_block_size = get_cache_blocks(chunk_size); + return std::min(MAX_CACHE_BLOCK_SIZE, cache_block_size); +} + +// 2d sequential loop in range : [mb0, mb1), [nb0, nb1) +template +inline void loop_2d(int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1, int64_t chunk_size, const func_t& f) { + // get number of blocks for L2 in most inner loop + int64_t cache_blocks_nb = get_cache_blocks(chunk_size); + + // loop order: [NB / cache_blocks_nb, MB, cache_blocks_nb] + // TODO: implement reverse order of [MB / cache_blocks_mb, NB, cache_blocks_mb] + for (int64_t nbb = nb0; nbb < nb1; nbb += cache_blocks_nb) { + for (int64_t mb = mb0; mb < mb1; ++mb) { + for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, nb1); ++nb) { + f(mb, nb, nb - nbb); + } + } + } +} + // data indexing for dimension collapse template inline T data_index_init(T offset) { diff --git a/sgl-kernel/csrc/cpu/gemm.cpp b/sgl-kernel/csrc/cpu/gemm.cpp index 2cce8ddac..48655b9f7 100644 --- a/sgl-kernel/csrc/cpu/gemm.cpp +++ b/sgl-kernel/csrc/cpu/gemm.cpp @@ -254,7 +254,7 @@ void tinygemm_kernel( return; } - // pattern: 1-4-16 + // pattern: 1-4-16, N = 16, 32, 48, 64 constexpr int64_t BLOCK_M = 4; constexpr int64_t BLOCK_N = 64; const int64_t MB = div_up(M, BLOCK_M); @@ -268,35 +268,59 @@ void tinygemm_kernel( switch (mb_size << 4 | nb_size >> 4) { // mb_size = 1 + case 0x11: + LAUNCH_TINYGEMM_KERNEL_NN(1, 16); + break; case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x13: + LAUNCH_TINYGEMM_KERNEL_NN(1, 48); + break; case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; // mb_size = 2 + case 0x21: + LAUNCH_TINYGEMM_KERNEL_NN(2, 16); + break; case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x23: + LAUNCH_TINYGEMM_KERNEL_NN(2, 48); + break; case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; // mb_size = 3 + case 0x31: + LAUNCH_TINYGEMM_KERNEL_NN(3, 16); + break; case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x33: + LAUNCH_TINYGEMM_KERNEL_NN(3, 48); + break; case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; // mb_size = 4 + case 0x41: + LAUNCH_TINYGEMM_KERNEL_NN(4, 16); + break; case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + case 0x43: + LAUNCH_TINYGEMM_KERNEL_NN(4, 48); + break; case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; default: - TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + TORCH_CHECK(false, "Unexpected block size, ", mb_size, " x ", nb_size); } } } @@ -318,20 +342,15 @@ void weight_packed_linear_kernel_impl( const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); - // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx c) N is small - const bool use_brgemm = (M > 4) || (!std::is_same_v) || (N < 64); + const bool use_brgemm = can_use_brgemm(M); // parallel on [MB, NB] AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { - int64_t mb{0}, nb{0}; - data_index_init(begin, mb, MB, nb, NB); - + parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { // for brgemm, use float32 for accumulate alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; - for (int64_t i = begin; i < end; ++i) { - UNUSED(i); + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) { int64_t mb_start = mb * BLOCK_M; int64_t mb_size = std::min(M - mb_start, BLOCK_M); int64_t nb_start = nb * BLOCK_N; @@ -350,10 +369,7 @@ void weight_packed_linear_kernel_impl( /* ldb */ nb_size, /* ldc */ out_strideM, /* brg */ use_brgemm); - - // move to the next index - data_index_step(mb, MB, nb, NB); - } + }); if (use_brgemm) { at::native::cpublas::brgemm_release(); diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h index eabbfb7c8..6a16a2985 100644 --- a/sgl-kernel/csrc/cpu/gemm.h +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -27,10 +27,10 @@ template <> inline bool can_use_brgemm(int M) { return true; } -// TODO: add u8s8 brgemm, this requires PyTorch 2.7 +// this requires PyTorch 2.7 or above template <> inline bool can_use_brgemm(int M) { - return false; + return M > 4; } template <> @@ -198,4 +198,5 @@ void tinygemm_kernel( int64_t ldb, int64_t ldc, bool brg, - int64_t block_size_K); + int64_t block_size_K, + bool do_unpack = true); diff --git a/sgl-kernel/csrc/cpu/gemm_fp8.cpp b/sgl-kernel/csrc/cpu/gemm_fp8.cpp index 3bba40786..008f83298 100644 --- a/sgl-kernel/csrc/cpu/gemm_fp8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_fp8.cpp @@ -2,9 +2,6 @@ #include "gemm.h" #include "vec.h" -// we use 4x32 for BLOCK_M -#define BLOCK_SIZE_M_SCALE 4 - namespace { template @@ -250,7 +247,8 @@ struct brgemm { int K, int lda, int ldb, - int ldc) { + int ldc, + bool do_unpack = true) { TORCH_CHECK(false, "struct brgemm: primary template not implemented!"); } }; @@ -270,17 +268,20 @@ struct brgemm { int K, int lda, int ldb, - int ldc) { + int ldc, + bool do_unpack = true) { constexpr int BLOCK_N = block_size_n(); // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] const int ldb_tmp = BLOCK_N; - for (int k = 0; k < K; k += BLOCK_K) { - int kb_size = std::min(BLOCK_K, K - k); + if (do_unpack) { + for (int k = 0; k < K; k += BLOCK_K) { + int kb_size = std::min(BLOCK_K, K - k); - int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 - unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); + int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 + unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); + } } at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); @@ -312,9 +313,11 @@ void tinygemm_kernel( int64_t ldb, int64_t ldc, bool brg, - int64_t block_size_K) { + int64_t block_size_K, + bool do_unpack = true) { if (brg) { - brgemm::apply(A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc); + brgemm::apply( + A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc, do_unpack); return; } @@ -366,7 +369,7 @@ void fp8_scaled_mm_kernel_impl( int64_t block_size_N, int64_t block_size_K, int64_t buffer_size_per_thread) { - constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; + constexpr int64_t BLOCK_M = block_size_m(); constexpr int64_t BLOCK_N = block_size_n(); const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); @@ -378,16 +381,12 @@ void fp8_scaled_mm_kernel_impl( // parallel on [MB, NB] AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { - int64_t mb{0}, nb{0}; - data_index_init(begin, mb, MB, nb, NB); - - int tid = at::get_thread_num(); + parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { + int tid = get_thread_num(); scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; - float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K)); + float* __restrict__ Ctmp = (float*)((void*)(Btmp + MAX_CACHE_BLOCK_SIZE * BLOCK_N * K)); - for (int64_t i = begin; i < end; ++i) { - UNUSED(i); + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) { const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; int64_t mb_start = mb * BLOCK_M; @@ -395,11 +394,14 @@ void fp8_scaled_mm_kernel_impl( int64_t nb_start = nb * BLOCK_N; int64_t nb_size = std::min(N - nb_start, BLOCK_N); + // only do unpacking for the first row + bool do_unpack = (mb == mb0); + tinygemm_kernel( /* A */ mat1 + mb_start * mat1_strideM, /* B */ mat2 + nb_start * K, // nb * BLOCK_N * K /* C */ out + mb_start * out_strideM + nb_start, - /* Btmp */ Btmp, + /* Btmp */ Btmp + nb_offset * BLOCK_N * K, /* Ctmp */ Ctmp, /* scale */ scale_ptr, /* bias */ bias + nb_start, @@ -410,11 +412,9 @@ void fp8_scaled_mm_kernel_impl( /* ldb */ nb_size, /* ldc */ out_strideM, /* brg */ use_brgemm, - /* block_size_K */ block_size_K); - - // move to the next index - data_index_step(mb, MB, nb, NB); - } + /* block_size_K */ block_size_K, + /* do_unpack */ do_unpack); + }); if (use_brgemm) { at::native::cpublas::brgemm_release(); @@ -441,8 +441,10 @@ void tinygemm_kernel( int64_t ldb, int64_t ldc, bool brg, - int64_t block_size_K) { - tinygemm_kernel(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K); + int64_t block_size_K, + bool do_unpack) { + tinygemm_kernel( + A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack); } #define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ @@ -460,7 +462,8 @@ void tinygemm_kernel( int64_t ldb, \ int64_t ldc, \ bool brg, \ - int64_t block_size_K) + int64_t block_size_K, \ + bool do_unpack) INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); @@ -495,7 +498,7 @@ at::Tensor fp8_scaled_mm_cpu( int64_t block_size_N = block_size[0]; int64_t block_size_K = block_size[1]; - constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; + constexpr int64_t BLOCK_M = block_size_m(); constexpr int64_t BLOCK_N = block_size_n(); TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); @@ -523,7 +526,7 @@ at::Tensor fp8_scaled_mm_cpu( // Btmp : [T, BLOCK_N * K] // Ctmp : [T, BLOCK_M * BLOCK_N] int num_threads = at::get_num_threads(); - int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2; + int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2; auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { diff --git a/sgl-kernel/csrc/cpu/gemm_int8.cpp b/sgl-kernel/csrc/cpu/gemm_int8.cpp index f0f013cd1..cb6146607 100644 --- a/sgl-kernel/csrc/cpu/gemm_int8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_int8.cpp @@ -4,6 +4,61 @@ namespace { +template +struct scale_C { + static inline void apply( + scalar_t* __restrict__ C, + const int32_t* __restrict__ Ctmp, + const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, + float As, + const float* __restrict__ Bs) { + TORCH_CHECK(false, "scale_C: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct scale_C { + static inline void apply( + at::BFloat16* __restrict__ C, + const int32_t* __restrict__ Ctmp, + const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, + float As, + const float* __restrict__ Bs) { + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512 vc[COLS]; + __m512 vd0 = _mm512_set1_ps(As); + + auto compute = [&](auto col) { + __m512 vd1 = _mm512_loadu_ps(Bs + col * 16); + __m512i vcomp = _mm512_loadu_si512(Bcomp + col * 16); + __m512i vc32 = _mm512_loadu_si512(Ctmp + col * 16); + vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp)); + if constexpr (has_bias) { + __m512 vbias = _mm512_loadu_ps(bias + col * 16); + vc[col] = _mm512_fmadd_ps(_mm512_mul_ps(vc[col], vd0), vd1, vbias); + } else { + vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vd0), vd1); + } + }; + Unroll{}(compute); + + auto storec = [&](auto col) { + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc[col + 1], vc[col + 0]))); + } + }; + Unroll{}(storec); + } +}; +#endif + template struct tinygemm_kernel_nn { static inline void apply( @@ -169,6 +224,17 @@ void tinygemm_kernel( // B compensation const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); + if (brg) { + constexpr int BLOCK_N = block_size_n(); + at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp); + + // apply compensation and scale + for (int64_t m = 0; m < M; ++m) { + scale_C::apply(C + m * ldc, Ctmp + m * BLOCK_N, Bcomp, bias, As[m], Bs); + } + return; + } + // pattern: 1-4-16 constexpr int64_t BLOCK_M = 4; constexpr int64_t BLOCK_N = 64; @@ -233,22 +299,17 @@ void int8_scaled_mm_kernel_impl( const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); - // TODO: brgemm u8s8 depends on PyTorch 2.7 release. - const bool use_brgemm = false; + const bool use_brgemm = can_use_brgemm(M); // K + 4 after compensation const int64_t packed_row_size = get_row_size(K); AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { - int64_t mb{0}, nb{0}; - data_index_init(begin, mb, MB, nb, NB); - + parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { // for brgemm, use int32_t for accumulate alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N]; - for (int i = begin; i < end; ++i) { - UNUSED(i); + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) { int mb_start = mb * BLOCK_M; int mb_size = std::min(M - mb_start, BLOCK_M); int nb_start = nb * BLOCK_N; @@ -269,10 +330,7 @@ void int8_scaled_mm_kernel_impl( /* ldb */ nb_size, /* ldc */ N, /* brg */ use_brgemm); - - // move to the next index - data_index_step(mb, MB, nb, NB); - } + }); if (use_brgemm) { at::native::cpublas::brgemm_release(); diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index 88d84c830..c3d66cec7 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -579,36 +579,31 @@ void fused_experts_kernel_impl( const int64_t stride_e = 2 * N * K; const int64_t stride_n = K; + int64_t avg_M = std::max(int64_t(1), M * topk / E); + const bool use_brgemm = can_use_brgemm(avg_M); + // here we only parallel on half of 2N to fuse silu_and_mul with gemm - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { // get local pointers - int tid = at::get_thread_num(); + int tid = get_thread_num(); scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; - bool is_brgemm_used = false; - - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB; - int64_t nb = i % NB; - - // nb0 from top half and nb1 from bottom half - int64_t nb0 = nb, nb1 = nb + NB; - int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) { + // nb_upper from top half and nb_lower from bottom half + int64_t nb_upper = nb, nb_lower = nb + NB; + int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N); // B shape [K, n_size] in vnni format int32_t expert_id = expert_ids[mb]; - const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; - const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; + const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb_upper * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb_lower * BLOCK_N * stride_n; // 1.a load A const int32_t* A_ids = sorted_ids + mb * BLOCK_M; int64_t m_size = offsets[mb + 1] - offsets[mb]; - const bool use_brgemm = can_use_brgemm(m_size); - is_brgemm_used = is_brgemm_used || use_brgemm; - for (int64_t m = 0; m < m_size; ++m) { int32_t index = A_ids[m] / topk; copy_stub(A + m * K, input + index * K, K); @@ -659,9 +654,9 @@ void fused_experts_kernel_impl( /* ldb */ n_size, /* ldc */ N); } - } + }); - if (is_brgemm_used) { + if (use_brgemm) { at::native::cpublas::brgemm_release(); } }); @@ -676,24 +671,16 @@ void fused_experts_kernel_impl( const int64_t stride_oc = IC; // parallel on [MB2, NB2] - at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { // get local pointers - int tid = at::get_thread_num(); + int tid = get_thread_num(); // we won't be using C1 for gemm2 float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; - bool is_brgemm_used = false; - - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB2; - int64_t nb = i % NB2; - + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) { int64_t m_size = offsets[mb + 1] - offsets[mb]; int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); - const bool use_brgemm = can_use_brgemm(m_size); - is_brgemm_used = is_brgemm_used || use_brgemm; - // A ptr from ic1 of [M * topk, N] in sorted order // so as to avoid copy A to tmp buffer again const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; @@ -736,9 +723,9 @@ void fused_experts_kernel_impl( float weight = topk_weights[index]; copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); } - } + }); - if (is_brgemm_used) { + if (use_brgemm) { at::native::cpublas::brgemm_release(); } }); @@ -776,36 +763,27 @@ void shared_expert_kernel_impl( TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); const int64_t stride_n = K; + const bool use_brgemm = can_use_brgemm(M); + // here we only parallel on half of 2N to fuse silu_and_mul with gemm - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { // get local pointers - int tid = at::get_thread_num(); + int tid = get_thread_num(); float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; - bool is_brgemm_used = false; - - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB; - int64_t nb = i % NB; - - // nb0 from top half and nb1 from bottom half - int64_t nb0 = nb, nb1 = nb + NB; - int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) { + // nb_upper from top half and nb_lower from bottom half + int64_t nb_upper = nb, nb_lower = nb + NB; + int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N); int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); - // int64_t mb_start = mb * BLOCK_M; - // int64_t mb_size = std::min(M - mb_start, BLOCK_M); - // A shape [m_size, K] const scalar_t* A = input + mb * BLOCK_M * K; // B shape [K, n_size] in vnni format - const scalar_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; - const scalar_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; - - const bool use_brgemm = can_use_brgemm(m_size); - is_brgemm_used = is_brgemm_used || use_brgemm; + const scalar_t* __restrict__ B0 = packed_w1 + nb_upper * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + nb_lower * BLOCK_N * stride_n; if (use_brgemm) { // 1.b gemm: C0 = A @ B0 @@ -850,9 +828,9 @@ void shared_expert_kernel_impl( /* ldb */ n_size, /* ldc */ N); } - } + }); - if (is_brgemm_used) { + if (use_brgemm) { at::native::cpublas::brgemm_release(); } }); @@ -866,24 +844,16 @@ void shared_expert_kernel_impl( const int64_t stride_oc = IC; // parallel on [MB2, NB2] - at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { // get local pointers - int tid = at::get_thread_num(); + int tid = get_thread_num(); // we won't be using C1 for gemm2 float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; - bool is_brgemm_used = false; - - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB2; - int64_t nb = i % NB2; - + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) { int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); - const bool use_brgemm = can_use_brgemm(m_size); - is_brgemm_used = is_brgemm_used || use_brgemm; - // A shape [m_size, IC] const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N; @@ -922,9 +892,9 @@ void shared_expert_kernel_impl( for (int64_t m = 0; m < m_size; ++m) { add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); } - } + }); - if (is_brgemm_used) { + if (use_brgemm) { at::native::cpublas::brgemm_release(); } }); @@ -1086,7 +1056,7 @@ at::Tensor fused_experts_cpu( // // for fp8 w8a16: // 7. intermediate_cache0 : [M * topk, 2N] - // 8. B_tmp : [T, BLOCK_N, std::max(K, N)] + // 8. B_tmp : [T, MAX_CACHE_BLOCK_SIZE, BLOCK_N, std::max(K, N)] // int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 + num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) + @@ -1096,7 +1066,7 @@ at::Tensor fused_experts_cpu( buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float); } if (use_fp8_w8a16) { - buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2; + buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N) * 2; } auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); @@ -1268,7 +1238,7 @@ at::Tensor shared_expert_cpu( // // for fp8 w8a16: // 5. intermediate_cache0 : [M, 2N] - // 6. B_tmp: [T, BLOCK_M, max(K, N)] + // 6. B_tmp: [T, MAX_CACHE_BLOCK_SIZE, BLOCK_M, max(K, N)] // int num_threads = at::get_num_threads(); int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); @@ -1277,7 +1247,7 @@ at::Tensor shared_expert_cpu( buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float); } if (use_fp8_w8a16) { - buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2; + buffer_size_nbytes += M * 2 * N * 2 + num_threads * MAX_CACHE_BLOCK_SIZE * BLOCK_M * std::max(K, N) * 2; } auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); diff --git a/sgl-kernel/csrc/cpu/moe_fp8.cpp b/sgl-kernel/csrc/cpu/moe_fp8.cpp index cb891fca2..281c00897 100644 --- a/sgl-kernel/csrc/cpu/moe_fp8.cpp +++ b/sgl-kernel/csrc/cpu/moe_fp8.cpp @@ -174,18 +174,18 @@ void fused_experts_fp8_kernel_impl( const int64_t stride_e = 2 * N * K; const int64_t stride_n = K; + int64_t avg_M = std::max(int64_t(1), M * topk / E); + const bool use_brgemm = can_use_brgemm(avg_M); + + int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N); + // here we only parallel on half of 2N to fuse silu_and_mul with gemm - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { // get local pointers - int tid = at::get_thread_num(); + int tid = get_thread_num(); scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; - bool is_brgemm_used = false; - - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB; - int64_t nb = i % NB; - + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) { int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); // B shape [K, n_size] in vnni format @@ -194,13 +194,14 @@ void fused_experts_fp8_kernel_impl( const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + // do unpacking for the first row or a new expert + int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1]; + bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id); + // 1.a load A const int32_t* A_ids = sorted_ids + mb * BLOCK_M; int64_t m_size = offsets[mb + 1] - offsets[mb]; - const bool use_brgemm = can_use_brgemm(m_size); - is_brgemm_used = is_brgemm_used || use_brgemm; - for (int64_t m = 0; m < m_size; ++m) { int32_t index = A_ids[m] / topk; copy_stub(A + m * K, input + index * K, K); @@ -211,7 +212,7 @@ void fused_experts_fp8_kernel_impl( /* A */ A, /* B */ B, /* C */ ic0 + offset * 2 * N + nb * BLOCK_N, - /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ Bs, /* M */ m_size, @@ -221,10 +222,11 @@ void fused_experts_fp8_kernel_impl( /* ldb */ n_size, /* ldc */ 2 * N, /* brg */ use_brgemm, - /* block_size_K */ block_size_K); - } + /* block_size_K */ block_size_K, + /* do_unpack */ do_unpack); + }); - if (is_brgemm_used) { + if (use_brgemm) { at::native::cpublas::brgemm_release(); } }); @@ -248,22 +250,14 @@ void fused_experts_fp8_kernel_impl( const int64_t stride_oc = IC; // parallel on [MB2, NB2] - at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { - int tid = at::get_thread_num(); + parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { + int tid = get_thread_num(); alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; - bool is_brgemm_used = false; - - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB2; - int64_t nb = i % NB2; - + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) { int64_t m_size = offsets[mb + 1] - offsets[mb]; int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); - const bool use_brgemm = can_use_brgemm(m_size); - is_brgemm_used = is_brgemm_used || use_brgemm; - // A ptr from ic1 of [M * topk, N] in sorted order // so as to avoid copy A to tmp buffer again const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; @@ -275,11 +269,15 @@ void fused_experts_fp8_kernel_impl( const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + // do unpacking for the first row or a new expert + int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1]; + bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id); + tinygemm_kernel( /* A */ A, /* B */ B, /* C */ C, - /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ Bs, /* M */ m_size, @@ -289,7 +287,8 @@ void fused_experts_fp8_kernel_impl( /* ldb */ n_size, /* ldc */ BLOCK_N, /* brg */ use_brgemm, - /* block_size_K */ block_size_K); + /* block_size_K */ block_size_K, + /* do_unpack */ do_unpack); // 2.b copy from C to ic2 in original order // and also mul topk_weights in float32 @@ -298,9 +297,9 @@ void fused_experts_fp8_kernel_impl( float weight = topk_weights[index]; copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); } - } + }); - if (is_brgemm_used) { + if (use_brgemm) { at::native::cpublas::brgemm_release(); } }); @@ -374,20 +373,23 @@ void shared_expert_fp8_kernel_impl( const bool use_brgemm = can_use_brgemm(M); - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { - int tid = at::get_thread_num(); + int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N); - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB; - int64_t nb = i % NB; + parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { + int tid = get_thread_num(); + + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) { int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + // do unpacking for the first row + bool do_unpack = (mb == mb0); + tinygemm_kernel( /* A */ input + mb * BLOCK_M * K, /* B */ packed_w1 + nb * BLOCK_N * K, /* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, - /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K, /* M */ m_size, @@ -397,8 +399,9 @@ void shared_expert_fp8_kernel_impl( /* ldb */ n_size, /* ldc */ 2 * N, /* brg */ use_brgemm, - /* block_size_K */ block_size_K); - } + /* block_size_K */ block_size_K, + /* do_unpack */ do_unpack); + }); if (use_brgemm) { at::native::cpublas::brgemm_release(); @@ -421,22 +424,23 @@ void shared_expert_fp8_kernel_impl( scale_size_K = div_up(N, block_size_K); // parallel on [MB2, NB2] - at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { - int tid = at::get_thread_num(); + parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { + int tid = get_thread_num(); alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB2; - int64_t nb = i % NB2; + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) { int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + // do unpacking for the first row + bool do_unpack = (mb == mb0); + // 2.a gemm: C = A @ B tinygemm_kernel( /* A */ ic1 + mb * BLOCK_M * N, /* B */ packed_w2 + nb * BLOCK_N * N, /* C */ C, - /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K, /* M */ m_size, @@ -446,7 +450,8 @@ void shared_expert_fp8_kernel_impl( /* ldb */ n_size, /* ldc */ BLOCK_N, /* brg */ use_brgemm, - /* block_size_K */ block_size_K); + /* block_size_K */ block_size_K, + /* do_unpack */ do_unpack); // 2.b copy from C to output and add fused_experts_out scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; @@ -454,7 +459,7 @@ void shared_expert_fp8_kernel_impl( for (int64_t m = 0; m < m_size; ++m) { add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); } - } + }); }); if (use_brgemm) { diff --git a/sgl-kernel/csrc/cpu/moe_int8.cpp b/sgl-kernel/csrc/cpu/moe_int8.cpp index e12e5e7cf..8fbac902f 100644 --- a/sgl-kernel/csrc/cpu/moe_int8.cpp +++ b/sgl-kernel/csrc/cpu/moe_int8.cpp @@ -109,6 +109,120 @@ inline void add_mul_stub( } } +template +inline void silu_and_mul( + scalar_t* __restrict__ C, + const int32_t* __restrict__ C0, // x: x0, x1 + const int32_t* __restrict__ C1, // y: y0, y1 + const float* __restrict__ As, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, + const int32_t* __restrict__ Bcomp1, + int64_t m_size, + int64_t N) { +#if defined(CPU_CAPABILITY_AVX512) + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512 vc0[COLS]; + __m512 vc1[COLS]; + __m512i vcomp0[COLS]; + __m512i vcomp1[COLS]; + __m512 vas; + __m512 vbs0[COLS]; + __m512 vbs1[COLS]; + + auto load_scale_and_comp = [&](auto col) { + vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); + vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); + vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); + vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); + }; + Unroll{}(load_scale_and_comp); + + auto scalec = [&](auto col, int64_t m) { + // update As + vas = _mm512_set1_ps(As[m]); + // C = As * (C - Bcomp) * Bs + __m512i vc32_0 = _mm512_loadu_si512(C0 + m * BLOCK_N + col * 16); + __m512i vc32_1 = _mm512_loadu_si512(C1 + m * BLOCK_N + col * 16); + vc0[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_0, vcomp0[col])); + vc1[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_1, vcomp1[col])); + vc0[col] = _mm512_mul_ps(_mm512_mul_ps(vc0[col], vas), vbs0[col]); + vc1[col] = _mm512_mul_ps(_mm512_mul_ps(vc1[col], vas), vbs1[col]); + }; + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + const fVec one = fVec(1.f); + auto silu_and_mul = [&](auto col) { + fVec x = fVec(vc0[col]); + fVec y = fVec(vc1[col]); + x = x / (one + x.neg().exp_u20()); + vc0[col] = x * y; + }; + + auto storec = [&](auto col, int64_t m) { + if constexpr (col % 2 == 0) { + fVec x0 = fVec(vc0[col + 0]); + fVec x1 = fVec(vc0[col + 1]); + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(C + m * N + col * 16); + } + }; + + for (int64_t m = 0; m < m_size; ++m) { + Unroll{}(scalec, m); + Unroll{}(silu_and_mul); + Unroll{}(storec, m); + } +#else + TORCH_CHECK(false, "silu_and_mul: scalar path not implemented!"); +#endif +} + +template +inline void scale_C( + float* __restrict__ C, + const int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + const int32_t* __restrict__ Bcomp, + int64_t m_size) { +#if defined(CPU_CAPABILITY_AVX512) + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512 vc[COLS]; + __m512i vcomp[COLS]; + __m512 vas; + __m512 vbs[COLS]; + + auto load_scale_and_comp = [&](auto col) { + vcomp[col] = _mm512_loadu_si512(Bcomp + col * 16); + vbs[col] = _mm512_loadu_ps(Bs + col * 16); + }; + Unroll{}(load_scale_and_comp); + + auto scalec = [&](auto col, int64_t m) { + // update As + vas = _mm512_set1_ps(As[m]); + // C = As * (C - Bcomp) * Bs + __m512i vc32 = _mm512_loadu_si512(Ctmp + m * BLOCK_N + col * 16); + vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp[col])); + vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vas), vbs[col]); + _mm512_storeu_ps(C + m * BLOCK_N + col * 16, vc[col]); + }; + + for (int64_t m = 0; m < m_size; ++m) { + Unroll{}(scalec, m); + } +#else + TORCH_CHECK(false, "scale_C: scalar path not implemented!"); +#endif +} + /// gemm for w13 template struct tinygemm_kernel_vnni { @@ -515,28 +629,31 @@ void fused_experts_int8_kernel_impl( const int64_t stride_e = 2 * N * packed_K; const int64_t stride_n = packed_K; + + int64_t avg_M = std::max(int64_t(1), M * topk / E); + const bool use_brgemm = can_use_brgemm(avg_M); + // here we only parallel on half of 2N to fuse silu_and_mul with gemm - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { // get local pointers - int tid = at::get_thread_num(); + int tid = get_thread_num(); uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + int32_t* __restrict__ C0 = reinterpret_cast(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N; + int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; alignas(64) float As[BLOCK_M]; - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB; - int64_t nb = i % NB; - - // nb0 from top half and nb1 from bottom half - int64_t nb0 = nb, nb1 = nb + NB; - int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) { + // nb_upper from top half and nb_lower from bottom half + int64_t nb_upper = nb, nb_lower = nb + NB; + int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N); // B shape [K, n_size] in vnni format int32_t expert_id = expert_ids[mb]; - const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; - const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; - const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N; - const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N; + const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb_upper * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb_lower * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb_upper * BLOCK_N; + const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb_lower * BLOCK_N; // 1.a load A const int32_t* A_ids = sorted_ids + mb * BLOCK_M; @@ -548,22 +665,62 @@ void fused_experts_int8_kernel_impl( As[m] = As_tmp[index]; } - // fused 1.b: silu_and_mul(A @ B0, A @ B1) - const int64_t offset = offsets[mb]; - tinygemm_kernel( - /* A */ A, - /* B0 */ B0, - /* B1 */ B1, - /* C */ ic1 + offset * N + nb * BLOCK_N, - /* As */ As, - /* Bs0 */ Bs0, - /* Bs1 */ Bs1, - /* M */ m_size, - /* N */ n_size, - /* K */ K, - /* lda */ K, - /* ldb */ n_size, - /* ldc */ N); + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + const int32_t* Bcomp0 = reinterpret_cast(B0 + block_size_n() * K); + const int32_t* Bcomp1 = reinterpret_cast(B1 + block_size_n() * K); + + // 1.d silu and mul + const int64_t offset = offsets[mb]; + silu_and_mul( + ic1 + offset * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + offset * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); } }); @@ -584,16 +741,13 @@ void fused_experts_int8_kernel_impl( const int64_t stride_oc = packed_N; // parallel on [MB2, NB2] - at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { // get local pointers - int tid = at::get_thread_num(); - // we won't be using C1 for gemm2 + int tid = get_thread_num(); float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + int32_t* __restrict__ C32 = reinterpret_cast(C + BLOCK_M * BLOCK_N); - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB2; - int64_t nb = i % NB2; - + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) { int64_t m_size = offsets[mb + 1] - offsets[mb]; int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); @@ -609,18 +763,36 @@ void fused_experts_int8_kernel_impl( const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N; // 2.a gemm: C = A @ B - tinygemm_kernel( - /* A */ A, - /* B */ B, - /* C */ C, - /* As */ As, - /* Bs */ Bs, - /* M */ m_size, - /* N */ n_size, - /* K */ IC, - /* lda */ IC, - /* ldb */ n_size, - /* ldc */ BLOCK_N); + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C32); + + // apply scales + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * IC); + scale_C(C, C32, As, Bs, Bcomp, m_size); + } else { + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } // 2.b copy from C to ic2 in original order // and also mul topk_weights in float32 @@ -629,6 +801,10 @@ void fused_experts_int8_kernel_impl( float weight = topk_weights[index]; copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); } + }); + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); } }); @@ -708,15 +884,19 @@ void shared_expert_int8_kernel_impl( const int64_t packed_N = get_row_size(N); const int64_t stride_n = packed_K; - // here we only parallel on half of 2N to fuse silu_and_mul with gemm - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB; - int64_t nb = i % NB; + const bool use_brgemm = can_use_brgemm(M); - // nb0 from top half and nb1 from bottom half - int64_t nb0 = nb, nb1 = nb + NB; - int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { + // get local pointers + int tid = get_thread_num(); + int32_t* __restrict__ C0 = reinterpret_cast(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N; + int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) { + // nb_upper from top half and nb_lower from bottom half + int64_t nb_upper = nb, nb_lower = nb + NB; + int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N); int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); // A shape [m_size, K] @@ -724,26 +904,65 @@ void shared_expert_int8_kernel_impl( const float* As = As_tmp + mb * BLOCK_M; // B shape [K, n_size] in vnni format - const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; - const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; - const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; - const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; + const int8_t* __restrict__ B0 = packed_w1 + nb_upper * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + nb_lower * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + nb_upper * BLOCK_N; + const float* __restrict__ Bs1 = w1s + nb_lower * BLOCK_N; - // fused 1.b: silu_and_mul(A @ B0, A @ B1) - tinygemm_kernel( - /* A */ A, - /* B0 */ B0, - /* B1 */ B1, - /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, - /* As */ As, - /* Bs0 */ Bs0, - /* Bs1 */ Bs1, - /* M */ m_size, - /* N */ n_size, - /* K */ K, - /* lda */ K, - /* ldb */ n_size, - /* ldc */ N); + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + const int32_t* Bcomp0 = reinterpret_cast(B0 + block_size_n() * K); + const int32_t* Bcomp1 = reinterpret_cast(B1 + block_size_n() * K); + + // 1.d silu and mul + silu_and_mul( + ic1 + mb * BLOCK_M * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); } }); @@ -763,16 +982,13 @@ void shared_expert_int8_kernel_impl( const int64_t stride_oc = packed_N; // parallel on [MB2, NB2] - at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) { // get local pointers - int tid = at::get_thread_num(); - // we won't be using C1 for gemm2 + int tid = get_thread_num(); float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + int32_t* __restrict__ C32 = reinterpret_cast(C + BLOCK_M * BLOCK_N); - for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB2; - int64_t nb = i % NB2; - + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) { int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); @@ -784,19 +1000,37 @@ void shared_expert_int8_kernel_impl( const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; const float* __restrict__ Bs = w2s + nb * BLOCK_N; - // 2.a gemm: C = A @ B - tinygemm_kernel( - /* A */ A, - /* B */ B, - /* C */ C, - /* As */ As, - /* Bs */ Bs, - /* M */ m_size, - /* N */ n_size, - /* K */ IC, - /* lda */ IC, - /* ldb */ n_size, - /* ldc */ BLOCK_N); + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C32); + + // apply scales + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * IC); + scale_C(C, C32, As, Bs, Bcomp, m_size); + } else { + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } // 2.b copy from C to output and add fused_experts_out scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; @@ -804,6 +1038,10 @@ void shared_expert_int8_kernel_impl( for (int64_t m = 0; m < m_size; ++m) { add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); } + }); + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); } }); } diff --git a/sgl-kernel/csrc/cpu/qkv_proj.cpp b/sgl-kernel/csrc/cpu/qkv_proj.cpp index 8d663e84a..b3e2072e8 100644 --- a/sgl-kernel/csrc/cpu/qkv_proj.cpp +++ b/sgl-kernel/csrc/cpu/qkv_proj.cpp @@ -100,8 +100,7 @@ void segment_gemm_kernel_impl( const int64_t NB1 = div_up(N1, BLOCK_N); const int64_t NB = NB0 + NB1; - // TODO: brgemm u8s8 depends on PyTorch 2.7 release. - const bool use_brgemm = false; + const bool use_brgemm = can_use_brgemm(M); // K + 4 after compensation const int64_t packed_row_size = get_row_size(K);