Optimize prefill performance on cpu backend (#8750)
This commit is contained in:
@@ -105,7 +105,19 @@ namespace {
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
// parallel routines
|
||||
// [NB] Parallel Routines
|
||||
//
|
||||
// * at::parallel_for - applies for most of generic use cases, this will be compiled
|
||||
// against openmp in default torch release.
|
||||
//
|
||||
// * parallel_for - same function as above, can choose payload partition scheme in
|
||||
// balance211.
|
||||
//
|
||||
// * parallel_2d - parallel for 2 dimensions, used in GEMM, etc.
|
||||
// this one will do payload balance across 2 dimensions.
|
||||
//
|
||||
|
||||
// grain size for each thread
|
||||
constexpr int GRAIN_SIZE = 1024;
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||
@@ -113,6 +125,17 @@ inline T div_up(T x, T y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
// you can only use at::get_thread_num() with at::parallel_for()
|
||||
// as it is lazy initialized, otherwise it will always return 0.
|
||||
inline int get_thread_num() {
|
||||
#if defined(_OPENMP)
|
||||
return omp_get_thread_num();
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// balance payload across each thread
|
||||
template <typename T>
|
||||
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||
#if 0
|
||||
@@ -153,6 +176,100 @@ inline void parallel_for(int n, const func_t& f) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// for 1d parallel, use `actual_nth`
|
||||
// for 2d parallel, use even nths, e.g. 43->42
|
||||
int inline adjust_num_threads(int m) {
|
||||
int actual_nth = at::get_num_threads();
|
||||
if (m == 1) {
|
||||
return actual_nth;
|
||||
}
|
||||
return std::max(1, (actual_nth >> 1) * 2);
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_2d(int m, int n, const func_t& f) {
|
||||
// make sure we have even num_threads
|
||||
int nth = adjust_num_threads(m);
|
||||
|
||||
// [NOTE] thread blocking:
|
||||
//
|
||||
// 1) prefer square block per thread
|
||||
// 2) use even number of CPU cores
|
||||
// 3) use all `num_threads` cores
|
||||
//
|
||||
// we have:
|
||||
// TM * TN = T
|
||||
// BM / TM = BN / TN
|
||||
// then:
|
||||
// TM = ((BM / BN) * T) ^ 0.5
|
||||
//
|
||||
float r = float(m) / n;
|
||||
int nth_m = std::ceil(std::sqrt(r * nth));
|
||||
int nth_n = 1;
|
||||
for (; nth_m > 0; --nth_m) {
|
||||
nth_n = nth / nth_m;
|
||||
if (nth_m * nth_n == nth) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel num_threads(nth)
|
||||
{
|
||||
int ith = omp_get_thread_num();
|
||||
int ith_m = ith / nth_n;
|
||||
int ith_n = ith % nth_n;
|
||||
|
||||
int thread_block_m = div_up(m, nth_m);
|
||||
int thread_block_n = div_up(n, nth_n);
|
||||
|
||||
int begin_m = ith_m * thread_block_m;
|
||||
int end_m = std::min(m, begin_m + thread_block_m);
|
||||
int begin_n = ith_n * thread_block_n;
|
||||
int end_n = std::min(n, begin_n + thread_block_n);
|
||||
|
||||
f(begin_m, end_m, begin_n, end_n);
|
||||
}
|
||||
#else
|
||||
f(0, m, 0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
// limit max cache blocks
|
||||
// when we need to do pre-unpack for weights, e.g. fp8
|
||||
#define MAX_CACHE_BLOCK_SIZE 4
|
||||
|
||||
template <typename T>
|
||||
inline int get_cache_blocks(int chunk_size) {
|
||||
// L2 2MB and ratio of 50%
|
||||
const int L2_size = 2048 * 1024 >> 1;
|
||||
return std::max(1, int(L2_size / (chunk_size * sizeof(T))));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int get_cache_blocks<at::Float8_e4m3fn>(int chunk_size) {
|
||||
// fp8 uses bf16 as accumulate type
|
||||
int cache_block_size = get_cache_blocks<at::BFloat16>(chunk_size);
|
||||
return std::min(MAX_CACHE_BLOCK_SIZE, cache_block_size);
|
||||
}
|
||||
|
||||
// 2d sequential loop in range : [mb0, mb1), [nb0, nb1)
|
||||
template <typename T, typename func_t>
|
||||
inline void loop_2d(int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1, int64_t chunk_size, const func_t& f) {
|
||||
// get number of blocks for L2 in most inner loop
|
||||
int64_t cache_blocks_nb = get_cache_blocks<T>(chunk_size);
|
||||
|
||||
// loop order: [NB / cache_blocks_nb, MB, cache_blocks_nb]
|
||||
// TODO: implement reverse order of [MB / cache_blocks_mb, NB, cache_blocks_mb]
|
||||
for (int64_t nbb = nb0; nbb < nb1; nbb += cache_blocks_nb) {
|
||||
for (int64_t mb = mb0; mb < mb1; ++mb) {
|
||||
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, nb1); ++nb) {
|
||||
f(mb, nb, nb - nbb);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// data indexing for dimension collapse
|
||||
template <typename T>
|
||||
inline T data_index_init(T offset) {
|
||||
|
||||
Reference in New Issue
Block a user