CPU: map changes from developing branch in sgl-kernel (#6833)

Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
YanbingJiang
2025-06-10 16:08:15 +08:00
committed by GitHub
parent 81372f3bef
commit fcde67b016
20 changed files with 1321 additions and 321 deletions

View File

@@ -1080,7 +1080,8 @@ at::Tensor fused_experts_cpu(
// 6. As_tmp : [M * topk]
//
// for fp8 w8a16:
// 7. intermediate_cache1 : [M * topk, 2N]
// 7. intermediate_cache0 : [M * topk, 2N]
// 8. B_tmp : [T, BLOCK_N, std::max(K, N)]
//
int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 +
num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) +
@@ -1090,7 +1091,7 @@ at::Tensor fused_experts_cpu(
buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float);
}
if (use_fp8_w8a16) {
buffer_size_nbytes += M * topk * 2 * N * 2;
buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2;
}
auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
@@ -1136,7 +1137,9 @@ at::Tensor fused_experts_cpu(
} else if (use_fp8_w8a16) {
// here we just ignore C_tmp as it is not used
scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K));
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(A_tmp + num_threads * BLOCK_M * K));
float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K));
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * topk * 2 * N));
CHECK_MOE_SCALES_FP8(1, 2);
fused_experts_fp8_kernel_impl(
@@ -1145,6 +1148,8 @@ at::Tensor fused_experts_cpu(
intermediate_cache1,
intermediate_cache2,
A_tmp,
B_tmp,
C_tmp,
hidden_states.data_ptr<scalar_t>(),
packed_w1.data_ptr<at::Float8_e4m3fn>(),
packed_w2.data_ptr<at::Float8_e4m3fn>(),
@@ -1258,6 +1263,7 @@ at::Tensor shared_expert_cpu(
//
// for fp8 w8a16:
// 5. intermediate_cache0 : [M, 2N]
// 6. B_tmp: [T, BLOCK_M, max(K, N)]
//
int num_threads = at::get_num_threads();
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
@@ -1266,7 +1272,7 @@ at::Tensor shared_expert_cpu(
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
}
if (use_fp8_w8a16) {
buffer_size_nbytes += M * 2 * N * 2;
buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2;
}
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
@@ -1301,12 +1307,15 @@ at::Tensor shared_expert_cpu(
K);
} else if (use_fp8_w8a16) {
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * 2 * N));
CHECK_MOE_SCALES_FP8(0, 1);
shared_expert_fp8_kernel_impl<scalar_t>(
out_hidden_states.data_ptr<scalar_t>(),
intermediate_cache0,
intermediate_cache1,
B_tmp,
C_tmp,
hidden_states.data_ptr<scalar_t>(),
packed_w1.data_ptr<at::Float8_e4m3fn>(),
packed_w2.data_ptr<at::Float8_e4m3fn>(),