Add fp8 shared_expert kernel for CPU in sgl-kernel and add UT (#6339)
Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com> Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -1137,8 +1137,10 @@ at::Tensor shared_expert_cpu(
|
||||
double routed_scaling_factor,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
std::optional<at::Tensor>& w1_scale,
|
||||
std::optional<at::Tensor>& w2_scale,
|
||||
std::optional<std::vector<int64_t>> block_size,
|
||||
std::optional<at::Tensor>& a1_scale,
|
||||
std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni) {
|
||||
@@ -1180,6 +1182,11 @@ at::Tensor shared_expert_cpu(
|
||||
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
|
||||
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
|
||||
}
|
||||
|
||||
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
|
||||
|
||||
@@ -1191,12 +1198,18 @@ at::Tensor shared_expert_cpu(
|
||||
// 3. Aq_tmp : [M, K] or [M, N]
|
||||
// 4. As_tmp : [M]
|
||||
//
|
||||
// for fp8 w8a16:
|
||||
// 5. intermediate_cache0 : [M, 2N]
|
||||
//
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
|
||||
|
||||
if (use_int8_w8a8) {
|
||||
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
buffer_size_nbytes += M * 2 * N * 2;
|
||||
}
|
||||
|
||||
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] {
|
||||
@@ -1228,6 +1241,36 @@ at::Tensor shared_expert_cpu(
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
} else if (use_fp8_w8a16) {
|
||||
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
|
||||
|
||||
auto w1s = w1_scale.value();
|
||||
auto w2s = w2_scale.value();
|
||||
auto block_size_val = block_size.value();
|
||||
TORCH_CHECK(block_size_val.size() == 2, "shared_expert: expect block_size.size() to be 2.");
|
||||
int64_t block_size_N = block_size_val[0];
|
||||
int64_t block_size_K = block_size_val[1];
|
||||
TORCH_CHECK(w1s.size(0) == 2 * N / block_size_N);
|
||||
TORCH_CHECK(w1s.size(1) == K / block_size_K);
|
||||
TORCH_CHECK(w2s.size(0) == K / block_size_N);
|
||||
TORCH_CHECK(w2s.size(1) == N / block_size_K);
|
||||
|
||||
shared_expert_fp8_kernel_impl<scalar_t>(
|
||||
out_hidden_states.data_ptr<scalar_t>(),
|
||||
intermediate_cache0,
|
||||
intermediate_cache1,
|
||||
hidden_states.data_ptr<scalar_t>(),
|
||||
packed_w1.data_ptr<at::Float8_e4m3fn>(),
|
||||
packed_w2.data_ptr<at::Float8_e4m3fn>(),
|
||||
w1s.data_ptr<float>(),
|
||||
w2s.data_ptr<float>(),
|
||||
block_size_N,
|
||||
block_size_K,
|
||||
fused_experts_out.data_ptr<scalar_t>(),
|
||||
routed_scaling_factor,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
} else {
|
||||
shared_expert_kernel_impl<scalar_t>(
|
||||
out_hidden_states.data_ptr<scalar_t>(),
|
||||
|
||||
Reference in New Issue
Block a user