Fp4 MOE quant kernel optimization (#8777)
Co-authored-by: Rain Jiang <96632942+rainj-me@users.noreply.github.com>
This commit is contained in:
@@ -240,12 +240,113 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Use UE4M3 by default.
|
// Use UE4M3 by default.
|
||||||
template <class Type, bool UE8M0_SF = false>
|
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||||
__global__ void
|
__global__ void
|
||||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||||
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
|
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
|
||||||
#else
|
#else
|
||||||
cvt_fp16_to_fp4(
|
cvt_fp16_to_fp4(
|
||||||
|
#endif
|
||||||
|
int32_t numRows,
|
||||||
|
int32_t numCols,
|
||||||
|
Type const* in,
|
||||||
|
float const* SFScale,
|
||||||
|
uint32_t* out,
|
||||||
|
uint32_t* SFout,
|
||||||
|
uint32_t* input_offset_by_experts,
|
||||||
|
uint32_t* output_scale_offset_by_experts,
|
||||||
|
int n_experts,
|
||||||
|
bool low_latency) {
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||||
|
using PackedVec = PackedVec<Type>;
|
||||||
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||||
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
|
||||||
|
|
||||||
|
// Input tensor row/col loops.
|
||||||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||||
|
|
||||||
|
// Each global thread processes one element
|
||||||
|
for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {
|
||||||
|
// Calculate which row and column this global thread should process
|
||||||
|
int rowIdx = globalIdx / colsPerRow;
|
||||||
|
int colIdx = globalIdx % colsPerRow;
|
||||||
|
|
||||||
|
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
||||||
|
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||||
|
// Get the output tensor offset.
|
||||||
|
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||||
|
int64_t outOffset = inOffset;
|
||||||
|
auto& out_pos = out[outOffset];
|
||||||
|
|
||||||
|
// Find index within the experts using different strategies based on expert
|
||||||
|
// count
|
||||||
|
int rowIdx_in_expert = 0;
|
||||||
|
int expert_idx = 0;
|
||||||
|
|
||||||
|
if constexpr (SMALL_NUM_EXPERTS) {
|
||||||
|
for (int i = 0; i < n_experts; i++) {
|
||||||
|
uint32_t current_offset = __ldca(&input_offset_by_experts[i]);
|
||||||
|
uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]);
|
||||||
|
if (rowIdx >= current_offset && rowIdx < next_offset) {
|
||||||
|
rowIdx_in_expert = rowIdx - current_offset;
|
||||||
|
expert_idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Load input offsets into registers first, then do the computation.
|
||||||
|
// Local array size set to 17 because of register limit.
|
||||||
|
uint32_t local_offsets[17];
|
||||||
|
for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) {
|
||||||
|
*reinterpret_cast<int4*>(local_offsets) =
|
||||||
|
__ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start]));
|
||||||
|
*reinterpret_cast<int4*>(local_offsets + 4) =
|
||||||
|
__ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start + 4]));
|
||||||
|
*reinterpret_cast<int4*>(local_offsets + 8) =
|
||||||
|
__ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start + 8]));
|
||||||
|
*reinterpret_cast<int4*>(local_offsets + 12) =
|
||||||
|
__ldca(reinterpret_cast<const int4*>(&input_offset_by_experts[chunk_start + 12]));
|
||||||
|
local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]);
|
||||||
|
|
||||||
|
// Check against the 16 loaded offsets
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 16; i++) {
|
||||||
|
if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) {
|
||||||
|
rowIdx_in_expert = rowIdx - local_offsets[i];
|
||||||
|
expert_idx = chunk_start + i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the global scaling factor, which will be applied to the SF.
|
||||||
|
// Note SFScale is the same as next GEMM's alpha, which is
|
||||||
|
// (448.f / (Alpha_A / 6.f)).
|
||||||
|
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||||
|
|
||||||
|
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||||
|
// The actual output_scales dim is computed from the padded numCols.
|
||||||
|
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||||
|
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||||
|
uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
||||||
|
|
||||||
|
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
|
||||||
|
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||||
|
|
||||||
|
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
|
||||||
|
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||||
|
__global__ void
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||||
|
__launch_bounds__(1024, 4) cvt_fp16_to_fp4(
|
||||||
|
#else
|
||||||
|
cvt_fp16_to_fp4(
|
||||||
#endif
|
#endif
|
||||||
int32_t numRows,
|
int32_t numRows,
|
||||||
int32_t numCols,
|
int32_t numCols,
|
||||||
@@ -260,44 +361,75 @@ cvt_fp16_to_fp4(
|
|||||||
using PackedVec = PackedVec<Type>;
|
using PackedVec = PackedVec<Type>;
|
||||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
|
||||||
|
extern __shared__ uint32_t shared_input_offsets[];
|
||||||
|
|
||||||
// Input tensor row/col loops.
|
// Load input offsets into shared memory.
|
||||||
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
|
// If n_experts is larger than 4, use vectorized int4 to save instructions.
|
||||||
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) {
|
// If n_experts is smaller than 4, read directly.
|
||||||
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
if constexpr (SMALL_NUM_EXPERTS) {
|
||||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) {
|
||||||
// Get the output tensor offset.
|
shared_input_offsets[i] = input_offset_by_experts[i];
|
||||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
|
||||||
int64_t outOffset = inOffset;
|
|
||||||
auto& out_pos = out[outOffset];
|
|
||||||
|
|
||||||
// Find index within the experts.
|
|
||||||
int rowIdx_in_expert = 0;
|
|
||||||
int expert_idx = 0;
|
|
||||||
for (int i = 0; i < n_experts; i++) {
|
|
||||||
if (rowIdx >= input_offset_by_experts[i] && rowIdx < input_offset_by_experts[i + 1]) {
|
|
||||||
rowIdx_in_expert = rowIdx - input_offset_by_experts[i];
|
|
||||||
expert_idx = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the global scaling factor, which will be applied to the SF.
|
|
||||||
// Note SFScale is the same as next GEMM's alpha, which is
|
|
||||||
// (448.f / (Alpha_A / 6.f)).
|
|
||||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
|
||||||
|
|
||||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
|
||||||
// The actual output_scales dim is computed from the padded numCols.
|
|
||||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
|
||||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
|
||||||
uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
|
||||||
|
|
||||||
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
|
|
||||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
|
||||||
|
|
||||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) {
|
||||||
|
*reinterpret_cast<int4*>(&shared_input_offsets[i]) = *reinterpret_cast<const int4*>(&input_offset_by_experts[i]);
|
||||||
|
}
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||||
|
|
||||||
|
// Each global thread processes one element
|
||||||
|
for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {
|
||||||
|
// Calculate which row and column this global thread should process
|
||||||
|
int rowIdx = globalIdx / colsPerRow;
|
||||||
|
int colIdx = globalIdx % colsPerRow;
|
||||||
|
|
||||||
|
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
||||||
|
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||||
|
int64_t outOffset = inOffset;
|
||||||
|
auto& out_pos = out[outOffset];
|
||||||
|
|
||||||
|
// Find expert using binary search for better performance with large m_topk
|
||||||
|
int rowIdx_in_expert = 0;
|
||||||
|
int expert_idx = 0;
|
||||||
|
|
||||||
|
// Binary search through experts using shared memory
|
||||||
|
int left = 0, right = n_experts - 1;
|
||||||
|
while (left <= right) {
|
||||||
|
int mid = (left + right) / 2;
|
||||||
|
// Get offsets: shared_input_offsets[i] corresponds to
|
||||||
|
// input_offset_by_experts[i]
|
||||||
|
uint32_t mid_offset = shared_input_offsets[mid];
|
||||||
|
uint32_t next_offset = shared_input_offsets[mid + 1];
|
||||||
|
|
||||||
|
if (rowIdx >= mid_offset && rowIdx < next_offset) {
|
||||||
|
rowIdx_in_expert = rowIdx - mid_offset;
|
||||||
|
expert_idx = mid;
|
||||||
|
break;
|
||||||
|
} else if (rowIdx < mid_offset) {
|
||||||
|
right = mid - 1;
|
||||||
|
} else {
|
||||||
|
left = mid + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||||
|
|
||||||
|
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||||
|
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||||
|
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||||
|
uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
||||||
|
|
||||||
|
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
|
||||||
|
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||||
|
|
||||||
|
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@@ -322,21 +454,70 @@ void quant_impl(
|
|||||||
|
|
||||||
// Grid, Block size.
|
// Grid, Block size.
|
||||||
// Each thread converts 8 values.
|
// Each thread converts 8 values.
|
||||||
dim3 block(std::min(int(k / ELTS_PER_THREAD), 512));
|
int const workSizePerRow = k / ELTS_PER_THREAD;
|
||||||
|
int const totalWorkSize = m_topk * workSizePerRow;
|
||||||
|
dim3 block(std::min(workSizePerRow, 512));
|
||||||
// Get number of blocks per SM (assume we can fully utilize the SM).
|
// Get number of blocks per SM (assume we can fully utilize the SM).
|
||||||
int const numBlocksPerSM = 2048 / block.x;
|
int const numBlocksPerSM = 2048 / block.x;
|
||||||
dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM));
|
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x), multiProcessorCount * numBlocksPerSM));
|
||||||
|
while (grid.x <= multiProcessorCount && block.x > 64) {
|
||||||
|
grid.x *= 2;
|
||||||
|
block.x = (block.x + 1) / 2;
|
||||||
|
}
|
||||||
|
|
||||||
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
|
int const blockRepeat = (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
|
||||||
m_topk,
|
if (blockRepeat > 1) {
|
||||||
k,
|
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
|
||||||
reinterpret_cast<T*>(input),
|
if (n_experts >= 4) {
|
||||||
reinterpret_cast<float*>(input_global_scale),
|
cvt_fp16_to_fp4<T, false, false><<<grid, block, shared_mem_size, stream>>>(
|
||||||
reinterpret_cast<uint32_t*>(output),
|
m_topk,
|
||||||
reinterpret_cast<uint32_t*>(output_scale),
|
k,
|
||||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
reinterpret_cast<T*>(input),
|
||||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
reinterpret_cast<float*>(input_global_scale),
|
||||||
n_experts);
|
reinterpret_cast<uint32_t*>(output),
|
||||||
|
reinterpret_cast<uint32_t*>(output_scale),
|
||||||
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||||
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||||
|
n_experts);
|
||||||
|
} else {
|
||||||
|
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
|
||||||
|
m_topk,
|
||||||
|
k,
|
||||||
|
reinterpret_cast<T*>(input),
|
||||||
|
reinterpret_cast<float*>(input_global_scale),
|
||||||
|
reinterpret_cast<uint32_t*>(output),
|
||||||
|
reinterpret_cast<uint32_t*>(output_scale),
|
||||||
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||||
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||||
|
n_experts);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (n_experts >= 16) {
|
||||||
|
cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>(
|
||||||
|
m_topk,
|
||||||
|
k,
|
||||||
|
reinterpret_cast<T*>(input),
|
||||||
|
reinterpret_cast<float*>(input_global_scale),
|
||||||
|
reinterpret_cast<uint32_t*>(output),
|
||||||
|
reinterpret_cast<uint32_t*>(output_scale),
|
||||||
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||||
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||||
|
n_experts,
|
||||||
|
/* bool low_latency */ true);
|
||||||
|
} else {
|
||||||
|
cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>(
|
||||||
|
m_topk,
|
||||||
|
k,
|
||||||
|
reinterpret_cast<T*>(input),
|
||||||
|
reinterpret_cast<float*>(input_global_scale),
|
||||||
|
reinterpret_cast<uint32_t*>(output),
|
||||||
|
reinterpret_cast<uint32_t*>(output_scale),
|
||||||
|
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||||
|
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||||
|
n_experts,
|
||||||
|
/* bool low_latency */ true);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*Quantization entry for fp4 experts quantization*/
|
/*Quantization entry for fp4 experts quantization*/
|
||||||
|
|||||||
Reference in New Issue
Block a user