refine quant kernel code style (#4211)
This commit is contained in:
@@ -37,18 +37,7 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
|
|||||||
max_value = fmaxf(max_value, fabsf(val));
|
max_value = fmaxf(max_value, fabsf(val));
|
||||||
}
|
}
|
||||||
|
|
||||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
max_value = blockReduceMax(max_value);
|
||||||
const int laneId = threadIdx.x % WARP_SIZE;
|
|
||||||
const int warpId = threadIdx.x / WARP_SIZE;
|
|
||||||
|
|
||||||
max_value = warpReduceMax(max_value);
|
|
||||||
|
|
||||||
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
|
||||||
|
|
||||||
if (warpId == 0) max_value = warpReduceMax(max_value);
|
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);
|
atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);
|
||||||
|
|||||||
@@ -30,19 +30,7 @@ __global__ void per_token_quant_fp8_kernel(
|
|||||||
max_value = fmaxf(max_value, fabsf(val));
|
max_value = fmaxf(max_value, fabsf(val));
|
||||||
}
|
}
|
||||||
|
|
||||||
max_value = warpReduceMax(max_value);
|
max_value = blockReduceMax(max_value);
|
||||||
|
|
||||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
|
||||||
const int laneId = threadIdx.x % WARP_SIZE;
|
|
||||||
const int warpId = threadIdx.x / WARP_SIZE;
|
|
||||||
|
|
||||||
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
if (warpId == 0) {
|
|
||||||
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
|
||||||
max_value = warpReduceMax(max_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
__shared__ float block_max;
|
__shared__ float block_max;
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
|
|||||||
@@ -124,4 +124,20 @@ __device__ __forceinline__ float warpReduceMax(float max_value) {
|
|||||||
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
|
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
|
||||||
return max_value;
|
return max_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ float blockReduceMax(float max_value) {
|
||||||
|
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||||
|
const int laneId = threadIdx.x % WARP_SIZE;
|
||||||
|
const int warpId = threadIdx.x / WARP_SIZE;
|
||||||
|
|
||||||
|
max_value = warpReduceMax(max_value);
|
||||||
|
|
||||||
|
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||||
|
if (warpId == 0) max_value = warpReduceMax(max_value);
|
||||||
|
|
||||||
|
return max_value;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
Reference in New Issue
Block a user