[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32 (#8212)
Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
This commit is contained in:
@@ -47,6 +47,45 @@ namespace {
|
||||
} \
|
||||
}()
|
||||
|
||||
// dispatch with mixed dtypes (TYPE1, TYPE2):
|
||||
// TYPE1: the primary dtype (input, output, weight);
|
||||
// TYPE2: the secondary dtype (bias, etc.).
|
||||
#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2, ...) \
|
||||
[&] { \
|
||||
if (TYPE2 == at::kFloat) { \
|
||||
switch (TYPE1) { \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
using param_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
using param_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(TYPE1 == TYPE2); \
|
||||
switch (TYPE1) { \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
using param_t = at::BFloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
using param_t = at::Half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
|
||||
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
|
||||
|
||||
Reference in New Issue
Block a user