@@ -57,9 +57,13 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void per_tensor_quant_fp8_kernel(
|
__global__ void per_tensor_quant_fp8_kernel(
|
||||||
const T* __restrict__ input, FP8_TYPE* __restrict__ output, const float scale_val, const int64_t num_elements) {
|
const T* __restrict__ input,
|
||||||
|
FP8_TYPE* __restrict__ output,
|
||||||
|
const float* __restrict__ scale,
|
||||||
|
const int64_t num_elements) {
|
||||||
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const int grid_size = blockDim.x * gridDim.x;
|
const int grid_size = blockDim.x * gridDim.x;
|
||||||
|
const float scale_val = 1.0f / (*scale);
|
||||||
|
|
||||||
constexpr uint32_t vec_size = 16 / sizeof(T);
|
constexpr uint32_t vec_size = 16 / sizeof(T);
|
||||||
using vec_t = flashinfer::vec_t<T, vec_size>;
|
using vec_t = flashinfer::vec_t<T, vec_size>;
|
||||||
@@ -121,9 +125,12 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
|
|||||||
per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
per_tensor_absmax_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
|
static_cast<scalar_t*>(input.data_ptr()), static_cast<float*>(output_s.data_ptr()), num_elements);
|
||||||
}
|
}
|
||||||
float scale_val = 1.0f / (*static_cast<float*>(output_s.data_ptr()));
|
|
||||||
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
per_tensor_quant_fp8_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
static_cast<scalar_t*>(input.data_ptr()), static_cast<FP8_TYPE*>(output_q.data_ptr()), scale_val, num_elements);
|
static_cast<scalar_t*>(input.data_ptr()),
|
||||||
|
static_cast<FP8_TYPE*>(output_q.data_ptr()),
|
||||||
|
static_cast<float*>(output_s.data_ptr()),
|
||||||
|
num_elements);
|
||||||
return true;
|
return true;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user