Sync from v0.13
This commit is contained in:
177
csrc/quantization/vectorization_utils.cuh
Normal file
177
csrc/quantization/vectorization_utils.cuh
Normal file
@@ -0,0 +1,177 @@
|
||||
#pragma once
|
||||
#include "vectorization.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
|
||||
struct DefaultVecOp {
|
||||
ScaOp scalar_op;
|
||||
|
||||
__device__ __forceinline__ void operator()(
|
||||
vec_n_t<OutT, VEC_SIZE>& dst, const vec_n_t<InT, VEC_SIZE>& src) const {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
scalar_op(dst.val[i], src.val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int VEC_SIZE, typename InT, typename OutT, typename VecOp,
|
||||
typename ScaOp>
|
||||
__device__ inline void vectorize_with_alignment(
|
||||
const InT* in, OutT* out, int len, int tid, int stride,
|
||||
VecOp&& vec_op, // vec_n_t<InT,16> -> vec_n_t<OutT,16>
|
||||
ScaOp&& scalar_op) { // InT -> OutT
|
||||
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
|
||||
"VEC_SIZE must be a positive power-of-two");
|
||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||
|
||||
// fast path when the whole region is already aligned
|
||||
// Note: currently the output is guaranteed to be same as the input, so we
|
||||
// don't check it here, comments here just for future reference.
|
||||
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||
if (can_vec) {
|
||||
int num_vec = len / VEC_SIZE;
|
||||
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
using vout_t = vec_n_t<OutT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
auto* v_out = reinterpret_cast<vout_t*>(out);
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vout_t tmp;
|
||||
// Make a local copy of the entire pack
|
||||
vin_t src = v_in[i]; // <- encourages a single vector ld
|
||||
vec_op(tmp, src);
|
||||
v_out[i] = tmp; // <- encourages a single vector st
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
|
||||
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
|
||||
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
|
||||
prefix_elems /= sizeof(InT);
|
||||
prefix_elems = min(prefix_elems, len); // 0 ≤ prefix < 16
|
||||
|
||||
// 1. prefill the when it is unsafe to vectorize
|
||||
for (int i = tid; i < prefix_elems; i += stride) {
|
||||
scalar_op(out[i], in[i]);
|
||||
}
|
||||
|
||||
in += prefix_elems;
|
||||
out += prefix_elems;
|
||||
len -= prefix_elems;
|
||||
|
||||
int num_vec = len / VEC_SIZE;
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
using vout_t = vec_n_t<OutT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
auto* v_out = reinterpret_cast<vout_t*>(out);
|
||||
|
||||
// 2. vectorize the main part
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vout_t tmp;
|
||||
// Make a local copy of the entire pack
|
||||
vin_t src = v_in[i]; // <- encourages a single vector ld
|
||||
vec_op(tmp, src);
|
||||
v_out[i] = tmp; // <- encourages a single vector st
|
||||
}
|
||||
|
||||
// 3. handle the tail
|
||||
int tail_start = num_vec * VEC_SIZE;
|
||||
for (int i = tid + tail_start; i < len; i += stride) {
|
||||
scalar_op(out[i], in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
|
||||
__device__ __forceinline__ void vectorize_with_alignment(const InT* in,
|
||||
OutT* out, int len,
|
||||
int tid, int stride,
|
||||
ScaOp&& scalar_op) {
|
||||
using Vec = DefaultVecOp<VEC_SIZE, InT, OutT, std::decay_t<ScaOp>>;
|
||||
vectorize_with_alignment<VEC_SIZE>(in, out, len, tid, stride, Vec{scalar_op},
|
||||
std::forward<ScaOp>(scalar_op));
|
||||
}
|
||||
|
||||
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||
struct DefaultReadVecOp {
|
||||
ScaOp scalar_op;
|
||||
|
||||
__device__ __forceinline__ void operator()(
|
||||
const vec_n_t<InT, VEC_SIZE>& src) const {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
scalar_op(src.val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// read-only version: iterate over the input with alignment guarantees
|
||||
template <int VEC_SIZE, typename InT, typename VecOp, typename ScaOp>
|
||||
__device__ inline void vectorize_read_with_alignment(const InT* in, int len,
|
||||
int tid, int stride,
|
||||
VecOp&& vec_op,
|
||||
ScaOp&& scalar_op) {
|
||||
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
|
||||
"VEC_SIZE must be a positive power-of-two");
|
||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT);
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||
|
||||
// fast path when the whole region is already aligned
|
||||
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||
if (can_vec) {
|
||||
int num_vec = len / VEC_SIZE;
|
||||
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vin_t tmp = v_in[i];
|
||||
vec_op(tmp);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int misalignment_offset = addr & (WIDTH - 1);
|
||||
int alignment_bytes = WIDTH - misalignment_offset;
|
||||
int prefix_elems = alignment_bytes & (WIDTH - 1);
|
||||
prefix_elems /= sizeof(InT);
|
||||
prefix_elems = min(prefix_elems, len);
|
||||
|
||||
// 1. handle the possibly unaligned prefix with scalar access.
|
||||
for (int i = tid; i < prefix_elems; i += stride) {
|
||||
scalar_op(in[i]);
|
||||
}
|
||||
|
||||
in += prefix_elems;
|
||||
len -= prefix_elems;
|
||||
|
||||
int num_vec = len / VEC_SIZE;
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
|
||||
// 2. vectorized traversal of the main aligned region.
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vec_op(v_in[i]);
|
||||
}
|
||||
|
||||
// 3. handle remaining tail elements.
|
||||
int tail_start = num_vec * VEC_SIZE;
|
||||
for (int i = tid + tail_start; i < len; i += stride) {
|
||||
scalar_op(in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// overload that requires only a scalar_op
|
||||
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||
__device__ __forceinline__ void vectorize_read_with_alignment(
|
||||
const InT* in, int len, int tid, int stride, ScaOp&& scalar_op) {
|
||||
using Vec = DefaultReadVecOp<VEC_SIZE, InT, std::decay_t<ScaOp>>;
|
||||
vectorize_read_with_alignment<VEC_SIZE>(in, len, tid, stride, Vec{scalar_op},
|
||||
std::forward<ScaOp>(scalar_op));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
Reference in New Issue
Block a user