[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32 (#8212)
Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
This commit is contained in:
@@ -16,6 +16,25 @@ inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, c
|
||||
return at::vec::convert_from_float<scalar_t>(a, b);
|
||||
}
|
||||
|
||||
// allow f16, bf16
|
||||
template <typename scalar_t, typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 1>
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const scalar_t* __restrict__ data) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
bVec x_vec = bVec::loadu(data);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x_vec);
|
||||
return std::make_tuple(x0, x1);
|
||||
}
|
||||
|
||||
// allow f32
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const float* __restrict__ data) {
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
fVec x0 = fVec::loadu(data);
|
||||
fVec x1 = fVec::loadu(data + fVec::size());
|
||||
return std::make_tuple(x0, x1);
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
|
||||
|
||||
Reference in New Issue
Block a user