HIP: Cleanup hipification header (#15285)
add expicit conversion operator to support older versions of rocm Switch over to hip_bf16 from legacy hip_bfloat16 Simplify RDNA3 define Reduce swap over of new hipblas api to rocm 6.5 as this version is used for rocm 7.0 previews --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#include "getrows.cuh"
|
||||
#include "dequantize.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void k_get_rows(
|
||||
@@ -34,8 +35,8 @@ static __global__ void k_get_rows(
|
||||
dfloat2 v;
|
||||
dequantize_kernel(src0_row, ib, iqs, v);
|
||||
|
||||
dst_row[iybs + iqs + 0] = float(v.x);
|
||||
dst_row[iybs + iqs + y_offset] = float(v.y);
|
||||
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
}
|
||||
|
||||
template<typename src0_t, typename dst_t>
|
||||
@@ -62,7 +63,7 @@ static __global__ void k_get_rows_float(
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
||||
|
||||
dst_row[i00] = float(src0_row[i00]);
|
||||
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
||||
}
|
||||
|
||||
template<typename grad_t, typename dst_t>
|
||||
|
||||
Reference in New Issue
Block a user