[NVIDA] [1/N] Nvfp4 Masked Gemm: Add quant op for the flashinfer grouped gemm (#9200)

This commit is contained in:
Kaixi Hou
2025-08-22 12:19:45 -07:00
committed by GitHub
parent f556ac8bd8
commit e5638573c1
7 changed files with 420 additions and 13 deletions

View File

@@ -239,6 +239,33 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
#endif
}
__device__ __forceinline__ float silu(const float& val) {
return val / (1.0f + __expf(-val));
}
template <class Type>
inline __device__ void silu_and_mul(PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
float2 x[CVT_FP4_ELTS_PER_THREAD / 2];
float2 y[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
x[i] = __half22float2(x_vec.elts[i]);
y[i] = __half22float2(y_vec.elts[i]);
x[i].x = silu(x[i].x) * y[i].x;
x[i].y = silu(x[i].y) * y[i].y;
x_vec.elts[i] = __float22half2_rn(x[i]);
} else {
x[i] = __bfloat1622float2(x_vec.elts[i]);
y[i] = __bfloat1622float2(y_vec.elts[i]);
x[i].x = silu(x[i].x) * y[i].x;
x[i].y = silu(x[i].y) * y[i].y;
x_vec.elts[i] = __float22bfloat162_rn(x[i]);
}
}
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
__global__ void
@@ -255,6 +282,7 @@ cvt_fp16_to_fp4(
uint32_t* SFout,
uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts,
int32_t* mask,
int n_experts,
bool low_latency) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
@@ -265,6 +293,11 @@ cvt_fp16_to_fp4(
// Input tensor row/col loops.
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
// TODO(kaixih@nvidia): For now, we assume mask is used together with
// silu_and_mal. Maybe we want a more general behavior of mask later. In the
// silu case, the input last dim doubles.
bool use_mask = mask != nullptr;
int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow;
// Each global thread processes one element
for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {
@@ -272,13 +305,6 @@ cvt_fp16_to_fp4(
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;
int64_t inOffset = rowIdx * colsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find index within the experts using different strategies based on expert
// count
int rowIdx_in_expert = 0;
@@ -321,6 +347,23 @@ cvt_fp16_to_fp4(
}
}
// Eerly exit when using masks.
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
continue;
}
int64_t inOffset = rowIdx * actualColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
if (use_mask) {
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
silu_and_mul(in_vec, in_vec_mul);
}
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
@@ -356,6 +399,7 @@ cvt_fp16_to_fp4(
uint32_t* SFout,
uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts,
int32_t* mask,
int n_experts) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
@@ -383,6 +427,8 @@ cvt_fp16_to_fp4(
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
bool use_mask = mask != nullptr;
int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow;
// Each global thread processes one element
for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {
@@ -390,11 +436,6 @@ cvt_fp16_to_fp4(
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;
int64_t inOffset = rowIdx * colsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find expert using binary search for better performance with large m_topk
int rowIdx_in_expert = 0;
int expert_idx = 0;
@@ -419,6 +460,21 @@ cvt_fp16_to_fp4(
}
}
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
continue;
}
int64_t inOffset = rowIdx * actualColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
if (use_mask) {
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
silu_and_mul(in_vec, in_vec_mul);
}
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
@@ -442,6 +498,7 @@ void quant_impl(
void* input_global_scale,
void* input_offset_by_experts,
void* output_scale_offset_by_experts,
void* mask,
int m_topk,
int k,
int n_experts,
@@ -478,6 +535,7 @@ void quant_impl(
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
reinterpret_cast<int32_t*>(mask),
n_experts);
} else {
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
@@ -489,6 +547,7 @@ void quant_impl(
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
reinterpret_cast<int32_t*>(mask),
n_experts);
}
} else {
@@ -502,6 +561,7 @@ void quant_impl(
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
reinterpret_cast<int32_t*>(mask),
n_experts,
/* bool low_latency */ true);
} else {
@@ -514,6 +574,7 @@ void quant_impl(
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
reinterpret_cast<int32_t*>(mask),
n_experts,
/* bool low_latency */ true);
}
@@ -590,6 +651,7 @@ void scaled_fp4_experts_quant_sm100a(
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
nullptr, // mask
m_topk,
k,
n_experts,
@@ -602,6 +664,92 @@ void scaled_fp4_experts_quant_sm100a(
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
nullptr, // mask
m_topk,
k,
n_experts,
stream);
} else {
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
}
}
void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts,
torch::Tensor const& mask) {
CHECK_INPUT(output, "output must be a CUDA tensor");
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
CHECK_INPUT(input, "input must be a CUDA tensor");
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(mask, "mask must be a CUDA tensor");
TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(mask.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16;
auto m_topk = input.size(0);
auto k_by_2 = input.size(1);
TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2");
auto k = k_by_2 / 2;
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(mask.size(0) == n_experts);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
auto in_dtype = input.dtype();
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
if (in_dtype == at::ScalarType::Half) {
quant_impl<half>(
output.data_ptr(),
output_scale.data_ptr(),
input.data_ptr(),
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
mask.data_ptr(),
m_topk,
k,
n_experts,
stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
quant_impl<__nv_bfloat16>(
output.data_ptr(),
output_scale.data_ptr(),
input.data_ptr(),
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
mask.data_ptr(),
m_topk,
k,
n_experts,

View File

@@ -27,6 +27,15 @@ void scaled_fp4_experts_quant_sm100a(
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts,
torch::Tensor const& mask);
#endif
void scaled_fp4_quant(
@@ -50,3 +59,18 @@ void scaled_fp4_experts_quant(
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts,
torch::Tensor const& mask) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return silu_and_mul_scaled_fp4_experts_quant_sm100a(
output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts, mask);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
}