Add fp8 shared_expert kernel for CPU in sgl-kernel and add UT (#6339)

Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
Chunyuan WU
2025-05-19 03:42:15 +08:00
committed by GitHub
parent f11481b921
commit 5dd62c3a6f
8 changed files with 603 additions and 32 deletions

View File

@@ -1137,8 +1137,10 @@ at::Tensor shared_expert_cpu(
double routed_scaling_factor,
bool inplace,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor>& w1_scale,
std::optional<at::Tensor>& w2_scale,
std::optional<std::vector<int64_t>> block_size,
std::optional<at::Tensor>& a1_scale,
std::optional<at::Tensor>& a2_scale,
bool is_vnni) {
@@ -1180,6 +1182,11 @@ at::Tensor shared_expert_cpu(
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
}
if (use_fp8_w8a16) {
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16.");
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16.");
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
}
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
@@ -1191,12 +1198,18 @@ at::Tensor shared_expert_cpu(
// 3. Aq_tmp : [M, K] or [M, N]
// 4. As_tmp : [M]
//
// for fp8 w8a16:
// 5. intermediate_cache0 : [M, 2N]
//
int num_threads = at::get_num_threads();
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
if (use_int8_w8a8) {
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
}
if (use_fp8_w8a16) {
buffer_size_nbytes += M * 2 * N * 2;
}
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] {
@@ -1228,6 +1241,36 @@ at::Tensor shared_expert_cpu(
M,
N,
K);
} else if (use_fp8_w8a16) {
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
auto w1s = w1_scale.value();
auto w2s = w2_scale.value();
auto block_size_val = block_size.value();
TORCH_CHECK(block_size_val.size() == 2, "shared_expert: expect block_size.size() to be 2.");
int64_t block_size_N = block_size_val[0];
int64_t block_size_K = block_size_val[1];
TORCH_CHECK(w1s.size(0) == 2 * N / block_size_N);
TORCH_CHECK(w1s.size(1) == K / block_size_K);
TORCH_CHECK(w2s.size(0) == K / block_size_N);
TORCH_CHECK(w2s.size(1) == N / block_size_K);
shared_expert_fp8_kernel_impl<scalar_t>(
out_hidden_states.data_ptr<scalar_t>(),
intermediate_cache0,
intermediate_cache1,
hidden_states.data_ptr<scalar_t>(),
packed_w1.data_ptr<at::Float8_e4m3fn>(),
packed_w2.data_ptr<at::Float8_e4m3fn>(),
w1s.data_ptr<float>(),
w2s.data_ptr<float>(),
block_size_N,
block_size_K,
fused_experts_out.data_ptr<scalar_t>(),
routed_scaling_factor,
M,
N,
K);
} else {
shared_expert_kernel_impl<scalar_t>(
out_hidden_states.data_ptr<scalar_t>(),