Sync from v0.13
This commit is contained in:
671
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
Normal file
671
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
Normal file
@@ -0,0 +1,671 @@
|
||||
#pragma once
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_bfloat16.h>
|
||||
|
||||
#include "../../../../attention/attention_dtypes.h"
|
||||
|
||||
namespace vllm {
|
||||
#ifdef USE_ROCM
|
||||
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
template <typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
|
||||
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
|
||||
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
|
||||
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
|
||||
// the new HW cvt with something reasonable that doesn't rely on the
|
||||
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
|
||||
template <>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
|
||||
#if HIP_FP8_TYPE_OCP
|
||||
return c10::Float8_e4m3fn(
|
||||
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
|
||||
__hip_fp8_e4m3::__default_interpret),
|
||||
c10::Float8_e4m3fn::from_bits());
|
||||
#else
|
||||
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
|
||||
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
|
||||
return c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
|
||||
__hip_fp8_e4m3_fnuz::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
||||
const float scale) {
|
||||
return x;
|
||||
}
|
||||
|
||||
#if HIP_FP8_TYPE_OCP
|
||||
using fp8_type = __hip_fp8_e4m3;
|
||||
using fp8x2_type = __hip_fp8x2_e4m3;
|
||||
#else
|
||||
using fp8_type = __hip_fp8_e4m3_fnuz;
|
||||
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
|
||||
#endif
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
||||
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return __float2bfloat16(static_cast<float>(f8));
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t
|
||||
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return static_cast<float>(f8);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2
|
||||
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
||||
fp8x2_type f8x2;
|
||||
f8x2.__x = a;
|
||||
return static_cast<float2>(f8x2);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_
|
||||
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
|
||||
union {
|
||||
uint32_t ui32;
|
||||
__half2_raw h2r;
|
||||
} tmp;
|
||||
tmp.ui32 = a;
|
||||
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
||||
return __hip_cvt_float_to_fp8(__bfloat162float(a),
|
||||
fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
||||
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// float2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
vec_conversion<uint32_t, float2>(const float2& a) {
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
// Float4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
// float2 -> bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
|
||||
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> bfloat162x2
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t
|
||||
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
|
||||
bf16_4_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> bfloat162x4
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t
|
||||
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
||||
bf16_8_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
b.z = __float22bfloat162_rn(a.z);
|
||||
b.w = __float22bfloat162_rn(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||
precision domains
|
||||
|
||||
Convention of the scale in API, e.g: FP8_data = Quantization(
|
||||
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
|
||||
scale => HP
|
||||
|
||||
*/
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return __float2bfloat16(static_cast<float>(f8) * scale);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
||||
float scale) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||
res.y =
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t
|
||||
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||
scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t
|
||||
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||
const uint8_t& a, float scale) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return static_cast<float>(f8) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2
|
||||
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
|
||||
fp8x2_type f8x2;
|
||||
f8x2.__x = a;
|
||||
return static_cast<float2>(f8x2) * scale;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_
|
||||
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
||||
Float4_ res;
|
||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
|
||||
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||
return {res.x.x, res.x.y, res.y.x, res.y.y};
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_
|
||||
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
|
||||
__half_raw res;
|
||||
res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
tmp.h2r.x.data *= scale;
|
||||
tmp.h2r.y.data *= scale;
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||
tmp.u32[1] =
|
||||
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
|
||||
float scale) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
tmp.data /= scale;
|
||||
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// halfx2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
|
||||
union {
|
||||
uint32_t ui32;
|
||||
__half2_raw h2r;
|
||||
} tmp;
|
||||
tmp.ui32 = a;
|
||||
tmp.h2r.x.data /= scale;
|
||||
tmp.h2r.y.data /= scale;
|
||||
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// half2x2 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// half2x4 -> fp8x8
|
||||
template <>
|
||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
|
||||
float scale) {
|
||||
union {
|
||||
uint2 ui2[2];
|
||||
uint4 ui4;
|
||||
} tmp;
|
||||
tmp.ui4 = a;
|
||||
uint2 res;
|
||||
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
|
||||
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16& a, float scale) {
|
||||
return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
||||
fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// bf16x2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
|
||||
const __nv_bfloat162& a, float scale) {
|
||||
union {
|
||||
uint8_t ui8[2];
|
||||
uint16_t ui16;
|
||||
} tmp;
|
||||
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
|
||||
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
|
||||
return tmp.ui16;
|
||||
}
|
||||
|
||||
// bf16x4 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// bf16x8 -> fp8x8
|
||||
template <>
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
|
||||
uint2 res;
|
||||
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
|
||||
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
|
||||
return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// floatx2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
|
||||
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// floatx4 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout convert(const Tin& x) {
|
||||
#ifdef ENABLE_FP8
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return vec_conversion<Tout, Tin>(x);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
return {}; // Squash missing return statement warning
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
#ifdef ENABLE_FP8
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return scaled_vec_conversion<Tout, Tin>(x, scale);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
return {}; // Squash missing return statement warning
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the data type of the key and value cache. The FN is a macro that calls a
|
||||
// function with template<typename scalar_t, typename cache_t,
|
||||
// Fp8KVCacheDataType kv_dt>.
|
||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||
if (KV_DTYPE == "auto") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||
} \
|
||||
}
|
||||
|
||||
} // namespace fp8
|
||||
#endif // USE_ROCM
|
||||
} // namespace vllm
|
||||
247
csrc/quantization/w8a8/fp8/common.cu
Normal file
247
csrc/quantization/w8a8/fp8/common.cu
Normal file
@@ -0,0 +1,247 @@
|
||||
#include "common.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "quantization/vectorization_utils.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel_strided(
|
||||
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
||||
int64_t out_row_stride) {
|
||||
const int64_t token_idx = blockIdx.x; // one token per block
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const scalar_t* token_in = input + token_idx * in_row_stride;
|
||||
fp8_type* token_out = out + token_idx * out_row_stride;
|
||||
|
||||
const float inv_scale = 1.0f / (*scale);
|
||||
|
||||
vectorize_with_alignment<16>(
|
||||
token_in, token_out, hidden_size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
||||
inv_scale);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void segmented_max_reduction_strided(
|
||||
float* __restrict__ scale, const scalar_t* __restrict__ input,
|
||||
int hidden_size, int64_t in_row_stride, int64_t num_tokens) {
|
||||
__shared__ float cache[256];
|
||||
const int tid = threadIdx.x;
|
||||
int64_t token_idx = blockIdx.x;
|
||||
|
||||
// one block per token. Guard in case gridDim.x > num_tokens.
|
||||
if (token_idx >= num_tokens) {
|
||||
return;
|
||||
}
|
||||
|
||||
const scalar_t* row_ptr = input + token_idx * in_row_stride;
|
||||
|
||||
// each thread scans elements of the row in a strided fashion.
|
||||
float thread_max = 0.0f;
|
||||
for (int e = tid; e < hidden_size; e += blockDim.x) {
|
||||
float v = fabsf(static_cast<float>(row_ptr[e]));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
}
|
||||
|
||||
cache[tid] = thread_max;
|
||||
__syncthreads();
|
||||
|
||||
// parallel reduction to find row max.
|
||||
for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
|
||||
if (tid < offset) {
|
||||
cache[tid] = fmaxf(cache[tid], cache[tid + offset]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// thread 0 updates global scale (per-tensor) atomically.
|
||||
if (tid == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel_strided_dynamic(
|
||||
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
||||
int64_t out_row_stride) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const scalar_t* token_in = input + token_idx * in_row_stride;
|
||||
fp8_type* token_out = out + token_idx * out_row_stride;
|
||||
|
||||
const float reciprocal_scale = 1.0f / (*scale);
|
||||
vectorize_with_alignment<16>(
|
||||
token_in, token_out, hidden_size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
||||
reciprocal_scale);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
|
||||
fp8_type* __restrict__ out, float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input, const float* __restrict__ scale_ub,
|
||||
int hidden_size, int64_t in_row_stride, int64_t out_row_stride) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Use int64 to avoid overflowing an int32 when calculating this offset
|
||||
int64_t in_offset = static_cast<int64_t>(token_idx) * in_row_stride;
|
||||
int64_t out_offset = static_cast<int64_t>(token_idx) * out_row_stride;
|
||||
const scalar_t* token_in = input + in_offset;
|
||||
fp8_type* token_out = out + out_offset;
|
||||
|
||||
// 1) per-token absmax
|
||||
float absmax_val = 0.f;
|
||||
vectorize_read_with_alignment<16>(
|
||||
token_in, hidden_size, tid, blockDim.x, [&] __device__(scalar_t v) {
|
||||
absmax_val = fmaxf(absmax_val, fabsf(static_cast<float>(v)));
|
||||
});
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
const float block_max =
|
||||
BlockReduce(tmp).Reduce(absmax_val, CubMaxOp{}, blockDim.x);
|
||||
|
||||
__shared__ float token_scale;
|
||||
if (tid == 0) {
|
||||
token_scale = scale_ub ? fminf(block_max, *scale_ub) : block_max;
|
||||
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
|
||||
min_scaling_factor<fp8_type>::val());
|
||||
scale[token_idx] = token_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 2) quantize
|
||||
vectorize_with_alignment<16>(
|
||||
token_in, token_out, hidden_size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<false, fp8_type>(static_cast<float>(src),
|
||||
token_scale);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale) // [1]
|
||||
{
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
const int block_size = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(block_size);
|
||||
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
const int block_size = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(block_size);
|
||||
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// scale tensor should be initialised to <=0 before reduction
|
||||
AT_CUDA_CHECK(
|
||||
cudaMemsetAsync(scale.data_ptr<float>(), 0, sizeof(float), stream));
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::segmented_max_reduction_strided<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(),
|
||||
hidden_size, in_row_stride,
|
||||
static_cast<int64_t>(num_tokens));
|
||||
|
||||
vllm::scaled_fp8_quant_kernel_strided_dynamic<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
const int block_size = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, block_size));
|
||||
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<
|
||||
scalar_t, fp8_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
hidden_size, in_row_stride, out_row_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
62
csrc/quantization/w8a8/fp8/common.cuh
Normal file
62
csrc/quantization/w8a8/fp8/common.cuh
Normal file
@@ -0,0 +1,62 @@
|
||||
#pragma once
|
||||
|
||||
#include "quantization/vectorization.cuh"
|
||||
#include "quantization/utils.cuh"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include "nvidia/quant_utils.cuh"
|
||||
#else
|
||||
#include "amd/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
// Determines the preferred FP8 type for the current platform.
|
||||
// Note that for CUDA this just returns true,
|
||||
// but on ROCm it will check device props.
|
||||
static bool is_fp8_ocp() {
|
||||
#ifndef USE_ROCM
|
||||
return true;
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
size_t substring = device_arch.find("gfx94");
|
||||
return substring == std::string::npos;
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
float old;
|
||||
old = (value >= 0)
|
||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(
|
||||
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
|
||||
return old;
|
||||
}
|
||||
|
||||
template <bool is_scale_inverted, typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
x = val * scale;
|
||||
} else {
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r =
|
||||
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
|
||||
#ifndef USE_ROCM
|
||||
// Use hardware cvt instruction for fp8 on nvidia
|
||||
// Currently only support fp8_type = c10::Float8_e4m3fn
|
||||
return fp8::vec_conversion<fp8_type, float>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return fp8::cvt_c10<fp8_type>(r);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
597
csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh
Normal file
597
csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh
Normal file
@@ -0,0 +1,597 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../../../attention/attention_dtypes.h"
|
||||
#include <assert.h>
|
||||
#include <float.h>
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace vllm {
|
||||
#ifndef USE_ROCM
|
||||
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(
|
||||
const Tin& x, const __nv_fp8_interpretation_t fp8_type = __NV_E4M3) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// float -> c10::Float8_e4m3fn
|
||||
template <>
|
||||
__inline__ __device__ c10::Float8_e4m3fn
|
||||
vec_conversion<c10::Float8_e4m3fn, float>(
|
||||
const float& a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return static_cast<c10::Float8_e4m3fn>(a);
|
||||
#else
|
||||
return c10::Float8_e4m3fn(__nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type),
|
||||
c10::Float8_e4m3fn::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 0 // Disable the following code to reduce the binary size.
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
||||
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
||||
tmp.u16[0] = res.x;
|
||||
tmp.u16[1] = res.y;
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
|
||||
tmp.u32[1] =
|
||||
vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(
|
||||
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
|
||||
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
// half -> float -> bf16
|
||||
float tmp = half_to_float(res.x);
|
||||
return __float2bfloat16(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
|
||||
res.y =
|
||||
vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
|
||||
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float
|
||||
vec_conversion<float, uint8_t>(const uint8_t &a,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8 -> half
|
||||
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
|
||||
// half -> float
|
||||
return half_to_float(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
|
||||
// half2 -> float2
|
||||
return half2_to_float2(tmp);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
|
||||
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
|
||||
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
|
||||
const float &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
|
||||
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
|
||||
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(
|
||||
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
|
||||
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
|
||||
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_bfloat162 b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
|
||||
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
|
||||
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_8_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||
precision domains Convention of the scale in API, e.g: FP8_data =
|
||||
Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
|
||||
Dequant(FP8) * scale => HP
|
||||
*/
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout scaled_vec_conversion(
|
||||
const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
|
||||
const uint8_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
return float_to_half(half_to_float(tmp.x) * scale);
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
||||
tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
|
||||
tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] =
|
||||
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
|
||||
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
|
||||
scale, fp8_type);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4
|
||||
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
|
||||
const uint8_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
// half -> float -> bf16
|
||||
float tmp = half_to_float(res.x);
|
||||
return __float2bfloat16(tmp * scale);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
|
||||
fp8_type);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
|
||||
scale, fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
|
||||
fp8_type);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||
scale, fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
|
||||
const uint2& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||
const uint8_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
uint16_t tmp = res.x;
|
||||
|
||||
// half -> float
|
||||
return half_to_float(tmp) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
|
||||
// half2 -> float2
|
||||
return half2_to_float2(tmp);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ res;
|
||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
|
||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
|
||||
fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
|
||||
const uint2& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
||||
__NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
|
||||
const float& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout convert(const Tin& x) {
|
||||
#if 0 // Disable the following code to reduce the binary size.
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
|
||||
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
#ifdef ENABLE_FP8
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
|
||||
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the data type of the key and value cache. The FN is a macro that calls a
|
||||
// function with template<typename scalar_t, typename cache_t,
|
||||
// Fp8KVCacheDataType kv_dt>.
|
||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||
if (KV_DTYPE == "auto") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else if (KV_DTYPE == "fp8_e5m2") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else if (KV_DTYPE == "fp8_ds_mla") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||
} \
|
||||
}
|
||||
|
||||
} // namespace fp8
|
||||
#endif // not USE_ROCM
|
||||
} // namespace vllm
|
||||
385
csrc/quantization/w8a8/fp8/per_token_group_quant.cu
Normal file
385
csrc/quantization/w8a8/fp8/per_token_group_quant.cu
Normal file
@@ -0,0 +1,385 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "quantization/w8a8/per_token_group_quant_8bit.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "quantization/vectorization.cuh"
|
||||
#include "quantization/vectorization_utils.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
__device__ __forceinline__ float GroupReduceMax(float val) {
|
||||
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
|
||||
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, bool SCALE_UE8M0>
|
||||
__device__ __forceinline__ float ComputeGroupScale(
|
||||
const T* __restrict__ group_input, T* __restrict__ smem_group,
|
||||
const int group_size, const int lane_id, const int threads_per_group,
|
||||
const float eps, const float max_8bit) {
|
||||
float local_absmax = eps;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
}
|
||||
|
||||
return y_s;
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE>
|
||||
__device__ __forceinline__ void QuantizeGroup(
|
||||
const T* __restrict__ smem_group, DST_DTYPE* __restrict__ group_output,
|
||||
const int group_size, const int lane_id, const int threads_per_group,
|
||||
const float y_s, const float min_8bit, const float max_8bit) {
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
|
||||
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
|
||||
__global__ void per_token_group_quant_8bit_kernel(
|
||||
const T* __restrict__ input, void* __restrict__ output_q,
|
||||
scale_packed_t* __restrict__ output_s, const int group_size,
|
||||
const int num_groups, const int groups_per_block, const float eps,
|
||||
const float min_8bit, const float max_8bit, const int scale_num_rows = 0,
|
||||
const int scale_stride = 0) {
|
||||
const int threads_per_group = 16;
|
||||
const int64_t local_group_id = threadIdx.x / threads_per_group;
|
||||
const int lane_id = threadIdx.x % threads_per_group;
|
||||
|
||||
const int64_t block_group_id = blockIdx.x * groups_per_block;
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
using scale_element_t = float;
|
||||
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
|
||||
|
||||
const T* group_input = input + block_group_offset;
|
||||
DST_DTYPE* group_output =
|
||||
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||
scale_element_t* scale_output;
|
||||
|
||||
if constexpr (IS_COLUMN_MAJOR) {
|
||||
const int num_elems_per_pack =
|
||||
static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
|
||||
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
|
||||
const int row_idx = global_group_id / scale_num_rows_element;
|
||||
const int col_idx_raw = global_group_id % scale_num_rows_element;
|
||||
const int col_idx = col_idx_raw / num_elems_per_pack;
|
||||
const int pack_idx = col_idx_raw % num_elems_per_pack;
|
||||
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
|
||||
(col_idx * scale_stride * num_elems_per_pack +
|
||||
row_idx * num_elems_per_pack + pack_idx);
|
||||
} else {
|
||||
scale_output = output_s + global_group_id;
|
||||
}
|
||||
|
||||
// shared memory to cache each group's data to avoid double DRAM reads.
|
||||
extern __shared__ __align__(16) char smem_raw[];
|
||||
T* smem = reinterpret_cast<T*>(smem_raw);
|
||||
T* smem_group = smem + local_group_id * group_size;
|
||||
|
||||
const float y_s = ComputeGroupScale<T, SCALE_UE8M0>(
|
||||
group_input, smem_group, group_size, lane_id, threads_per_group, eps,
|
||||
max_8bit);
|
||||
|
||||
scale_element_t y_s_quant = y_s;
|
||||
|
||||
if (lane_id == 0) {
|
||||
*scale_output = y_s_quant;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
|
||||
threads_per_group, y_s, min_8bit, max_8bit);
|
||||
}
|
||||
|
||||
inline int GetGroupsPerBlock(int64_t num_groups) {
|
||||
if (num_groups % 16 == 0) {
|
||||
return 16;
|
||||
}
|
||||
if (num_groups % 8 == 0) {
|
||||
return 8;
|
||||
}
|
||||
if (num_groups % 4 == 0) {
|
||||
return 4;
|
||||
}
|
||||
if (num_groups % 2 == 0) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s, int64_t group_size,
|
||||
double eps, double min_8bit, double max_8bit,
|
||||
bool scale_ue8m0) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(output_q.is_contiguous());
|
||||
|
||||
const int num_groups = input.numel() / group_size;
|
||||
|
||||
TORCH_CHECK(input.numel() % group_size == 0);
|
||||
TORCH_CHECK(output_s.dim() == 2);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
const int groups_per_block = GetGroupsPerBlock(num_groups);
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
const int num_threads = groups_per_block * THREADS_PER_GROUP;
|
||||
|
||||
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
|
||||
const int scale_num_rows = output_s.size(1);
|
||||
const int scale_stride = output_s.stride(1);
|
||||
|
||||
#define LAUNCH_KERNEL(T, DST_DTYPE) \
|
||||
do { \
|
||||
dim3 grid(num_blocks); \
|
||||
dim3 block(num_threads); \
|
||||
size_t smem_bytes = \
|
||||
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
|
||||
if (is_column_major) { \
|
||||
if (scale_ue8m0) { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit, scale_num_rows, scale_stride); \
|
||||
} else { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit, scale_num_rows, scale_stride); \
|
||||
} \
|
||||
} else { \
|
||||
if (scale_ue8m0) { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, true> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit); \
|
||||
} else { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, false> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
|
||||
if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
||||
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
|
||||
} else if (dst_type == at::ScalarType::Char) {
|
||||
LAUNCH_KERNEL(scalar_t, int8_t);
|
||||
}
|
||||
}));
|
||||
|
||||
#undef LAUNCH_KERNEL
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE>
|
||||
__global__ void per_token_group_quant_8bit_packed_kernel(
|
||||
const T* __restrict__ input, void* __restrict__ output_q,
|
||||
unsigned int* __restrict__ output_s_packed, const int group_size,
|
||||
const int num_groups, const int groups_per_block, const int groups_per_row,
|
||||
const int mn, const int tma_aligned_mn, const float eps,
|
||||
const float min_8bit, const float max_8bit) {
|
||||
const int threads_per_group = 16;
|
||||
const int64_t local_group_id = threadIdx.x / threads_per_group;
|
||||
const int lane_id = threadIdx.x % threads_per_group;
|
||||
|
||||
const int64_t block_group_id = blockIdx.x * groups_per_block;
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
if (global_group_id >= num_groups) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
const T* group_input = input + block_group_offset;
|
||||
DST_DTYPE* group_output =
|
||||
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||
|
||||
// shared memory to cache each group's data to avoid double DRAM reads.
|
||||
extern __shared__ __align__(16) char smem_raw[];
|
||||
T* smem = reinterpret_cast<T*>(smem_raw);
|
||||
T* smem_group = smem + local_group_id * group_size;
|
||||
const float y_s =
|
||||
ComputeGroupScale<T, true>(group_input, smem_group, group_size, lane_id,
|
||||
threads_per_group, eps, max_8bit);
|
||||
|
||||
// pack 4 scales into a uint32
|
||||
if (lane_id == 0) {
|
||||
// map flat group id to 2D indices (mn_idx, sf_k_idx)
|
||||
const int sf_k_idx = static_cast<int>(global_group_id % groups_per_row);
|
||||
const int mn_idx = static_cast<int>(global_group_id / groups_per_row);
|
||||
|
||||
if (mn_idx < mn) {
|
||||
// each uint32 in output_s_packed stores 4 packed scales
|
||||
const int sf_k_pack_idx = sf_k_idx / 4;
|
||||
const int pos = sf_k_idx % 4;
|
||||
|
||||
// reinterpret the UE8M0 scale y_s as IEEE bits, extract the 8-bit
|
||||
// exponent, and place it into the correct byte of the 32-bit word.
|
||||
const unsigned int bits = __float_as_uint(y_s);
|
||||
const unsigned int exponent = (bits >> 23u) & 0xffu;
|
||||
const unsigned int contrib = exponent << (pos * 8u);
|
||||
|
||||
const int out_idx = sf_k_pack_idx * tma_aligned_mn + mn_idx;
|
||||
// atomically OR 8-bit exponent into the packed scales buffer
|
||||
atomicOr(output_s_packed + out_idx, contrib);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
|
||||
threads_per_group, y_s, min_8bit, max_8bit);
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s_packed,
|
||||
int64_t group_size, double eps,
|
||||
double min_8bit, double max_8bit) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(output_q.is_contiguous());
|
||||
|
||||
const int64_t k = input.size(-1);
|
||||
TORCH_CHECK(k % group_size == 0, "Last dimension (", k,
|
||||
") must be divisible by group_size (", group_size, ").");
|
||||
|
||||
const int64_t mn = input.numel() / k;
|
||||
const int64_t groups_per_row = k / group_size;
|
||||
const int64_t num_groups = mn * groups_per_row;
|
||||
|
||||
TORCH_CHECK(output_s_packed.dim() == 2,
|
||||
"output_s_packed must be 2D, got dim=", output_s_packed.dim(),
|
||||
".");
|
||||
|
||||
const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4;
|
||||
const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4;
|
||||
|
||||
TORCH_CHECK(output_s_packed.scalar_type() == at::ScalarType::Int,
|
||||
"output_s_packed must have dtype int32 for UE8M0-packed scales.");
|
||||
// DeepGEMM expects SFA scales in MN-major form with shape
|
||||
// [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last
|
||||
// dimension.
|
||||
TORCH_CHECK(output_s_packed.size(0) == mn &&
|
||||
output_s_packed.size(1) == k_num_packed_sfk,
|
||||
"output_s_packed shape must be [", mn, ", ", k_num_packed_sfk,
|
||||
"], but got [", output_s_packed.size(0), ", ",
|
||||
output_s_packed.size(1), "].");
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
const int groups_per_block = GetGroupsPerBlock(num_groups);
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
const int num_threads = groups_per_block * THREADS_PER_GROUP;
|
||||
|
||||
// zero-initialize packed scales, since we use atomicOr to accumulate
|
||||
// exponents from different groups.
|
||||
output_s_packed.zero_();
|
||||
|
||||
#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \
|
||||
do { \
|
||||
dim3 grid(num_blocks); \
|
||||
dim3 block(num_threads); \
|
||||
size_t smem_bytes = \
|
||||
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
|
||||
per_token_group_quant_8bit_packed_kernel<T, DST_DTYPE> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
|
||||
static_cast<int>(group_size), static_cast<int>(num_groups), \
|
||||
groups_per_block, static_cast<int>(groups_per_row), \
|
||||
static_cast<int>(mn), static_cast<int>(tma_aligned_mn), \
|
||||
static_cast<float>(eps), static_cast<float>(min_8bit), \
|
||||
static_cast<float>(max_8bit)); \
|
||||
} while (0)
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] {
|
||||
if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
||||
LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3);
|
||||
} else if (dst_type == at::ScalarType::Char) {
|
||||
LAUNCH_PACKED_KERNEL(scalar_t, int8_t);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"per_token_group_quant_8bit_packed only supports FP8/INT8 "
|
||||
"outputs.");
|
||||
}
|
||||
}));
|
||||
|
||||
#undef LAUNCH_PACKED_KERNEL
|
||||
}
|
||||
|
||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||
int64_t group_size, double eps, double fp8_min,
|
||||
double fp8_max, bool scale_ue8m0) {
|
||||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||
fp8_min, fp8_max, scale_ue8m0);
|
||||
}
|
||||
Reference in New Issue
Block a user