Add fp8 fused_experts kernel for CPU in sgl-kernel and add UT (#6404)
This commit is contained in:
@@ -932,6 +932,40 @@ void shared_expert_kernel_impl(
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// common checks
|
||||
static inline void check_moe_scales(
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale) {
|
||||
if (use_int8_w8a8) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8.");
|
||||
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
|
||||
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
|
||||
TORCH_CHECK(block_size.value().size() == 2, "expect block_size.size() to be 2.");
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
|
||||
auto w1s = w1_scale.value(); \
|
||||
auto w2s = w2_scale.value(); \
|
||||
auto block_size_val = block_size.value(); \
|
||||
int64_t block_size_N = block_size_val[0]; \
|
||||
int64_t block_size_K = block_size_val[1]; \
|
||||
TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \
|
||||
TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \
|
||||
TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \
|
||||
TORCH_CHECK(w2s.size(DIM1) == N / block_size_K)
|
||||
|
||||
// hidden_states: [M, K]
|
||||
// w1: [E, 2N, K]
|
||||
// w2: [E, K, N]
|
||||
@@ -946,8 +980,10 @@ at::Tensor fused_experts_cpu(
|
||||
at::Tensor& topk_ids,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni) {
|
||||
@@ -990,12 +1026,8 @@ at::Tensor fused_experts_cpu(
|
||||
CHECK_EQ(packed_w1.size(2), packed_K);
|
||||
CHECK_EQ(packed_w2.size(2), packed_N);
|
||||
|
||||
if (use_int8_w8a8) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8.");
|
||||
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
|
||||
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
|
||||
}
|
||||
// check scales
|
||||
check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale);
|
||||
|
||||
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
|
||||
|
||||
@@ -1047,6 +1079,9 @@ at::Tensor fused_experts_cpu(
|
||||
// 5. Aq_tmp : [M, K] or [M * topk, N]
|
||||
// 6. As_tmp : [M * topk]
|
||||
//
|
||||
// for fp8 w8a16:
|
||||
// 7. intermediate_cache1 : [M * topk, 2N]
|
||||
//
|
||||
int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 +
|
||||
num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) +
|
||||
num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
|
||||
@@ -1054,6 +1089,9 @@ at::Tensor fused_experts_cpu(
|
||||
if (use_int8_w8a8) {
|
||||
buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float);
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
buffer_size_nbytes += M * topk * 2 * N * 2;
|
||||
}
|
||||
|
||||
auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
|
||||
|
||||
@@ -1095,6 +1133,35 @@ at::Tensor fused_experts_cpu(
|
||||
E,
|
||||
topk,
|
||||
num_tokens_post_pad);
|
||||
} else if (use_fp8_w8a16) {
|
||||
// here we just ignore C_tmp as it is not used
|
||||
scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K));
|
||||
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(A_tmp + num_threads * BLOCK_M * K));
|
||||
|
||||
CHECK_MOE_SCALES_FP8(1, 2);
|
||||
fused_experts_fp8_kernel_impl(
|
||||
out_hidden_states.data_ptr<scalar_t>(),
|
||||
intermediate_cache0,
|
||||
intermediate_cache1,
|
||||
intermediate_cache2,
|
||||
A_tmp,
|
||||
hidden_states.data_ptr<scalar_t>(),
|
||||
packed_w1.data_ptr<at::Float8_e4m3fn>(),
|
||||
packed_w2.data_ptr<at::Float8_e4m3fn>(),
|
||||
w1s.data_ptr<float>(),
|
||||
w2s.data_ptr<float>(),
|
||||
block_size_N,
|
||||
block_size_K,
|
||||
topk_weights.data_ptr<float>(),
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
offsets,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
E,
|
||||
topk,
|
||||
num_tokens_post_pad);
|
||||
} else {
|
||||
scalar_t* __restrict__ A_tmp = intermediate_cache2 + M * topk * K;
|
||||
float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K));
|
||||
@@ -1176,17 +1243,8 @@ at::Tensor shared_expert_cpu(
|
||||
CHECK_EQ(packed_w1.size(1), packed_K);
|
||||
CHECK_EQ(packed_w2.size(1), packed_N);
|
||||
|
||||
if (use_int8_w8a8) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8.");
|
||||
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
|
||||
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
|
||||
}
|
||||
// check scales
|
||||
check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale);
|
||||
|
||||
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
|
||||
|
||||
@@ -1244,17 +1302,7 @@ at::Tensor shared_expert_cpu(
|
||||
} else if (use_fp8_w8a16) {
|
||||
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
|
||||
|
||||
auto w1s = w1_scale.value();
|
||||
auto w2s = w2_scale.value();
|
||||
auto block_size_val = block_size.value();
|
||||
TORCH_CHECK(block_size_val.size() == 2, "shared_expert: expect block_size.size() to be 2.");
|
||||
int64_t block_size_N = block_size_val[0];
|
||||
int64_t block_size_K = block_size_val[1];
|
||||
TORCH_CHECK(w1s.size(0) == 2 * N / block_size_N);
|
||||
TORCH_CHECK(w1s.size(1) == K / block_size_K);
|
||||
TORCH_CHECK(w2s.size(0) == K / block_size_N);
|
||||
TORCH_CHECK(w2s.size(1) == N / block_size_K);
|
||||
|
||||
CHECK_MOE_SCALES_FP8(0, 1);
|
||||
shared_expert_fp8_kernel_impl<scalar_t>(
|
||||
out_hidden_states.data_ptr<scalar_t>(),
|
||||
intermediate_cache0,
|
||||
|
||||
Reference in New Issue
Block a user