CPU: map changes from developing branch in sgl-kernel (#6833)
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -1080,7 +1080,8 @@ at::Tensor fused_experts_cpu(
|
||||
// 6. As_tmp : [M * topk]
|
||||
//
|
||||
// for fp8 w8a16:
|
||||
// 7. intermediate_cache1 : [M * topk, 2N]
|
||||
// 7. intermediate_cache0 : [M * topk, 2N]
|
||||
// 8. B_tmp : [T, BLOCK_N, std::max(K, N)]
|
||||
//
|
||||
int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 +
|
||||
num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) +
|
||||
@@ -1090,7 +1091,7 @@ at::Tensor fused_experts_cpu(
|
||||
buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float);
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
buffer_size_nbytes += M * topk * 2 * N * 2;
|
||||
buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2;
|
||||
}
|
||||
|
||||
auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
|
||||
@@ -1136,7 +1137,9 @@ at::Tensor fused_experts_cpu(
|
||||
} else if (use_fp8_w8a16) {
|
||||
// here we just ignore C_tmp as it is not used
|
||||
scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K));
|
||||
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(A_tmp + num_threads * BLOCK_M * K));
|
||||
float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K));
|
||||
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
|
||||
scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * topk * 2 * N));
|
||||
|
||||
CHECK_MOE_SCALES_FP8(1, 2);
|
||||
fused_experts_fp8_kernel_impl(
|
||||
@@ -1145,6 +1148,8 @@ at::Tensor fused_experts_cpu(
|
||||
intermediate_cache1,
|
||||
intermediate_cache2,
|
||||
A_tmp,
|
||||
B_tmp,
|
||||
C_tmp,
|
||||
hidden_states.data_ptr<scalar_t>(),
|
||||
packed_w1.data_ptr<at::Float8_e4m3fn>(),
|
||||
packed_w2.data_ptr<at::Float8_e4m3fn>(),
|
||||
@@ -1258,6 +1263,7 @@ at::Tensor shared_expert_cpu(
|
||||
//
|
||||
// for fp8 w8a16:
|
||||
// 5. intermediate_cache0 : [M, 2N]
|
||||
// 6. B_tmp: [T, BLOCK_M, max(K, N)]
|
||||
//
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
|
||||
@@ -1266,7 +1272,7 @@ at::Tensor shared_expert_cpu(
|
||||
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
buffer_size_nbytes += M * 2 * N * 2;
|
||||
buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2;
|
||||
}
|
||||
|
||||
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
|
||||
@@ -1301,12 +1307,15 @@ at::Tensor shared_expert_cpu(
|
||||
K);
|
||||
} else if (use_fp8_w8a16) {
|
||||
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
|
||||
scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * 2 * N));
|
||||
|
||||
CHECK_MOE_SCALES_FP8(0, 1);
|
||||
shared_expert_fp8_kernel_impl<scalar_t>(
|
||||
out_hidden_states.data_ptr<scalar_t>(),
|
||||
intermediate_cache0,
|
||||
intermediate_cache1,
|
||||
B_tmp,
|
||||
C_tmp,
|
||||
hidden_states.data_ptr<scalar_t>(),
|
||||
packed_w1.data_ptr<at::Float8_e4m3fn>(),
|
||||
packed_w2.data_ptr<at::Float8_e4m3fn>(),
|
||||
|
||||
Reference in New Issue
Block a user