refine quant kernel code style (#4211)
This commit is contained in:
@@ -37,18 +37,7 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
max_value = warpReduceMax(max_value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
|
||||
__syncthreads();
|
||||
|
||||
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
|
||||
if (warpId == 0) max_value = warpReduceMax(max_value);
|
||||
max_value = blockReduceMax(max_value);
|
||||
|
||||
if (tid == 0) {
|
||||
atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);
|
||||
|
||||
@@ -30,19 +30,7 @@ __global__ void per_token_quant_fp8_kernel(
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
|
||||
max_value = warpReduceMax(max_value);
|
||||
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
|
||||
__syncthreads();
|
||||
|
||||
if (warpId == 0) {
|
||||
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
max_value = warpReduceMax(max_value);
|
||||
}
|
||||
max_value = blockReduceMax(max_value);
|
||||
|
||||
__shared__ float block_max;
|
||||
if (tid == 0) {
|
||||
|
||||
@@ -124,4 +124,20 @@ __device__ __forceinline__ float warpReduceMax(float max_value) {
|
||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
|
||||
return max_value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float blockReduceMax(float max_value) {
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
max_value = warpReduceMax(max_value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
|
||||
__syncthreads();
|
||||
|
||||
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) max_value = warpReduceMax(max_value);
|
||||
|
||||
return max_value;
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user