refine quant kernel code style (#4211)

This commit is contained in:
Xiaoyu Zhang
2025-03-08 21:47:35 +08:00
committed by GitHub
parent 2cadd51d11
commit b3251e9f40
3 changed files with 18 additions and 25 deletions

View File

@@ -37,18 +37,7 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
max_value = fmaxf(max_value, fabsf(val));
}
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
max_value = warpReduceMax(max_value);
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) max_value = warpReduceMax(max_value);
max_value = blockReduceMax(max_value);
if (tid == 0) {
atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);

View File

@@ -30,19 +30,7 @@ __global__ void per_token_quant_fp8_kernel(
max_value = fmaxf(max_value, fabsf(val));
}
max_value = warpReduceMax(max_value);
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
if (warpId == 0) {
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
max_value = warpReduceMax(max_value);
}
max_value = blockReduceMax(max_value);
__shared__ float block_max;
if (tid == 0) {

View File

@@ -124,4 +124,20 @@ __device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
return max_value;
}
__device__ __forceinline__ float blockReduceMax(float max_value) {
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
max_value = warpReduceMax(max_value);
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) max_value = warpReduceMax(max_value);
return max_value;
}
#endif