Sync from v0.13
This commit is contained in:
368
csrc/quantization/fp4/nvfp4_experts_quant.cu
Normal file
368
csrc/quantization/fp4/nvfp4_experts_quant.cu
Normal file
@@ -0,0 +1,368 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "nvfp4_utils.cuh"
|
||||
#include "launch_bounds_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
uint32_t* input_offset_by_experts,
|
||||
uint32_t* output_scale_offset_by_experts, int n_experts,
|
||||
bool low_latency) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||
"Vec size is not matched.");
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
|
||||
// Each global thread processes one element
|
||||
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
|
||||
globalIdx += gridDim.x * blockDim.x) {
|
||||
// Calculate which row and column this global thread should process
|
||||
int rowIdx = globalIdx / colsPerRow;
|
||||
int colIdx = globalIdx % colsPerRow;
|
||||
|
||||
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = inOffset;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Find index within the experts using different strategies based on expert
|
||||
// count
|
||||
int rowIdx_in_expert = 0;
|
||||
int expert_idx = 0;
|
||||
|
||||
if constexpr (SMALL_NUM_EXPERTS) {
|
||||
for (int i = 0; i < n_experts; i++) {
|
||||
uint32_t current_offset = __ldca(&input_offset_by_experts[i]);
|
||||
uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]);
|
||||
if (rowIdx >= current_offset && rowIdx < next_offset) {
|
||||
rowIdx_in_expert = rowIdx - current_offset;
|
||||
expert_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Load input offsets into registers first, then do the computation.
|
||||
// Local array size set to 17 because of register limit.
|
||||
uint32_t local_offsets[17];
|
||||
for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) {
|
||||
*reinterpret_cast<int4*>(local_offsets) =
|
||||
__ldca(reinterpret_cast<const int4*>(
|
||||
&input_offset_by_experts[chunk_start]));
|
||||
*reinterpret_cast<int4*>(local_offsets + 4) =
|
||||
__ldca(reinterpret_cast<const int4*>(
|
||||
&input_offset_by_experts[chunk_start + 4]));
|
||||
*reinterpret_cast<int4*>(local_offsets + 8) =
|
||||
__ldca(reinterpret_cast<const int4*>(
|
||||
&input_offset_by_experts[chunk_start + 8]));
|
||||
*reinterpret_cast<int4*>(local_offsets + 12) =
|
||||
__ldca(reinterpret_cast<const int4*>(
|
||||
&input_offset_by_experts[chunk_start + 12]));
|
||||
local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]);
|
||||
|
||||
// Check against the 16 loaded offsets
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) {
|
||||
rowIdx_in_expert = rowIdx - local_offsets[i];
|
||||
expert_idx = chunk_start + i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the global scaling factor, which will be applied to the SF.
|
||||
// Note SFScale is the same as next GEMM's alpha, which is
|
||||
// (448.f / (Alpha_A / 6.f)).
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
// The actual output_scales dim is computed from the padded numCols.
|
||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||
uint32_t* SFout_in_expert =
|
||||
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
||||
|
||||
auto sf_out =
|
||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
|
||||
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
uint32_t* input_offset_by_experts,
|
||||
uint32_t* output_scale_offset_by_experts, int n_experts) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||
"Vec size is not matched.");
|
||||
extern __shared__ uint32_t shared_input_offsets[];
|
||||
|
||||
// Load input offsets into shared memory.
|
||||
// If n_experts is larger than 4, use vectorized int4 to save instructions.
|
||||
// If n_experts is smaller than 4, read directly.
|
||||
if constexpr (SMALL_NUM_EXPERTS) {
|
||||
for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) {
|
||||
shared_input_offsets[i] = input_offset_by_experts[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) {
|
||||
*reinterpret_cast<int4*>(&shared_input_offsets[i]) =
|
||||
*reinterpret_cast<const int4*>(&input_offset_by_experts[i]);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||
|
||||
// Each global thread processes one element
|
||||
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
|
||||
globalIdx += gridDim.x * blockDim.x) {
|
||||
// Calculate which row and column this global thread should process
|
||||
int rowIdx = globalIdx / colsPerRow;
|
||||
int colIdx = globalIdx % colsPerRow;
|
||||
|
||||
int64_t inOffset = rowIdx * colsPerRow + colIdx;
|
||||
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||
int64_t outOffset = inOffset;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Find expert using binary search for better performance with large m_topk
|
||||
int rowIdx_in_expert = 0;
|
||||
int expert_idx = 0;
|
||||
|
||||
// Binary search through experts using shared memory
|
||||
int left = 0, right = n_experts - 1;
|
||||
while (left <= right) {
|
||||
int mid = (left + right) / 2;
|
||||
// Get offsets: shared_input_offsets[i] corresponds to
|
||||
// input_offset_by_experts[i]
|
||||
uint32_t mid_offset = shared_input_offsets[mid];
|
||||
uint32_t next_offset = shared_input_offsets[mid + 1];
|
||||
|
||||
if (rowIdx >= mid_offset && rowIdx < next_offset) {
|
||||
rowIdx_in_expert = rowIdx - mid_offset;
|
||||
expert_idx = mid;
|
||||
break;
|
||||
} else if (rowIdx < mid_offset) {
|
||||
right = mid - 1;
|
||||
} else {
|
||||
left = mid + 1;
|
||||
}
|
||||
}
|
||||
|
||||
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
|
||||
|
||||
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
|
||||
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
|
||||
uint32_t* SFout_in_expert =
|
||||
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
|
||||
|
||||
auto sf_out =
|
||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
|
||||
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void quant_impl(void* output, void* output_scale, void* input,
|
||||
void* input_global_scale, void* input_offset_by_experts,
|
||||
void* output_scale_offset_by_experts, int m_topk, int k,
|
||||
int n_experts, cudaStream_t stream) {
|
||||
// TODO: this multiProcessorCount should be cached.
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int multiProcessorCount;
|
||||
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount,
|
||||
device);
|
||||
|
||||
// Grid, Block size.
|
||||
// Each thread converts 8 values.
|
||||
int const workSizePerRow = k / ELTS_PER_THREAD;
|
||||
int const totalWorkSize = m_topk * workSizePerRow;
|
||||
dim3 block(std::min(workSizePerRow, 512));
|
||||
// Get number of blocks per SM
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x),
|
||||
multiProcessorCount * numBlocksPerSM));
|
||||
while (grid.x <= multiProcessorCount && block.x > 64) {
|
||||
grid.x *= 2;
|
||||
block.x = (block.x + 1) / 2;
|
||||
}
|
||||
|
||||
int const blockRepeat =
|
||||
(totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
|
||||
if (blockRepeat > 1) {
|
||||
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
|
||||
if (n_experts >= 4) {
|
||||
cvt_fp16_to_fp4<T, false, false>
|
||||
<<<grid, block, shared_mem_size, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts);
|
||||
} else {
|
||||
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts);
|
||||
}
|
||||
} else {
|
||||
if (n_experts >= 16) {
|
||||
cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts, /* bool low_latency */ true);
|
||||
} else {
|
||||
cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>(
|
||||
m_topk, k, reinterpret_cast<T*>(input),
|
||||
reinterpret_cast<float*>(input_global_scale),
|
||||
reinterpret_cast<uint32_t*>(output),
|
||||
reinterpret_cast<uint32_t*>(output_scale),
|
||||
reinterpret_cast<uint32_t*>(input_offset_by_experts),
|
||||
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
|
||||
n_experts, /* bool low_latency */ true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
/*Quantization entry for fp4 experts quantization*/
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
#define CHECK_INPUT(x, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m);
|
||||
|
||||
constexpr auto HALF = at::ScalarType::Half;
|
||||
constexpr auto BF16 = at::ScalarType::BFloat16;
|
||||
constexpr auto FLOAT = at::ScalarType::Float;
|
||||
constexpr auto INT = at::ScalarType::Int;
|
||||
constexpr auto UINT8 = at::ScalarType::Byte;
|
||||
|
||||
void scaled_fp4_experts_quant_sm1xxa(
|
||||
torch::Tensor& output, torch::Tensor& output_scale,
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts) {
|
||||
CHECK_INPUT(output, "output must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input, "input must be a CUDA tensor");
|
||||
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
|
||||
CHECK_INPUT(input_offset_by_experts,
|
||||
"input_offset_by_experts must be a CUDA tensor");
|
||||
CHECK_INPUT(output_scale_offset_by_experts,
|
||||
"output_scale_offset_by_experts must be a CUDA tensor");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2);
|
||||
TORCH_CHECK(output_scale.dim() == 2);
|
||||
TORCH_CHECK(input.dim() == 2);
|
||||
TORCH_CHECK(input_global_scale.dim() == 1);
|
||||
TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
||||
|
||||
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
||||
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
||||
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
||||
// output is uint8 (two nvfp4 values are packed into one uint8)
|
||||
// output_scale is int32 (four fp8 values are packed into one int32)
|
||||
TORCH_CHECK(output.scalar_type() == UINT8);
|
||||
TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||
|
||||
const int BLOCK_SIZE = 16;
|
||||
auto m_topk = input.size(0);
|
||||
auto k = input.size(1);
|
||||
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||
auto n_experts = input_global_scale.size(0);
|
||||
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
||||
TORCH_CHECK(output.size(0) == m_topk);
|
||||
TORCH_CHECK(output.size(1) == k / 2);
|
||||
int scales_k = k / BLOCK_SIZE;
|
||||
// 4 means the swizzle requirement by nvidia nvfp4.
|
||||
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
||||
// 4 means 4 fp8 values are packed into one int32
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
vllm::quant_impl<cuda_type>(
|
||||
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
|
||||
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
|
||||
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
|
||||
stream);
|
||||
});
|
||||
}
|
||||
Reference in New Issue
Block a user