Add fp8 fused_experts kernel for CPU in sgl-kernel and add UT (#6404)

This commit is contained in:
Chunyuan WU
2025-05-23 17:01:55 +08:00
committed by GitHub
parent 4ba1eea83f
commit 3ded6235c9
7 changed files with 752 additions and 157 deletions

View File

@@ -932,6 +932,40 @@ void shared_expert_kernel_impl(
} // anonymous namespace
// common checks
static inline void check_moe_scales(
bool use_int8_w8a8,
bool use_fp8_w8a16,
const std::optional<at::Tensor>& w1_scale,
const std::optional<at::Tensor>& w2_scale,
const std::optional<std::vector<int64_t>> block_size,
const std::optional<at::Tensor>& a1_scale,
const std::optional<at::Tensor>& a2_scale) {
if (use_int8_w8a8) {
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8.");
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8.");
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
}
if (use_fp8_w8a16) {
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16.");
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16.");
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
TORCH_CHECK(block_size.value().size() == 2, "expect block_size.size() to be 2.");
}
}
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
auto w1s = w1_scale.value(); \
auto w2s = w2_scale.value(); \
auto block_size_val = block_size.value(); \
int64_t block_size_N = block_size_val[0]; \
int64_t block_size_K = block_size_val[1]; \
TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \
TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \
TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \
TORCH_CHECK(w2s.size(DIM1) == N / block_size_K)
// hidden_states: [M, K]
// w1: [E, 2N, K]
// w2: [E, K, N]
@@ -946,8 +980,10 @@ at::Tensor fused_experts_cpu(
at::Tensor& topk_ids,
bool inplace,
bool use_int8_w8a8,
bool use_fp8_w8a16,
const std::optional<at::Tensor>& w1_scale,
const std::optional<at::Tensor>& w2_scale,
const std::optional<std::vector<int64_t>> block_size,
const std::optional<at::Tensor>& a1_scale,
const std::optional<at::Tensor>& a2_scale,
bool is_vnni) {
@@ -990,12 +1026,8 @@ at::Tensor fused_experts_cpu(
CHECK_EQ(packed_w1.size(2), packed_K);
CHECK_EQ(packed_w2.size(2), packed_N);
if (use_int8_w8a8) {
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8.");
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8.");
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
}
// check scales
check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale);
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
@@ -1047,6 +1079,9 @@ at::Tensor fused_experts_cpu(
// 5. Aq_tmp : [M, K] or [M * topk, N]
// 6. As_tmp : [M * topk]
//
// for fp8 w8a16:
// 7. intermediate_cache1 : [M * topk, 2N]
//
int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 +
num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) +
num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
@@ -1054,6 +1089,9 @@ at::Tensor fused_experts_cpu(
if (use_int8_w8a8) {
buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float);
}
if (use_fp8_w8a16) {
buffer_size_nbytes += M * topk * 2 * N * 2;
}
auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
@@ -1095,6 +1133,35 @@ at::Tensor fused_experts_cpu(
E,
topk,
num_tokens_post_pad);
} else if (use_fp8_w8a16) {
// here we just ignore C_tmp as it is not used
scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K));
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(A_tmp + num_threads * BLOCK_M * K));
CHECK_MOE_SCALES_FP8(1, 2);
fused_experts_fp8_kernel_impl(
out_hidden_states.data_ptr<scalar_t>(),
intermediate_cache0,
intermediate_cache1,
intermediate_cache2,
A_tmp,
hidden_states.data_ptr<scalar_t>(),
packed_w1.data_ptr<at::Float8_e4m3fn>(),
packed_w2.data_ptr<at::Float8_e4m3fn>(),
w1s.data_ptr<float>(),
w2s.data_ptr<float>(),
block_size_N,
block_size_K,
topk_weights.data_ptr<float>(),
sorted_ids,
expert_ids,
offsets,
M,
N,
K,
E,
topk,
num_tokens_post_pad);
} else {
scalar_t* __restrict__ A_tmp = intermediate_cache2 + M * topk * K;
float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K));
@@ -1176,17 +1243,8 @@ at::Tensor shared_expert_cpu(
CHECK_EQ(packed_w1.size(1), packed_K);
CHECK_EQ(packed_w2.size(1), packed_N);
if (use_int8_w8a8) {
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8.");
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8.");
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
}
if (use_fp8_w8a16) {
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16.");
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16.");
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
}
// check scales
check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale);
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
@@ -1244,17 +1302,7 @@ at::Tensor shared_expert_cpu(
} else if (use_fp8_w8a16) {
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
auto w1s = w1_scale.value();
auto w2s = w2_scale.value();
auto block_size_val = block_size.value();
TORCH_CHECK(block_size_val.size() == 2, "shared_expert: expect block_size.size() to be 2.");
int64_t block_size_N = block_size_val[0];
int64_t block_size_K = block_size_val[1];
TORCH_CHECK(w1s.size(0) == 2 * N / block_size_N);
TORCH_CHECK(w1s.size(1) == K / block_size_K);
TORCH_CHECK(w2s.size(0) == K / block_size_N);
TORCH_CHECK(w2s.size(1) == N / block_size_K);
CHECK_MOE_SCALES_FP8(0, 1);
shared_expert_fp8_kernel_impl<scalar_t>(
out_hidden_states.data_ptr<scalar_t>(),
intermediate_cache0,