CUDA: add bf16 and f32 support to cublas_mul_mat_batched (#14361)
* CUDA: add bf16 and f32 support to cublas_mul_mat_batched * Review: add type traits and make function more generic * Review: make check more explicit, add back comments, and fix formatting * Review: fix formatting, remove useless type conversion, fix naming for bools
This commit is contained in:
@@ -22,5 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
|
||||
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
|
||||
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
|
||||
|
||||
typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
|
||||
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
|
||||
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
|
||||
|
||||
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
|
||||
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
|
||||
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
|
||||
|
||||
Reference in New Issue
Block a user