[CPU][sgl-kernel] biased_grouped_topk: fix correction_bias dtype to float32 (#8212)

Co-authored-by: jianan-gu <jianan.gu@intel.com>
Co-authored-by: YanbingJiang <yanbing.jiang@intel.com>
This commit is contained in:
Chunyuan WU
2025-08-05 09:28:31 +08:00
committed by GitHub
parent d4bf5a8524
commit 08f8f49016
4 changed files with 94 additions and 28 deletions

View File

@@ -47,6 +47,45 @@ namespace {
} \
}()
// dispatch with mixed dtypes (TYPE1, TYPE2):
// TYPE1: the primary dtype (input, output, weight);
// TYPE2: the secondary dtype (bias, etc.).
#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2, ...) \
[&] { \
if (TYPE2 == at::kFloat) { \
switch (TYPE1) { \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
using param_t = float; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
using param_t = float; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
} else { \
TORCH_CHECK(TYPE1 == TYPE2); \
switch (TYPE1) { \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
using param_t = at::BFloat16; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
using param_t = at::Half; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
} \
}()
#define UNUSED(x) (void)(x)
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")