[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32 (#8212)

Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
This commit is contained in:
Chunyuan WU
2025-08-05 09:28:31 +08:00
committed by GitHub
parent d4bf5a8524
commit 08f8f49016
4 changed files with 94 additions and 28 deletions

View File

@@ -16,6 +16,25 @@ inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, c
return at::vec::convert_from_float<scalar_t>(a, b);
}
// allow f16, bf16
template <typename scalar_t, typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 1>
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const scalar_t* __restrict__ data) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
bVec x_vec = bVec::loadu(data);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x_vec);
return std::make_tuple(x0, x1);
}
// allow f32
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const float* __restrict__ data) {
using fVec = at::vec::Vectorized<float>;
fVec x0 = fVec::loadu(data);
fVec x1 = fVec::loadu(data + fVec::size());
return std::make_tuple(x0, x1);
}
#if defined(CPU_CAPABILITY_AVX512)
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics