[NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant perf (#9556)
This commit is contained in:
@@ -347,7 +347,7 @@ cvt_fp16_to_fp4(
|
||||
}
|
||||
}
|
||||
|
||||
// Eerly exit when using masks.
|
||||
// Early exit when using masks.
|
||||
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
|
||||
continue;
|
||||
}
|
||||
@@ -383,6 +383,107 @@ cvt_fp16_to_fp4(
|
||||
#endif
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
__launch_bounds__(512, 4) cvt_fp16_to_fp4_expert(
|
||||
#else
|
||||
cvt_fp16_to_fp4_expert(
|
||||
#endif
|
||||
int32_t numRows,
|
||||
int32_t numCols,
|
||||
Type const* in,
|
||||
float const* SFScale,
|
||||
uint32_t* out,
|
||||
uint32_t* SFout,
|
||||
int32_t* mask,
|
||||
bool use_silu_and_mul,
|
||||
int n_experts) {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
|
||||
|
||||
// Input tensor row/col loops.
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = (gridDim.x * blockDim.x) / n_experts;
|
||||
int remainder = (gridDim.x * blockDim.x) % n_experts;
|
||||
int expert_idx;
|
||||
int tid_in_expert;
|
||||
int actual_stride;
|
||||
if (remainder > 0) {
|
||||
int bound = remainder * (stride + 1);
|
||||
if (tid < bound) {
|
||||
expert_idx = tid / (stride + 1);
|
||||
tid_in_expert = tid % (stride + 1);
|
||||
actual_stride = stride + 1;
|
||||
} else {
|
||||
expert_idx = remainder + (tid - bound) / stride;
|
||||
tid_in_expert = (tid - bound) % stride;
|
||||
actual_stride = stride;
|
||||
}
|
||||
} else {
|
||||
expert_idx = tid / stride;
|
||||
tid_in_expert = tid % stride;
|
||||
actual_stride = stride;
|
||||
}
|
||||
int m = numRows / n_experts;
|
||||
int padded_m = (m + (128 - 1)) / 128 * 128;
|
||||
|
||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
// TODO(kaixih@nvidia): For now, we assume mask is used together with
|
||||
// silu_and_mal. Maybe we want a more general behavior of mask later. In the
|
||||
// silu case, the input last dim doubles.
|
||||
bool use_mask = mask != nullptr;
|
||||
int actualColsPerRow = use_silu_and_mul ? colsPerRow * 2 : colsPerRow;
|
||||
|
||||
// Each global thread processes one element
|
||||
for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow; globalIdx < (expert_idx + 1) * m * colsPerRow;
|
||||
globalIdx += actual_stride) {
|
||||
// Calculate which row and column this global thread should process
|
||||
int rowIdx = globalIdx / colsPerRow;
|
||||
int colIdx = globalIdx % colsPerRow;
|
||||
|
||||
// Find index within the experts
|
||||
int rowIdx_in_expert = rowIdx - expert_idx * m;
|
||||
|
||||
// Early exit when using masks.
|
||||
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
|
||||
break;
|
||||
}
|
||||
|
||||
int64_t inOffset = rowIdx * actualColsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
if (use_silu_and_mul) {
|
||||
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
|
||||
silu_and_mul(in_vec, in_vec_mul);
|
||||
}
|
||||
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = rowIdx * colsPerRow + colIdx;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
// The actual output_scales dim is computed from the padded numCols.
|
||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||
uint32_t* SFout_in_expert = SFout + expert_idx * padded_m * numCols_SFout;
|
||||
|
||||
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void
|
||||
@@ -499,6 +600,7 @@ void quant_impl(
|
||||
void* input_offset_by_experts,
|
||||
void* output_scale_offset_by_experts,
|
||||
void* mask,
|
||||
bool use_silu_and_mul,
|
||||
int m_topk,
|
||||
int k,
|
||||
int n_experts,
|
||||
@@ -522,6 +624,22 @@ void quant_impl(
|
||||
block.x = (block.x + 1) / 2;
|
||||
}
|
||||
|
||||
// TODO(kaixih@nvidia): Should relax this to allow any grid size.
|
||||
if (mask != nullptr) {
|
||||
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
|
||||
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
|
||||
m_topk,
|
||||
k,
|
||||
reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<int32_t*>(mask),
|
||||
use_silu_and_mul,
|
||||
n_experts);
|
||||
return;
|
||||
}
|
||||
|
||||
int const blockRepeat = (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
|
||||
if (blockRepeat > 1) {
|
||||
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
|
||||
@@ -652,6 +770,7 @@ void scaled_fp4_experts_quant_sm100a(
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(),
|
||||
nullptr, // mask
|
||||
false, // use_silu_and_mul
|
||||
m_topk,
|
||||
k,
|
||||
n_experts,
|
||||
@@ -665,6 +784,7 @@ void scaled_fp4_experts_quant_sm100a(
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(),
|
||||
nullptr, // mask
|
||||
false, // use_silu_and_mul
|
||||
m_topk,
|
||||
k,
|
||||
n_experts,
|
||||
@@ -679,28 +799,21 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts,
|
||||
torch::Tensor const& mask) {
|
||||
torch::Tensor const& mask,
|
||||
bool use_silu_and_mul) {
|
||||
CHECK_INPUT(output, "output must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input, "input must be a CUDA tensor");
|
||||
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor");
|
||||
CHECK_INPUT(mask, "mask must be a CUDA tensor");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2);
|
||||
TORCH_CHECK(output_scale.dim() == 2);
|
||||
TORCH_CHECK(input.dim() == 2);
|
||||
TORCH_CHECK(input_global_scale.dim() == 1);
|
||||
TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
||||
|
||||
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
||||
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
||||
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
||||
TORCH_CHECK(mask.scalar_type() == INT);
|
||||
// output is uint8 (two nvfp4 values are packed into one uint8)
|
||||
// output_scale is int32 (four fp8 values are packed into one int32)
|
||||
@@ -710,12 +823,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
|
||||
const int BLOCK_SIZE = 16;
|
||||
auto m_topk = input.size(0);
|
||||
auto k_by_2 = input.size(1);
|
||||
TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2");
|
||||
auto k = k_by_2 / 2;
|
||||
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||
auto k = k_by_2;
|
||||
if (use_silu_and_mul) {
|
||||
TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2");
|
||||
k = k_by_2 / 2;
|
||||
}
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(mask.size(0) == n_experts);
|
||||
TORCH_CHECK(output.size(0) == m_topk);
|
||||
TORCH_CHECK(output.size(1) == k / 2);
|
||||
@@ -734,9 +847,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
|
||||
output_scale.data_ptr(),
|
||||
input.data_ptr(),
|
||||
input_global_scale.data_ptr(),
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(),
|
||||
nullptr, // input_offset_by_experts
|
||||
nullptr, // output_scale_offset_by_experts
|
||||
mask.data_ptr(),
|
||||
use_silu_and_mul,
|
||||
m_topk,
|
||||
k,
|
||||
n_experts,
|
||||
@@ -747,9 +861,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
|
||||
output_scale.data_ptr(),
|
||||
input.data_ptr(),
|
||||
input_global_scale.data_ptr(),
|
||||
input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(),
|
||||
nullptr, // input_offset_by_experts
|
||||
nullptr, // output_scale_offset_by_experts
|
||||
mask.data_ptr(),
|
||||
use_silu_and_mul,
|
||||
m_topk,
|
||||
k,
|
||||
n_experts,
|
||||
|
||||
@@ -32,9 +32,8 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts,
|
||||
torch::Tensor const& mask);
|
||||
torch::Tensor const& mask,
|
||||
bool use_silu_and_mul);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -65,12 +64,11 @@ void silu_and_mul_scaled_fp4_experts_quant(
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts,
|
||||
torch::Tensor const& mask) {
|
||||
torch::Tensor const& mask,
|
||||
bool use_silu_and_mul) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
return silu_and_mul_scaled_fp4_experts_quant_sm100a(
|
||||
output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts, mask);
|
||||
output, output_scale, input, input_global_scale, mask, use_silu_and_mul);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user