diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h index 5c3ff26bb..901939f11 100644 --- a/sgl-kernel/csrc/cpu/gemm.h +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -85,6 +85,32 @@ void fused_experts_int8_kernel_impl( int64_t topk, int64_t num_tokens_post_pad); +// moe implementations for fp8 w8a16 +template +void fused_experts_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + // shared expert implementation for int8 w8a8 template void shared_expert_int8_kernel_impl( diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index e1a9a9f85..ea6b0cc2c 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -932,6 +932,40 @@ void shared_expert_kernel_impl( } // anonymous namespace +// common checks +static inline void check_moe_scales( + bool use_int8_w8a8, + bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale) { + if (use_int8_w8a8) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8."); + TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported."); + TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported."); + } + if (use_fp8_w8a16) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16."); + TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16."); + TORCH_CHECK(block_size.value().size() == 2, "expect block_size.size() to be 2."); + } +} + +#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \ + auto w1s = w1_scale.value(); \ + auto w2s = w2_scale.value(); \ + auto block_size_val = block_size.value(); \ + int64_t block_size_N = block_size_val[0]; \ + int64_t block_size_K = block_size_val[1]; \ + TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \ + TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \ + TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \ + TORCH_CHECK(w2s.size(DIM1) == N / block_size_K) + // hidden_states: [M, K] // w1: [E, 2N, K] // w2: [E, K, N] @@ -946,8 +980,10 @@ at::Tensor fused_experts_cpu( at::Tensor& topk_ids, bool inplace, bool use_int8_w8a8, + bool use_fp8_w8a16, const std::optional& w1_scale, const std::optional& w2_scale, + const std::optional> block_size, const std::optional& a1_scale, const std::optional& a2_scale, bool is_vnni) { @@ -990,12 +1026,8 @@ at::Tensor fused_experts_cpu( CHECK_EQ(packed_w1.size(2), packed_K); CHECK_EQ(packed_w2.size(2), packed_N); - if (use_int8_w8a8) { - TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8."); - TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8."); - TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported."); - TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported."); - } + // check scales + check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale); at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); @@ -1047,6 +1079,9 @@ at::Tensor fused_experts_cpu( // 5. Aq_tmp : [M, K] or [M * topk, N] // 6. As_tmp : [M * topk] // + // for fp8 w8a16: + // 7. intermediate_cache1 : [M * topk, 2N] + // int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 + num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); @@ -1054,6 +1089,9 @@ at::Tensor fused_experts_cpu( if (use_int8_w8a8) { buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float); } + if (use_fp8_w8a16) { + buffer_size_nbytes += M * topk * 2 * N * 2; + } auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); @@ -1095,6 +1133,35 @@ at::Tensor fused_experts_cpu( E, topk, num_tokens_post_pad); + } else if (use_fp8_w8a16) { + // here we just ignore C_tmp as it is not used + scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K)); + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + + CHECK_MOE_SCALES_FP8(1, 2); + fused_experts_fp8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache0, + intermediate_cache1, + intermediate_cache2, + A_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + block_size_N, + block_size_K, + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); } else { scalar_t* __restrict__ A_tmp = intermediate_cache2 + M * topk * K; float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); @@ -1176,17 +1243,8 @@ at::Tensor shared_expert_cpu( CHECK_EQ(packed_w1.size(1), packed_K); CHECK_EQ(packed_w2.size(1), packed_N); - if (use_int8_w8a8) { - TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8."); - TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8."); - TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported."); - TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported."); - } - if (use_fp8_w8a16) { - TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16."); - TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16."); - TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16."); - } + // check scales + check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale); at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); @@ -1244,17 +1302,7 @@ at::Tensor shared_expert_cpu( } else if (use_fp8_w8a16) { scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); - auto w1s = w1_scale.value(); - auto w2s = w2_scale.value(); - auto block_size_val = block_size.value(); - TORCH_CHECK(block_size_val.size() == 2, "shared_expert: expect block_size.size() to be 2."); - int64_t block_size_N = block_size_val[0]; - int64_t block_size_K = block_size_val[1]; - TORCH_CHECK(w1s.size(0) == 2 * N / block_size_N); - TORCH_CHECK(w1s.size(1) == K / block_size_K); - TORCH_CHECK(w2s.size(0) == K / block_size_N); - TORCH_CHECK(w2s.size(1) == N / block_size_K); - + CHECK_MOE_SCALES_FP8(0, 1); shared_expert_fp8_kernel_impl( out_hidden_states.data_ptr(), intermediate_cache0, diff --git a/sgl-kernel/csrc/cpu/moe_fp8.cpp b/sgl-kernel/csrc/cpu/moe_fp8.cpp index 77bf5fbb2..3aaddacf2 100644 --- a/sgl-kernel/csrc/cpu/moe_fp8.cpp +++ b/sgl-kernel/csrc/cpu/moe_fp8.cpp @@ -4,6 +4,76 @@ namespace { +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; +// no remainder +#pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + bVec x = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x); + x0 = x0 * weight_vec; + x1 = x1 * weight_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + // out = input + input2 * scale template inline void add_mul_stub( @@ -65,6 +135,215 @@ inline void silu_and_mul_stub( } // anonymous namespace +template +void fused_experts_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache0 = hidden_states @ w1 + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(2 * N, BLOCK_N); + int64_t scale_size_N = div_up(2 * N, block_size_N); + int64_t scale_size_K = div_up(K, block_size_K); + int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const int64_t stride_e = 2 * N * K; + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + + alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n; + const float* __restrict__ Bs = + w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, input + index * K, K); + } + + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ ic0 + offset * 2 * N + nb * BLOCK_N, + /* Btmp */ Btmp, + /* Ctmp */ Ctmp, + /* scale */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ 2 * N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) + at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + silu_and_mul_stub(ic1 + m * N, ic0 + m * 2 * N, ic0 + m * 2 * N + N, N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + scale_size_N = div_up(K, block_size_N); + scale_size_K = div_up(N, block_size_K); + const int64_t stride_e2 = OC * IC; + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N]; + alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; + alignas(64) float Ctmp[BLOCK_M * BLOCK_K]; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = + w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* Btmp */ Btmp, + /* Ctmp */ Ctmp, + /* scale */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \ + template void fused_experts_fp8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic0, \ + TYPE* __restrict__ ic1, \ + TYPE* __restrict__ ic2, \ + TYPE* __restrict__ A_tmp, \ + const TYPE* __restrict__ input, \ + const at::Float8_e4m3fn* __restrict__ packed_w1, \ + const at::Float8_e4m3fn* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + int64_t block_size_N, \ + int64_t block_size_K, \ + const float* __restrict__ topk_weights, \ + const int32_t* __restrict__ sorted_ids, \ + const int32_t* __restrict__ expert_ids, \ + const int32_t* __restrict__ offsets, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t E, \ + int64_t topk, \ + int64_t num_tokens_post_pad) + +INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16); +INSTANTIATE_MOE_FP8_TEMPLATE(at::Half); + template void shared_expert_fp8_kernel_impl( scalar_t* __restrict__ output, @@ -100,8 +379,8 @@ void shared_expert_fp8_kernel_impl( for (int64_t i = begin; i < end; ++i) { int64_t mb = i / NB; int64_t nb = i % NB; - int64_t mb_size = std::min(M - mb * BLOCK_M, BLOCK_M); - int64_t nb_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); tinygemm_kernel( /* A */ input + mb * BLOCK_M * K, @@ -110,11 +389,11 @@ void shared_expert_fp8_kernel_impl( /* Btmp */ Btmp, /* Ctmp */ Ctmp, /* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K, - /* M */ mb_size, - /* N */ nb_size, + /* M */ m_size, + /* N */ n_size, /* K */ K, /* lda */ K, - /* ldb */ nb_size, + /* ldb */ n_size, /* ldc */ 2 * N, /* brg */ use_brgemm, /* block_size_K */ block_size_K); @@ -149,8 +428,8 @@ void shared_expert_fp8_kernel_impl( for (int64_t i = begin; i < end; ++i) { int64_t mb = i / NB2; int64_t nb = i % NB2; - int64_t mb_size = std::min(M - mb * BLOCK_M, BLOCK_M); - int64_t nb_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); // 2.a gemm: C = A @ B tinygemm_kernel( @@ -160,11 +439,11 @@ void shared_expert_fp8_kernel_impl( /* Btmp */ Btmp, /* Ctmp */ Ctmp, /* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K, - /* M */ mb_size, - /* N */ nb_size, + /* M */ m_size, + /* N */ n_size, /* K */ IC, /* lda */ IC, - /* ldb */ nb_size, + /* ldb */ n_size, /* ldc */ BLOCK_N, /* brg */ use_brgemm, /* block_size_K */ block_size_K); @@ -172,8 +451,8 @@ void shared_expert_fp8_kernel_impl( // 2.b copy from C to output and add fused_experts_out scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; - for (int64_t m = 0; m < mb_size; ++m) { - add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, nb_size); + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); } } }); diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index f8e9a4559..018f8efb8 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -130,8 +130,10 @@ at::Tensor fused_experts_cpu( at::Tensor& topk_ids, bool inplace, bool use_int8_w8a8, + bool use_fp8_w8a16, const std::optional& w1_scale, const std::optional& w2_scale, + const std::optional> block_size, const std::optional& a1_scale, const std::optional& a2_scale, bool is_vnni); @@ -260,7 +262,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // moe m.def( "fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool " - "inplace, bool use_int8_w8a8, Tensor? w1_scale, Tensor? w2_scale, Tensor? a1_scale, Tensor? a2_scale, bool " + "inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, int[]? block_size, Tensor? " + "a1_scale, Tensor? a2_scale, bool " "is_vnni) -> Tensor"); m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu); diff --git a/sgl-kernel/setup_cpu.py b/sgl-kernel/setup_cpu.py deleted file mode 100644 index 9fc07700b..000000000 --- a/sgl-kernel/setup_cpu.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2025 SGLang Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import os -import platform -import shutil -import sys -from pathlib import Path - -import torch -from setuptools import find_packages, setup -from setuptools.command.build_py import build_py -from torch.utils.cpp_extension import BuildExtension, CppExtension - -root = Path(__file__).parent.resolve() -arch = platform.machine().lower() - -if arch in ("x86_64", "amd64"): - plat_name = "manylinux2014_x86_64" -elif arch in ("aarch64", "arm64"): - plat_name = "manylinux2014_aarch64" -elif arch.startswith("ppc"): - plat_name = "manylinux2014_ppc64le" -else: - plat_name = f"manylinux2014_{arch}" - -if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv: - sys.argv.extend(["--plat-name", plat_name]) - - -def _get_version(): - with open(root / "pyproject.toml") as f: - for line in f: - if line.startswith("version"): - return line.split("=")[1].strip().strip('"') - - -cpu_fp8_ftz = os.getenv("SGLANG_CPU_FP8_CVT_FTZ", "1") == "1" - -operator_namespace = "sgl_kernel" -include_dirs = [ - "../../include", -] - -sources = [ - "csrc/cpu/activation.cpp", - "csrc/cpu/bmm.cpp", - "csrc/cpu/decode.cpp", - "csrc/cpu/extend.cpp", - "csrc/cpu/gemm.cpp", - "csrc/cpu/gemm_fp8.cpp", - "csrc/cpu/gemm_int8.cpp", - "csrc/cpu/moe.cpp", - "csrc/cpu/moe_fp8.cpp", - "csrc/cpu/moe_int8.cpp", - "csrc/cpu/norm.cpp", - "csrc/cpu/qkv_proj.cpp", - "csrc/cpu/topk.cpp", - "csrc/cpu/interface.cpp", - "csrc/cpu/shm.cpp", - "csrc/cpu/rope.cpp", - "csrc/cpu/torch_extension_cpu.cpp", -] - -extra_compile_args = { - "cxx": [ - "-O3", - "-Wno-unknown-pragmas", - "-march=native", - "-fopenmp", - ] -} -if cpu_fp8_ftz: - extra_compile_args["cxx"].append("-DSGLANG_CPU_FP8_CVT_FTZ") - -libraries = ["c10", "torch", "torch_python"] -cmdclass = { - "build_ext": BuildExtension.with_options(use_ninja=True), -} -Extension = CppExtension - -extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"] - -ext_modules = [ - Extension( - name="sgl_kernel.common_ops", - sources=sources, - include_dirs=include_dirs, - extra_compile_args=extra_compile_args, - libraries=libraries, - extra_link_args=extra_link_args, - py_limited_api=False, - ), -] - -setup( - name="sgl-kernel", - version=_get_version(), - packages=find_packages(where="python"), - package_dir={"": "python"}, - ext_modules=ext_modules, - cmdclass=cmdclass, - options={"bdist_wheel": {"py_limited_api": "cp39"}}, -) diff --git a/test/srt/cpu/test_moe.py b/test/srt/cpu/test_moe.py new file mode 100644 index 000000000..098f72cf1 --- /dev/null +++ b/test/srt/cpu/test_moe.py @@ -0,0 +1,259 @@ +import itertools +import math +import unittest + +# TODO: use interface in cpu.py +import sgl_kernel +import torch + +kernel = torch.ops.sgl_kernel + +from utils import ( + BLOCK_K, + BLOCK_N, + factor_for_scale, + fp8_max, + fp8_min, + native_fp8_fused_moe, + precision, + scaled_weight, + torch_naive_fused_moe, + torch_w8a8_per_column_fused_moe, +) + +from sglang.test.test_utils import CustomTestCase + + +def fused_moe(a, w1, w2, score, topk, renormalize, prepack): + + G = 1 + topk_group = 1 + + B, D = a.shape + topk_weights = torch.empty(B, topk, dtype=torch.float32) + topk_ids = torch.empty(B, topk, dtype=torch.int32) + topk_weights, topk_ids = kernel.grouped_topk_cpu( + a, score, topk, renormalize, G, topk_group + ) + + packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1 + packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2 + + inplace = True + return kernel.fused_experts_cpu( + a, + packed_w1, + packed_w2, + topk_weights, + topk_ids, + inplace, + False, + False, + None, + None, + None, + None, + None, + prepack, + ) + + +class TestFusedExperts(CustomTestCase): + M = [2, 114] + N = [32] + K = [32] + E = [4] + topk = [2] + renormalize = [False, True] + + M_int8 = [1, 39] + N_int8 = [128] + K_int8 = [256] + E_int8 = [8] + topk_int8 = [3] + + M_fp8 = [2, 121] + N_fp8 = [512] + K_fp8 = [256] + E_fp8 = [8] + topk_fp8 = [4] + + def _bf16_moe(self, m, n, k, e, topk, renormalize): + dtype = torch.bfloat16 + prepack = True + + a = torch.randn((m, k), device="cpu", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cpu", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cpu", dtype=dtype) / 10 + score = torch.randn((m, e), device="cpu", dtype=dtype) + + torch_output = torch_naive_fused_moe(a, w1, w2, score, topk, renormalize) + fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack) + + atol = rtol = precision[torch_output.dtype] + self.assertTrue( + torch.allclose(torch_output, fused_output, atol=atol, rtol=rtol) + ) + + def test_bf16_moe(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.E, + self.topk, + self.renormalize, + ): + with self.subTest( + m=params[0], + n=params[1], + k=params[2], + e=params[3], + topk=params[4], + renormalize=params[5], + ): + self._bf16_moe(*params) + + def _int8_moe(self, M, N, K, E, topk): + dtype = torch.bfloat16 + prepack = True + + # Initialize int8 quantization parameters + int8_factor_for_scale = 1e-2 + int8_max = 127 + int8_min = -128 + + # Input tensor + # M * K + a = torch.randn((M, K), dtype=dtype) / math.sqrt(K) + + # Generate int8 weights + w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 + w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 + w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) + + # Generate scale for each column (per-column quantization) + w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * int8_factor_for_scale + w2_s = torch.rand(E, K, device=w2_fp32.device) * int8_factor_for_scale + + # Calculate routing + score = torch.randn((M, E), dtype=dtype) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + + ref_out = torch_w8a8_per_column_fused_moe( + a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk + ) + + inplace = True + packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1 + packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2 + out = kernel.fused_experts_cpu( + a, + packed_w1, + packed_w2, + topk_weight, + topk_ids.to(torch.int32), + inplace, + True, + False, + w1_s, + w2_s, + None, + None, + None, + prepack, + ) + + atol = rtol = precision[ref_out.dtype] + # Increase the tolerance for large input shapes + if M > 35: + atol = rtol = 0.02 + self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) + + def test_int8_moe(self): + for params in itertools.product( + self.M_int8, + self.N_int8, + self.K_int8, + self.E_int8, + self.topk_int8, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + topk=params[4], + ): + self._int8_moe(*params) + + def _fp8_moe(self, M, N, K, E, topk): + dtype = torch.bfloat16 + + a = torch.randn(M, K, dtype=dtype) / math.sqrt(K) + + w1_fp32 = torch.randn(E, 2 * N, K) + w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w2_fp32 = torch.randn(E, K, N) + w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w1s = torch.randn(E, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale + w2s = torch.randn(E, K // BLOCK_N, N // BLOCK_K) * factor_for_scale + + w1_scaled = scaled_weight(w1, w1s) + w2_scaled = scaled_weight(w2, w2s) + + score = torch.randn((M, E), dtype=dtype) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + + w1 = kernel.convert_weight_packed(w1) + w2 = kernel.convert_weight_packed(w2) + + ref_out = native_fp8_fused_moe( + a, w1_scaled, w2_scaled, topk_weight, topk_ids, topk + ) + out = kernel.fused_experts_cpu( + a, + w1, + w2, + topk_weight, + topk_ids.to(torch.int32), + False, + False, + True, + w1s, + w2s, + [BLOCK_N, BLOCK_K], + None, + None, + True, + ) + + atol = rtol = precision[dtype] + self.assertTrue(torch.allclose(ref_out.bfloat16(), out, atol=atol, rtol=rtol)) + + def test_fp8_moe(self): + for params in itertools.product( + self.M_fp8, + self.N_fp8, + self.K_fp8, + self.E_fp8, + self.topk_fp8, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + topk=params[4], + ): + self._fp8_moe(*params) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/cpu/utils.py b/test/srt/cpu/utils.py index 4665eb2cf..1716782fe 100644 --- a/test/srt/cpu/utils.py +++ b/test/srt/cpu/utils.py @@ -148,3 +148,99 @@ def scaled_weight(weight, scales): .contiguous() .view(E, N, K) ) + + +def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + + if renormalize: + topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) + + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( + 0, 1 + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +def torch_w8a8_per_column_fused_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk): + """This function performs fused moe with per-column int8 quantization using native torch.""" + + B, D = a.shape + # Perform per-token quantization + a_q, a_s = per_token_quant_int8(a) + # Repeat tokens to match topk + a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + # Also repeat the scale + a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1] + + out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device) + + # Calculate routing + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + # Process each expert + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + # First MLP layer: note that a_s is now per-token + inter_out = native_w8a8_per_token_matmul( + a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + bias=None, + output_dtype=torch.float32, + ) + # Activation function + act_out = SiluAndMul(inter_out) + # Quantize activation output with per-token + act_out_q, act_out_s = per_token_quant_int8(act_out) + # Second MLP layer + out[mask] = native_w8a8_per_token_matmul( + act_out_q, + w2[i], + act_out_s, + w2_s[i], + bias=None, + output_dtype=torch.float32, + ) + # Apply routing weights and sum + return ( + (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)) + .sum(dim=1) + .to(a.dtype) + ) + + +def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D).float() + out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device) + + # Calculate routing + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + ic0 = torch.matmul(a[mask], w1[i].transpose(0, 1)) + ic1 = SiluAndMul(ic0) + out[mask] = torch.matmul(ic1, w2[i].transpose(0, 1)) + + return ( + (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)) + .sum(dim=1) + .to(a.dtype) + )