Add CPU optimized kernels for topk and rope fusions (#6456)

This commit is contained in:
jianan-gu
2025-06-03 08:37:34 +08:00
committed by GitHub
parent ff91474825
commit ff00895c46
7 changed files with 829 additions and 98 deletions

View File

@@ -4,6 +4,67 @@
namespace {
// NB: avoid using `at::vec::map<>` on bfloat16 or half
// Llama4TextL2Norm
template <typename scalar_t>
void l2norm_kernel_impl(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
int64_t batch_size,
int64_t hidden_size,
float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
// local ptrs
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
const scalar_t* __restrict__ input_ptr = input + i * hidden_size;
fVec sum_fvec = fVec(float(0));
float sum_val = float(0);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
sum_val += x_val * x_val;
}
sum_val += vec_reduce_sum(sum_fvec);
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
const fVec scale_fvec = fVec(rsqrt_var);
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
x_fvec0 = x_fvec0 * scale_fvec;
x_fvec1 = x_fvec1 * scale_fvec;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(out_ptr + d);
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var);
}
}
});
}
template <typename scalar_t>
void rmsnorm_kernel_impl(
scalar_t* __restrict__ output,
@@ -160,6 +221,22 @@ void fused_add_rmsnorm_kernel_impl(
} // anonymous namespace
// input : {batch_size, hidden_size}
at::Tensor l2norm_cpu(at::Tensor& input, double eps) {
RECORD_FUNCTION("sgl-kernel::l2norm_cpu", std::vector<c10::IValue>({input}));
CHECK_INPUT(input);
CHECK_DIM(2, input);
int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1);
at::Tensor output = at::empty_like(input);
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "l2norm_kernel", [&] {
l2norm_kernel_impl<scalar_t>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), batch_size, hidden_size, eps);
});
return output;
}
// input : {batch_size, hidden_size}
// weight: {hidden_size}
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {