CPU: map changes from developing branch in sgl-kernel (#6833)

Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
YanbingJiang
2025-06-10 16:08:15 +08:00
committed by GitHub
parent 81372f3bef
commit fcde67b016
20 changed files with 1321 additions and 321 deletions

View File

@@ -142,6 +142,8 @@ void fused_experts_fp8_kernel_impl(
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
@@ -178,9 +180,6 @@ void fused_experts_fp8_kernel_impl(
int tid = at::get_thread_num();
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
@@ -212,8 +211,8 @@ void fused_experts_fp8_kernel_impl(
/* A */ A,
/* B */ B,
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
/* Btmp */ Btmp,
/* Ctmp */ Ctmp,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs,
/* M */ m_size,
/* N */ n_size,
@@ -250,9 +249,8 @@ void fused_experts_fp8_kernel_impl(
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N];
int tid = at::get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
alignas(64) float Ctmp[BLOCK_M * BLOCK_K];
bool is_brgemm_used = false;
@@ -281,8 +279,8 @@ void fused_experts_fp8_kernel_impl(
/* A */ A,
/* B */ B,
/* C */ C,
/* Btmp */ Btmp,
/* Ctmp */ Ctmp,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs,
/* M */ m_size,
/* N */ n_size,
@@ -323,6 +321,8 @@ void fused_experts_fp8_kernel_impl(
TYPE* __restrict__ ic1, \
TYPE* __restrict__ ic2, \
TYPE* __restrict__ A_tmp, \
TYPE* __restrict__ B_tmp, \
float* __restrict__ C_tmp, \
const TYPE* __restrict__ input, \
const at::Float8_e4m3fn* __restrict__ packed_w1, \
const at::Float8_e4m3fn* __restrict__ packed_w2, \
@@ -349,6 +349,8 @@ void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
@@ -373,8 +375,7 @@ void shared_expert_fp8_kernel_impl(
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
int tid = at::get_thread_num();
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
@@ -386,8 +387,8 @@ void shared_expert_fp8_kernel_impl(
/* A */ input + mb * BLOCK_M * K,
/* B */ packed_w1 + nb * BLOCK_N * K,
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
/* Btmp */ Btmp,
/* Ctmp */ Ctmp,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size,
/* N */ n_size,
@@ -421,9 +422,8 @@ void shared_expert_fp8_kernel_impl(
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N];
int tid = at::get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
alignas(64) float Ctmp[BLOCK_M * BLOCK_K];
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
@@ -436,8 +436,8 @@ void shared_expert_fp8_kernel_impl(
/* A */ ic1 + mb * BLOCK_M * N,
/* B */ packed_w2 + nb * BLOCK_N * N,
/* C */ C,
/* Btmp */ Btmp,
/* Ctmp */ Ctmp,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size,
/* N */ n_size,
@@ -467,6 +467,8 @@ void shared_expert_fp8_kernel_impl(
TYPE* __restrict__ output, \
TYPE* __restrict__ ic0, \
TYPE* __restrict__ ic1, \
TYPE* __restrict__ B_tmp, \
float* __restrict__ C_tmp, \
const TYPE* __restrict__ input, \
const at::Float8_e4m3fn* __restrict__ packed_w1, \
const at::Float8_e4m3fn* __restrict__ packed_w2, \