adapt to ds3.2

This commit is contained in:
maxiao
2025-09-30 17:44:54 +08:00
parent 1237aa19ce
commit 8f7453e3af
9 changed files with 199 additions and 49 deletions

View File

@@ -358,25 +358,25 @@ __device__ __forceinline__ dstDtype castFromFloat(float val) {
#endif
// add FP8 support
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else // USE_ROCM
#if HIP_FP8_TYPE_FNUZ
#include <c10/util/Float8_e4m3fnuz.h>
using FP8_TYPE = c10::Float8_e4m3fnuz;
constexpr auto FP8_E4M3_MAX = 224.0f;
#else
#if HIP_FP8_TYPE_E4M3
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else
#error "fp8 is not supported in this processor (arch < gfx942)."
#endif // HIP_FP8_TYPE_E4M3
#endif // HIP_FP8_TYPE_FNUZ
#endif // USE_ROCM
// #ifndef USE_ROCM
// #include <c10/util/Float8_e4m3fn.h>
// using FP8_TYPE = c10::Float8_e4m3fn;
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
// #else // USE_ROCM
// #if HIP_FP8_TYPE_FNUZ
// #include <c10/util/Float8_e4m3fnuz.h>
// using FP8_TYPE = c10::Float8_e4m3fnuz;
// constexpr auto FP8_E4M3_MAX = 224.0f;
// #else
// #if HIP_FP8_TYPE_E4M3
// #include <c10/util/Float8_e4m3fn.h>
// using FP8_TYPE = c10::Float8_e4m3fn;
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
// #else
// #error "fp8 is not supported in this processor (arch < gfx942)."
// #endif // HIP_FP8_TYPE_E4M3
// #endif // HIP_FP8_TYPE_FNUZ
// #endif // USE_ROCM
#define FULL_MASK 0xffffffff